move semaphore to handler level

This commit is contained in:
Bruce Röttgers 2025-05-16 20:32:21 +02:00
parent 60c13fb9ec
commit 5031f33ea2
3 changed files with 30 additions and 28 deletions

29
main.py
View File

@ -40,7 +40,7 @@ def gather_handler_kwargs(handler_name: str) -> dict:
async def process_dump(
mappings: dict[str, str], handlers, max_concurrent: int
mappings: dict[str, str], handlers
):
"""
Stream-download the bzip2-compressed XML dump and feed to SAX.
@ -52,7 +52,7 @@ async def process_dump(
)
decomp = bz2.BZ2Decompressor()
sax_parser = xml.sax.make_parser()
dump_handler = WikiDumpHandler(mappings, handlers, max_concurrent)
dump_handler = WikiDumpHandler(mappings, handlers)
sax_parser.setContentHandler(dump_handler)
async with aiohttp.ClientSession() as session:
@ -75,9 +75,18 @@ async def main():
logger.error("Error: set ENV HANDLER (e.g. 'filesystem' or 'filesystem,sftp')")
sys.exit(1)
# 2. Read concurrency setting
try:
max_conc = int(os.getenv("MAX_CONCURRENT", "0"))
except ValueError:
raise ValueError("MAX_CONCURRENT must be an integer")
if max_conc < 0:
raise ValueError("MAX_CONCURRENT must be >= 0")
handlers = []
# 2. Load each handler
# 3. Load each handler
for handler_name in handler_names:
handler_name = handler_name.strip()
if not handler_name:
@ -102,20 +111,14 @@ async def main():
# Build kwargs from ENV
handler_kwargs = gather_handler_kwargs(handler_name)
# Add max_concurrent to kwargs
handler_kwargs["max_concurrent"] = max_conc
# Instantiate
handler = HandlerCls(**handler_kwargs)
handlers.append(handler)
# 3. read concurrency setting
try:
max_conc = int(os.getenv("MAX_CONCURRENT", "0"))
except ValueError:
raise ValueError("MAX_CONCURRENT must be an integer")
if max_conc < 0:
raise ValueError("MAX_CONCURRENT must be >= 0")
# 4. Fetch mappings
logger.info("Fetching mappings from SQL dump…")
mappings = await fetch_mappings()
@ -123,7 +126,7 @@ async def main():
# 5. Stream & split the XML dump
logger.info("Processing XML dump…")
await process_dump(mappings, handlers, max_conc)
await process_dump(mappings, handlers) # Pass 0 as max_concurrent since handlers handle it
# 6. Finish up
await asyncio.gather(*[handler.close() for handler in handlers])

View File

@ -1,6 +1,7 @@
"""Reference handler for output handlers."""
from abc import ABC, abstractmethod
import logging
import asyncio
@ -14,15 +15,20 @@ class BaseHandler(ABC):
_successful_writes = 0
_failed_writes = 0
def __init__(self, fail_on_error: bool = True, **kwargs):
def __init__(self, fail_on_error: bool = True, max_concurrent=0, **kwargs):
"""
Initializes the BaseHandler with optional parameters.
Args:
fail_on_error (bool): If True, the handler will raise an exception on error. Defaults to True.
max_concurrent: Maximum number of concurrent write operations.
0 means unlimited concurrency.
**kwargs: Additional keyword arguments for specific handler implementations.
"""
self.fail_on_error = fail_on_error
self.semaphore = None
if max_concurrent > 0:
self.semaphore = asyncio.Semaphore(max_concurrent)
@abstractmethod
@ -47,7 +53,11 @@ class BaseHandler(ABC):
entry (dict): The entry to write (will be JSON-encoded).
uid (str): The unique identifier for the entry. The default id provided by wikivoyage is recommended.
"""
success = await self._write_entry(entry, uid)
if self.semaphore:
async with self.semaphore:
success = await self._write_entry(entry, uid)
else:
success = await self._write_entry(entry, uid)
if success:
self.logger.debug(f"Successfully wrote entry with UID {uid}")
self._successful_writes += 1

View File

@ -12,14 +12,11 @@ class WikiDumpHandler(xml.sax.ContentHandler):
and write via the usersupplied handler(s).
"""
def __init__(self, mappings, handlers, max_concurrent):
def __init__(self, mappings, handlers):
super().__init__()
self.mappings = mappings
# Support a single handler or a list of handlers
self.handlers = handlers
self.sem = (
asyncio.Semaphore(max_concurrent) if max_concurrent > 0 else None
)
self.tasks: list[asyncio.Task] = []
self.currentTag: str | None = None
@ -55,10 +52,7 @@ class WikiDumpHandler(xml.sax.ContentHandler):
title = self.currentTitle
logger.debug(f"scheduled {wd_id} for handling")
# schedule processing
if self.sem:
task = asyncio.create_task(self._bounded_process(text, wd_id, title))
else:
task = asyncio.create_task(self._process(text, wd_id, title))
task = asyncio.create_task(self._process(text, wd_id, title))
self.tasks.append(task)
else:
logger.debug(f"page {pid} without wikidata id, skipping...")
@ -104,8 +98,3 @@ class WikiDumpHandler(xml.sax.ContentHandler):
await asyncio.gather(*[
handler.write_entry(entry, uid) for handler in self.handlers
])
async def _bounded_process(self, text: str, uid: str, title: str):
# Only run N at once
async with self.sem:
await self._process(text, uid, title)