from __future__ import annotations import asyncio from .logging_config import get_logger logger = get_logger("lingma_gateway.concurrency") class BackpressureRejected(Exception): """Raised when a request cannot acquire an in-flight slot before timeout.""" def __init__(self, retry_after: float): super().__init__(f"backpressure rejected, retry_after={retry_after:.1f}s") self.retry_after = retry_after class InFlightTicket: """Reference-counted handle for a single in-flight slot. Release is idempotent so callers can defensively `release()` from multiple cleanup paths (stream finally + outer exception handler) without worrying. """ __slots__ = ("_parent", "_released") def __init__(self, parent: "InFlightGuard | None"): self._parent = parent self._released = False def release(self) -> None: if self._released or self._parent is None: self._released = True return self._released = True self._parent._on_release() async def __aenter__(self) -> "InFlightTicket": return self async def __aexit__(self, *_exc) -> None: self.release() class InFlightGuard: """Async semaphore wrapper with queue/reject accounting and Prometheus hooks. - `max_in_flight <= 0` disables limiting (back-compat, unlimited). - `queue_timeout_sec` bounds how long a request may wait for a slot. On timeout, `try_acquire()` raises `BackpressureRejected`. """ def __init__(self, max_in_flight: int, queue_timeout_sec: float): self.max = max(0, int(max_in_flight)) self.queue_timeout = max(0.0, float(queue_timeout_sec)) self._sem: asyncio.Semaphore | None = ( asyncio.Semaphore(self.max) if self.max > 0 else None ) self.in_flight = 0 self.queued = 0 self.accepted_total = 0 self.rejected_total = 0 async def try_acquire(self) -> InFlightTicket: if self._sem is None: self.in_flight += 1 self.accepted_total += 1 return InFlightTicket(parent=self) self.queued += 1 try: if self.queue_timeout <= 0: await self._sem.acquire() else: try: await asyncio.wait_for(self._sem.acquire(), timeout=self.queue_timeout) except (asyncio.TimeoutError, TimeoutError): self.rejected_total += 1 logger.warning( "backpressure rejected: in_flight=%d queued=%d max=%d", self.in_flight, self.queued - 1, self.max, ) raise BackpressureRejected(retry_after=self.queue_timeout) finally: self.queued -= 1 self.in_flight += 1 self.accepted_total += 1 return InFlightTicket(parent=self) def _on_release(self) -> None: self.in_flight -= 1 if self._sem is not None: self._sem.release() def stats(self) -> dict: return { "max_in_flight": self.max, "in_flight": self.in_flight, "queued": self.queued, "accepted_total": self.accepted_total, "rejected_total": self.rejected_total, "queue_timeout_sec": self.queue_timeout, } def prometheus_lines(self) -> list[str]: return [ "# TYPE gateway_in_flight gauge", f"gateway_in_flight {self.in_flight}", "# TYPE gateway_queued gauge", f"gateway_queued {self.queued}", "# TYPE gateway_max_in_flight gauge", f"gateway_max_in_flight {self.max}", "# TYPE gateway_accepted_total counter", f"gateway_accepted_total {self.accepted_total}", "# TYPE gateway_rejected_total counter", f"gateway_rejected_total {self.rejected_total}", ]