mirror of
https://github.com/bcye/structured-wikivoyage-exports.git
synced 2025-06-07 08:24:05 +00:00
move semaphore to handler level
This commit is contained in:
parent
60c13fb9ec
commit
5031f33ea2
29
main.py
29
main.py
@ -40,7 +40,7 @@ def gather_handler_kwargs(handler_name: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def process_dump(
|
async def process_dump(
|
||||||
mappings: dict[str, str], handlers, max_concurrent: int
|
mappings: dict[str, str], handlers
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Stream-download the bzip2-compressed XML dump and feed to SAX.
|
Stream-download the bzip2-compressed XML dump and feed to SAX.
|
||||||
@ -52,7 +52,7 @@ async def process_dump(
|
|||||||
)
|
)
|
||||||
decomp = bz2.BZ2Decompressor()
|
decomp = bz2.BZ2Decompressor()
|
||||||
sax_parser = xml.sax.make_parser()
|
sax_parser = xml.sax.make_parser()
|
||||||
dump_handler = WikiDumpHandler(mappings, handlers, max_concurrent)
|
dump_handler = WikiDumpHandler(mappings, handlers)
|
||||||
sax_parser.setContentHandler(dump_handler)
|
sax_parser.setContentHandler(dump_handler)
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
@ -75,9 +75,18 @@ async def main():
|
|||||||
logger.error("Error: set ENV HANDLER (e.g. 'filesystem' or 'filesystem,sftp')")
|
logger.error("Error: set ENV HANDLER (e.g. 'filesystem' or 'filesystem,sftp')")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 2. Read concurrency setting
|
||||||
|
try:
|
||||||
|
max_conc = int(os.getenv("MAX_CONCURRENT", "0"))
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("MAX_CONCURRENT must be an integer")
|
||||||
|
|
||||||
|
if max_conc < 0:
|
||||||
|
raise ValueError("MAX_CONCURRENT must be >= 0")
|
||||||
|
|
||||||
handlers = []
|
handlers = []
|
||||||
|
|
||||||
# 2. Load each handler
|
# 3. Load each handler
|
||||||
for handler_name in handler_names:
|
for handler_name in handler_names:
|
||||||
handler_name = handler_name.strip()
|
handler_name = handler_name.strip()
|
||||||
if not handler_name:
|
if not handler_name:
|
||||||
@ -103,19 +112,13 @@ async def main():
|
|||||||
# Build kwargs from ENV
|
# Build kwargs from ENV
|
||||||
handler_kwargs = gather_handler_kwargs(handler_name)
|
handler_kwargs = gather_handler_kwargs(handler_name)
|
||||||
|
|
||||||
|
# Add max_concurrent to kwargs
|
||||||
|
handler_kwargs["max_concurrent"] = max_conc
|
||||||
|
|
||||||
# Instantiate
|
# Instantiate
|
||||||
handler = HandlerCls(**handler_kwargs)
|
handler = HandlerCls(**handler_kwargs)
|
||||||
handlers.append(handler)
|
handlers.append(handler)
|
||||||
|
|
||||||
# 3. read concurrency setting
|
|
||||||
try:
|
|
||||||
max_conc = int(os.getenv("MAX_CONCURRENT", "0"))
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("MAX_CONCURRENT must be an integer")
|
|
||||||
|
|
||||||
if max_conc < 0:
|
|
||||||
raise ValueError("MAX_CONCURRENT must be >= 0")
|
|
||||||
|
|
||||||
# 4. Fetch mappings
|
# 4. Fetch mappings
|
||||||
logger.info("Fetching mappings from SQL dump…")
|
logger.info("Fetching mappings from SQL dump…")
|
||||||
mappings = await fetch_mappings()
|
mappings = await fetch_mappings()
|
||||||
@ -123,7 +126,7 @@ async def main():
|
|||||||
|
|
||||||
# 5. Stream & split the XML dump
|
# 5. Stream & split the XML dump
|
||||||
logger.info("Processing XML dump…")
|
logger.info("Processing XML dump…")
|
||||||
await process_dump(mappings, handlers, max_conc)
|
await process_dump(mappings, handlers) # Pass 0 as max_concurrent since handlers handle it
|
||||||
|
|
||||||
# 6. Finish up
|
# 6. Finish up
|
||||||
await asyncio.gather(*[handler.close() for handler in handlers])
|
await asyncio.gather(*[handler.close() for handler in handlers])
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Reference handler for output handlers."""
|
"""Reference handler for output handlers."""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -14,15 +15,20 @@ class BaseHandler(ABC):
|
|||||||
_successful_writes = 0
|
_successful_writes = 0
|
||||||
_failed_writes = 0
|
_failed_writes = 0
|
||||||
|
|
||||||
def __init__(self, fail_on_error: bool = True, **kwargs):
|
def __init__(self, fail_on_error: bool = True, max_concurrent=0, **kwargs):
|
||||||
"""
|
"""
|
||||||
Initializes the BaseHandler with optional parameters.
|
Initializes the BaseHandler with optional parameters.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fail_on_error (bool): If True, the handler will raise an exception on error. Defaults to True.
|
fail_on_error (bool): If True, the handler will raise an exception on error. Defaults to True.
|
||||||
|
max_concurrent: Maximum number of concurrent write operations.
|
||||||
|
0 means unlimited concurrency.
|
||||||
**kwargs: Additional keyword arguments for specific handler implementations.
|
**kwargs: Additional keyword arguments for specific handler implementations.
|
||||||
"""
|
"""
|
||||||
self.fail_on_error = fail_on_error
|
self.fail_on_error = fail_on_error
|
||||||
|
self.semaphore = None
|
||||||
|
if max_concurrent > 0:
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -47,7 +53,11 @@ class BaseHandler(ABC):
|
|||||||
entry (dict): The entry to write (will be JSON-encoded).
|
entry (dict): The entry to write (will be JSON-encoded).
|
||||||
uid (str): The unique identifier for the entry. The default id provided by wikivoyage is recommended.
|
uid (str): The unique identifier for the entry. The default id provided by wikivoyage is recommended.
|
||||||
"""
|
"""
|
||||||
success = await self._write_entry(entry, uid)
|
if self.semaphore:
|
||||||
|
async with self.semaphore:
|
||||||
|
success = await self._write_entry(entry, uid)
|
||||||
|
else:
|
||||||
|
success = await self._write_entry(entry, uid)
|
||||||
if success:
|
if success:
|
||||||
self.logger.debug(f"Successfully wrote entry with UID {uid}")
|
self.logger.debug(f"Successfully wrote entry with UID {uid}")
|
||||||
self._successful_writes += 1
|
self._successful_writes += 1
|
||||||
|
@ -12,14 +12,11 @@ class WikiDumpHandler(xml.sax.ContentHandler):
|
|||||||
and write via the user‐supplied handler(s).
|
and write via the user‐supplied handler(s).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, mappings, handlers, max_concurrent):
|
def __init__(self, mappings, handlers):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mappings = mappings
|
self.mappings = mappings
|
||||||
# Support a single handler or a list of handlers
|
# Support a single handler or a list of handlers
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
self.sem = (
|
|
||||||
asyncio.Semaphore(max_concurrent) if max_concurrent > 0 else None
|
|
||||||
)
|
|
||||||
self.tasks: list[asyncio.Task] = []
|
self.tasks: list[asyncio.Task] = []
|
||||||
|
|
||||||
self.currentTag: str | None = None
|
self.currentTag: str | None = None
|
||||||
@ -55,10 +52,7 @@ class WikiDumpHandler(xml.sax.ContentHandler):
|
|||||||
title = self.currentTitle
|
title = self.currentTitle
|
||||||
logger.debug(f"scheduled {wd_id} for handling")
|
logger.debug(f"scheduled {wd_id} for handling")
|
||||||
# schedule processing
|
# schedule processing
|
||||||
if self.sem:
|
task = asyncio.create_task(self._process(text, wd_id, title))
|
||||||
task = asyncio.create_task(self._bounded_process(text, wd_id, title))
|
|
||||||
else:
|
|
||||||
task = asyncio.create_task(self._process(text, wd_id, title))
|
|
||||||
self.tasks.append(task)
|
self.tasks.append(task)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"page {pid} without wikidata id, skipping...")
|
logger.debug(f"page {pid} without wikidata id, skipping...")
|
||||||
@ -104,8 +98,3 @@ class WikiDumpHandler(xml.sax.ContentHandler):
|
|||||||
await asyncio.gather(*[
|
await asyncio.gather(*[
|
||||||
handler.write_entry(entry, uid) for handler in self.handlers
|
handler.write_entry(entry, uid) for handler in self.handlers
|
||||||
])
|
])
|
||||||
|
|
||||||
async def _bounded_process(self, text: str, uid: str, title: str):
|
|
||||||
# Only run N at once
|
|
||||||
async with self.sem:
|
|
||||||
await self._process(text, uid, title)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user