Merge pull request #33 from bcye/feature/multiple-handlers

Allow for multiple handlers
This commit is contained in:
Bruce 2025-06-03 14:52:14 +02:00 committed by GitHub
commit 8f099dc7bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 54 deletions

81
main.py
View File

@ -40,7 +40,7 @@ def gather_handler_kwargs(handler_name: str) -> dict:
async def process_dump( async def process_dump(
mappings: dict[str, str], handler, max_concurrent: int mappings: dict[str, str], handlers
): ):
""" """
Stream-download the bzip2-compressed XML dump and feed to SAX. Stream-download the bzip2-compressed XML dump and feed to SAX.
@ -52,7 +52,7 @@ async def process_dump(
) )
decomp = bz2.BZ2Decompressor() decomp = bz2.BZ2Decompressor()
sax_parser = xml.sax.make_parser() sax_parser = xml.sax.make_parser()
dump_handler = WikiDumpHandler(mappings, handler, max_concurrent) dump_handler = WikiDumpHandler(mappings, handlers)
sax_parser.setContentHandler(dump_handler) sax_parser.setContentHandler(dump_handler)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -69,36 +69,13 @@ async def process_dump(
await asyncio.gather(*dump_handler.tasks) await asyncio.gather(*dump_handler.tasks)
async def main(): async def main():
# 1. Which handler to load? # 1. Which handler(s) to load?
handler_name = os.getenv("HANDLER") handler_names = os.getenv("HANDLER", "").split(",")
if not handler_name: if not handler_names or not handler_names[0]:
logger.error("Error: set ENV HANDLER (e.g. 'filesystem')") logger.error("Error: set ENV HANDLER (e.g. 'filesystem' or 'filesystem,sftp')")
sys.exit(1) sys.exit(1)
# 2. Dynamic import # 2. Read concurrency setting
module_path = f"output_handlers.{handler_name}"
try:
mod = importlib.import_module(module_path)
except ImportError as e:
logger.error(f"Error loading handler module {module_path}: {e}")
sys.exit(1)
# 3. Find the class: e.g. "sftp" → "SftpHandler"
class_name = handler_name.title().replace("_", "") + "Handler"
if not hasattr(mod, class_name):
logger.error(f"{module_path} defines no class {class_name}")
sys.exit(1)
HandlerCls = getattr(mod, class_name)
logger.info(f"Using handler from {module_path}")
# 4. Build kwargs from ENV
handler_kwargs = gather_handler_kwargs(handler_name)
# 5. Instantiate
handler = HandlerCls(**handler_kwargs)
# 6. read concurrency setting
try: try:
max_conc = int(os.getenv("MAX_CONCURRENT", "0")) max_conc = int(os.getenv("MAX_CONCURRENT", "0"))
except ValueError: except ValueError:
@ -107,18 +84,52 @@ async def main():
if max_conc < 0: if max_conc < 0:
raise ValueError("MAX_CONCURRENT must be >= 0") raise ValueError("MAX_CONCURRENT must be >= 0")
handlers = []
# 7. Fetch mappings # 3. Load each handler
for handler_name in handler_names:
handler_name = handler_name.strip()
if not handler_name:
continue
# Dynamic import
module_path = f"output_handlers.{handler_name}"
try:
mod = importlib.import_module(module_path)
except ImportError as e:
logger.error(f"Error loading handler module {module_path}: {e}")
sys.exit(1)
# Find the class: e.g. "sftp" → "SftpHandler"
class_name = handler_name.title().replace("_", "") + "Handler"
if not hasattr(mod, class_name):
logger.error(f"{module_path} defines no class {class_name}")
sys.exit(1)
HandlerCls = getattr(mod, class_name)
logger.info(f"Using handler from {module_path}")
# Build kwargs from ENV
handler_kwargs = gather_handler_kwargs(handler_name)
# Add max_concurrent to kwargs
handler_kwargs["max_concurrent"] = max_conc
# Instantiate
handler = HandlerCls(**handler_kwargs)
handlers.append(handler)
# 4. Fetch mappings
logger.info("Fetching mappings from SQL dump…") logger.info("Fetching mappings from SQL dump…")
mappings = await fetch_mappings() mappings = await fetch_mappings()
logger.info(f"Got {len(mappings)} wikibase_item mappings.") logger.info(f"Got {len(mappings)} wikibase_item mappings.")
# 8. Stream & split the XML dump # 5. Stream & split the XML dump
logger.info("Processing XML dump…") logger.info("Processing XML dump…")
await process_dump(mappings, handler, max_conc) await process_dump(mappings, handlers) # Pass 0 as max_concurrent since handlers handle it
# 5. Finish up # 6. Finish up
await handler.close() await asyncio.gather(*[handler.close() for handler in handlers])
logger.info("All done.") logger.info("All done.")

View File

@ -1,6 +1,7 @@
"""Reference handler for output handlers.""" """Reference handler for output handlers."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import logging import logging
import asyncio
@ -14,15 +15,20 @@ class BaseHandler(ABC):
_successful_writes = 0 _successful_writes = 0
_failed_writes = 0 _failed_writes = 0
def __init__(self, fail_on_error: bool = True, **kwargs): def __init__(self, fail_on_error: bool = True, max_concurrent=0, **kwargs):
""" """
Initializes the BaseHandler with optional parameters. Initializes the BaseHandler with optional parameters.
Args: Args:
fail_on_error (bool): If True, the handler will raise an exception on error. Defaults to True. fail_on_error (bool): If True, the handler will raise an exception on error. Defaults to True.
max_concurrent: Maximum number of concurrent write operations.
0 means unlimited concurrency.
**kwargs: Additional keyword arguments for specific handler implementations. **kwargs: Additional keyword arguments for specific handler implementations.
""" """
self.fail_on_error = fail_on_error self.fail_on_error = fail_on_error
self.semaphore = None
if max_concurrent > 0:
self.semaphore = asyncio.Semaphore(max_concurrent)
@abstractmethod @abstractmethod
@ -47,7 +53,11 @@ class BaseHandler(ABC):
entry (dict): The entry to write (will be JSON-encoded). entry (dict): The entry to write (will be JSON-encoded).
uid (str): The unique identifier for the entry. The default id provided by wikivoyage is recommended. uid (str): The unique identifier for the entry. The default id provided by wikivoyage is recommended.
""" """
success = await self._write_entry(entry, uid) if self.semaphore:
async with self.semaphore:
success = await self._write_entry(entry, uid)
else:
success = await self._write_entry(entry, uid)
if success: if success:
self.logger.debug(f"Successfully wrote entry with UID {uid}") self.logger.debug(f"Successfully wrote entry with UID {uid}")
self._successful_writes += 1 self._successful_writes += 1

View File

@ -10,8 +10,9 @@ class BunnyStorageHandler(BaseHandler):
api_key: str, api_key: str,
fail_on_error: bool = True, fail_on_error: bool = True,
keepalive_timeout: int = 75, keepalive_timeout: int = 75,
**kwargs,
): ):
super().__init__(fail_on_error=fail_on_error) super().__init__(fail_on_error=fail_on_error, **kwargs)
self.base_url = f"https://{region}.bunnycdn.com/{base_path}" self.base_url = f"https://{region}.bunnycdn.com/{base_path}"
self.headers = { self.headers = {
"AccessKey": api_key, "AccessKey": api_key,

View File

@ -9,16 +9,14 @@ class WikiDumpHandler(xml.sax.ContentHandler):
""" """
SAX handler that, for each <page> whose <id> is in mappings, SAX handler that, for each <page> whose <id> is in mappings,
collects the <text> and schedules an async task to parse collects the <text> and schedules an async task to parse
and write via the usersupplied handler. and write via the usersupplied handler(s).
""" """
def __init__(self, mappings, handler, max_concurrent): def __init__(self, mappings, handlers):
super().__init__() super().__init__()
self.mappings = mappings self.mappings = mappings
self.handler = handler # Support a single handler or a list of handlers
self.sem = ( self.handlers = handlers
asyncio.Semaphore(max_concurrent) if max_concurrent > 0 else None
)
self.tasks: list[asyncio.Task] = [] self.tasks: list[asyncio.Task] = []
self.currentTag: str | None = None self.currentTag: str | None = None
@ -54,10 +52,7 @@ class WikiDumpHandler(xml.sax.ContentHandler):
title = self.currentTitle title = self.currentTitle
logger.debug(f"scheduled {wd_id} for handling") logger.debug(f"scheduled {wd_id} for handling")
# schedule processing # schedule processing
if self.sem: task = asyncio.create_task(self._process(text, wd_id, title))
task = asyncio.create_task(self._bounded_process(text, wd_id, title))
else:
task = asyncio.create_task(self._process(text, wd_id, title))
self.tasks.append(task) self.tasks.append(task)
else: else:
logger.debug(f"page {pid} without wikidata id, skipping...") logger.debug(f"page {pid} without wikidata id, skipping...")
@ -98,9 +93,8 @@ class WikiDumpHandler(xml.sax.ContentHandler):
parser = WikivoyageParser() parser = WikivoyageParser()
entry = parser.parse(text) entry = parser.parse(text)
entry['properties']['title'] = title entry['properties']['title'] = title
await self.handler.write_entry(entry, uid)
async def _bounded_process(self, text: str, uid: str, title: str): # Write to all handlers concurrently
# Only run N at once await asyncio.gather(*[
async with self.sem: handler.write_entry(entry, uid) for handler in self.handlers
await self._process(text, uid, title) ])