diff --git a/.env.example b/.env.example index 34f812f..152502d 100644 --- a/.env.example +++ b/.env.example @@ -4,6 +4,15 @@ HOST=0.0.0.0 PORT=8317 # API Key,可配置多个(逗号分隔) API_KEYS=sk-your-api-key +# 独立的 /metrics 鉴权 token(留空则退化为 API_KEYS 也可访问;若连 API_KEYS 都没配,/metrics 为公开) +METRICS_TOKEN= +# 日志级别(DEBUG / INFO / WARNING / ERROR) +LOG_LEVEL=INFO + +# /v1/chat/completions 并发上限(<=0 表示不限流) +GATEWAY_MAX_IN_FLIGHT=4 +# 排队等待超时秒数,超过后返回 429 + Retry-After +GATEWAY_QUEUE_TIMEOUT_SEC=30 # 容器内 Lingma 二进制路径 LINGMA_BIN=/app/data/bin/Lingma @@ -45,7 +54,16 @@ AUTO_LOGIN_TIMEOUT=180 # 自动登录重试次数 AUTO_LOGIN_MAX_RETRY=2 -# Lingma 登录用户名 +# Lingma 登录用户名(仅当 LINGMA_ACCOUNTS 为空时生效,单实例模式) LINGMA_USERNAME= -# Lingma 登录密码 +# Lingma 登录密码(仅当 LINGMA_ACCOUNTS 为空时生效) LINGMA_PASSWORD= + +# ==== 多实例池(方案乙:多账号) ==== +# 多账号列表,支持两种格式: +# CSV: user1:pass1,user2:pass2 +# JSON: [{"username":"u1","password":"p1"},{"username":"u2","password":"p2"}] +# 配置后每个账号对应一个独立 Lingma 实例(独立 workDir + 独立自动登录) +LINGMA_ACCOUNTS= +# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning +LINGMA_INSTANCE_COUNT= diff --git a/.gitignore b/.gitignore index e2e1fce..8d2999e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,5 @@ __pycache__/ *.pyc bin/ runtime-bin/ -data/ -!data/ data/* !data/.gitkeep diff --git a/README.md b/README.md index 3b4cd54..68d3d5c 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,12 @@ cp .env.example .env - `AUTO_LOGIN_MAX_RETRY`:自动登录重试次数 - `LINGMA_USERNAME`:Lingma 登录用户名 - `LINGMA_PASSWORD`:Lingma 登录密码 +- `METRICS_TOKEN`:`/metrics` 独立鉴权 token(留空则 `API_KEYS` 也可访问;两者都留空时 `/metrics` 为公开) +- `LOG_LEVEL`:日志级别(默认 `INFO`,输出结构化 JSON,包含 `request_id`) +- `GATEWAY_MAX_IN_FLIGHT`:`/v1/chat/completions` 并发上限(默认 4,`<=0` 表示不限流) +- `GATEWAY_QUEUE_TIMEOUT_SEC`:排队等待超时秒数(默认 30,超过后直接 429 + `Retry-After`) +- `LINGMA_ACCOUNTS`:多账号实例池,格式 `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号起一个独立 Lingma 子进程 +- `LINGMA_INSTANCE_COUNT`:实例数(默认等于账号数;显式指定且不足时账号会循环复用) ### `.env` 最小必填示例 @@ -85,7 +91,18 @@ DEDICATED_DOMAIN_URL= - 本项目所有持久化数据都在 `./data`: - `data/bin/Lingma`:自动提取的 Lingma 二进制 - - `data/.lingma/...`:Lingma 登录态、缓存、日志 + - `data/.lingma/...`:Lingma 登录态、缓存、日志(单实例模式) + - `data/.lingma/pool/inst-/...`:多实例模式下每个实例独立的登录态/缓存 + +### 多实例池(方案乙:多账号) + +启用方式:在 `.env` 里配置 `LINGMA_ACCOUNTS=u1:p1,u2:p2`,重启容器即可。 + +- 每个账号对应一个独立 Lingma 子进程,各自独立登录、独立 workDir。 +- 路由策略:同一 `user` 字段或同一 system prompt 的请求粘性路由到同一实例;其余按 least-in-flight 分配。 +- 一个实例挂了/断连不影响整体,`/healthz` 汇报 `pool_ready` 计数。 +- `/internal/stats.pool` 按实例粒度暴露状态,`/metrics` 增加 `gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`。 +- 未配置 `LINGMA_ACCOUNTS` 时自动退化为单实例模式(沿用 `LINGMA_USERNAME/LINGMA_PASSWORD`),向下兼容。 ## 3. Docker 运行 @@ -163,13 +180,16 @@ curl -s http://127.0.0.1:8317/internal/stats \ ``` ```bash -curl -s http://127.0.0.1:8317/metrics +curl -s http://127.0.0.1:8317/metrics \ + -H "Authorization: Bearer ${METRICS_TOKEN:-sk-your-api-key}" ``` 说明: - `usage.prompt_tokens/completion_tokens` 为估算值(按字节近似换算)。 - 非流式响应里会附带 `usage` 字段。 +- 流式响应可传 `stream_options: {"include_usage": true}` 让最后一帧返回 `usage`。 +- `/metrics` 默认需要 Bearer 鉴权:优先匹配 `METRICS_TOKEN`,否则接受 `API_KEYS` 里任意一个;两者都未配置时保持公开。 ## 6. 容器内自动登录 diff --git a/app/auth.py b/app/auth.py index 574db14..c320933 100644 --- a/app/auth.py +++ b/app/auth.py @@ -1,12 +1,11 @@ from __future__ import annotations +import hmac + from fastapi import HTTPException, Request, status -def require_bearer(request: Request, api_keys: list[str]) -> None: - if not api_keys: - return - +def _extract_bearer(request: Request) -> str: auth = request.headers.get("authorization", "") if not auth.startswith("Bearer "): raise HTTPException( @@ -19,9 +18,22 @@ def require_bearer(request: Request, api_keys: list[str]) -> None: } }, ) + return auth[len("Bearer ") :].strip() - token = auth[len("Bearer ") :].strip() - if token not in api_keys: + +def _match_any(token: str, candidates: list[str]) -> bool: + for c in candidates: + if c and hmac.compare_digest(token, c): + return True + return False + + +def require_bearer(request: Request, api_keys: list[str]) -> None: + # Empty api_keys means auth is disabled (keeps the old behavior). + if not api_keys: + return + token = _extract_bearer(request) + if not _match_any(token, api_keys): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={ @@ -32,3 +44,31 @@ def require_bearer(request: Request, api_keys: list[str]) -> None: } }, ) + + +def require_metrics_access( + request: Request, api_keys: list[str], metrics_token: str +) -> None: + """Allow metrics if any of: METRICS_TOKEN matches, or any API_KEYS match. + + If neither METRICS_TOKEN nor API_KEYS are configured, metrics is public + (backwards compatible default). + """ + accepted: list[str] = [] + if metrics_token: + accepted.append(metrics_token) + accepted.extend(api_keys) + if not accepted: + return + token = _extract_bearer(request) + if not _match_any(token, accepted): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={ + "error": { + "message": "Invalid metrics token", + "type": "invalid_request_error", + "code": "invalid_api_key", + } + }, + ) diff --git a/app/concurrency.py b/app/concurrency.py new file mode 100644 index 0000000..8998174 --- /dev/null +++ b/app/concurrency.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio + +from .logging_config import get_logger + + +logger = get_logger("lingma_gateway.concurrency") + + +class BackpressureRejected(Exception): + """Raised when a request cannot acquire an in-flight slot before timeout.""" + + def __init__(self, retry_after: float): + super().__init__(f"backpressure rejected, retry_after={retry_after:.1f}s") + self.retry_after = retry_after + + +class InFlightTicket: + """Reference-counted handle for a single in-flight slot. + + Release is idempotent so callers can defensively `release()` from multiple + cleanup paths (stream finally + outer exception handler) without worrying. + """ + + __slots__ = ("_parent", "_released") + + def __init__(self, parent: "InFlightGuard | None"): + self._parent = parent + self._released = False + + def release(self) -> None: + if self._released or self._parent is None: + self._released = True + return + self._released = True + self._parent._on_release() + + async def __aenter__(self) -> "InFlightTicket": + return self + + async def __aexit__(self, *_exc) -> None: + self.release() + + +class InFlightGuard: + """Async semaphore wrapper with queue/reject accounting and Prometheus hooks. + + - `max_in_flight <= 0` disables limiting (back-compat, unlimited). + - `queue_timeout_sec` bounds how long a request may wait for a slot. On + timeout, `try_acquire()` raises `BackpressureRejected`. + """ + + def __init__(self, max_in_flight: int, queue_timeout_sec: float): + self.max = max(0, int(max_in_flight)) + self.queue_timeout = max(0.0, float(queue_timeout_sec)) + self._sem: asyncio.Semaphore | None = ( + asyncio.Semaphore(self.max) if self.max > 0 else None + ) + self.in_flight = 0 + self.queued = 0 + self.accepted_total = 0 + self.rejected_total = 0 + + async def try_acquire(self) -> InFlightTicket: + if self._sem is None: + self.in_flight += 1 + self.accepted_total += 1 + return InFlightTicket(parent=self) + + self.queued += 1 + try: + if self.queue_timeout <= 0: + await self._sem.acquire() + else: + try: + await asyncio.wait_for(self._sem.acquire(), timeout=self.queue_timeout) + except (asyncio.TimeoutError, TimeoutError): + self.rejected_total += 1 + logger.warning( + "backpressure rejected: in_flight=%d queued=%d max=%d", + self.in_flight, + self.queued - 1, + self.max, + ) + raise BackpressureRejected(retry_after=self.queue_timeout) + finally: + self.queued -= 1 + + self.in_flight += 1 + self.accepted_total += 1 + return InFlightTicket(parent=self) + + def _on_release(self) -> None: + self.in_flight -= 1 + if self._sem is not None: + self._sem.release() + + def stats(self) -> dict: + return { + "max_in_flight": self.max, + "in_flight": self.in_flight, + "queued": self.queued, + "accepted_total": self.accepted_total, + "rejected_total": self.rejected_total, + "queue_timeout_sec": self.queue_timeout, + } + + def prometheus_lines(self) -> list[str]: + return [ + "# TYPE gateway_in_flight gauge", + f"gateway_in_flight {self.in_flight}", + "# TYPE gateway_queued gauge", + f"gateway_queued {self.queued}", + "# TYPE gateway_max_in_flight gauge", + f"gateway_max_in_flight {self.max}", + "# TYPE gateway_accepted_total counter", + f"gateway_accepted_total {self.accepted_total}", + "# TYPE gateway_rejected_total counter", + f"gateway_rejected_total {self.rejected_total}", + ] diff --git a/app/config.py b/app/config.py index 99dd3c8..5b96094 100644 --- a/app/config.py +++ b/app/config.py @@ -1,8 +1,14 @@ from __future__ import annotations +import json import os -from dataclasses import dataclass -from pathlib import Path +from dataclasses import dataclass, field + + +@dataclass +class LingmaAccount: + username: str + password: str @dataclass @@ -10,6 +16,10 @@ class Settings: host: str port: int api_keys: list[str] + metrics_token: str + log_level: str + gateway_max_in_flight: int + gateway_queue_timeout_sec: float lingma_bin: str lingma_work_dir: str lingma_socket_port: int @@ -22,8 +32,57 @@ class Settings: auto_login_headless: bool auto_login_timeout: int auto_login_max_retry: int - lingma_username: str - lingma_password: str + accounts: list[LingmaAccount] = field(default_factory=list) + instance_count: int = 1 + + +def _bool_env(name: str, default: bool) -> bool: + raw = os.getenv(name) + if raw is None: + return default + return raw.strip().lower() in {"1", "true", "yes", "on"} + + +def _parse_accounts(raw: str) -> list[LingmaAccount]: + """Parse LINGMA_ACCOUNTS. + + Accepted formats: + - JSON array: `[{"username":"u1","password":"p1"},{"username":"u2","password":"p2"}]` + - CSV: `u1:p1,u2:p2` + - Newlines: `u1:p1\nu2:p2` + + Whitespace around entries is trimmed. Empty entries are ignored. + Passwords containing ':' are supported (only the first ':' is the separator). + """ + raw = (raw or "").strip() + if not raw: + return [] + + if raw.startswith("["): + try: + data = json.loads(raw) + except Exception: + return [] + out: list[LingmaAccount] = [] + if isinstance(data, list): + for item in data: + if isinstance(item, dict): + u = str(item.get("username", "")).strip() + p = str(item.get("password", "")).strip() + if u and p: + out.append(LingmaAccount(u, p)) + return out + + out: list[LingmaAccount] = [] + for entry in raw.replace("\n", ",").split(","): + entry = entry.strip() + if not entry or ":" not in entry: + continue + u, p = entry.split(":", 1) + u, p = u.strip(), p.strip() + if u and p: + out.append(LingmaAccount(u, p)) + return out def load_settings() -> Settings: @@ -33,10 +92,31 @@ def load_settings() -> Settings: "LINGMA_WORK_DIR", "/app/data/.lingma/vscode/sharedClientCache", ) + + accounts = _parse_accounts(os.getenv("LINGMA_ACCOUNTS", "")) + if not accounts: + u = os.getenv("LINGMA_USERNAME", "").strip() + p = os.getenv("LINGMA_PASSWORD", "").strip() + if u and p: + accounts.append(LingmaAccount(u, p)) + + explicit_count = os.getenv("LINGMA_INSTANCE_COUNT", "").strip() + if explicit_count: + try: + instance_count = max(1, int(explicit_count)) + except ValueError: + instance_count = len(accounts) or 1 + else: + instance_count = max(1, len(accounts)) if accounts else 1 + return Settings( host=os.getenv("HOST", "0.0.0.0"), port=int(os.getenv("PORT", "8317")), api_keys=api_keys, + metrics_token=os.getenv("METRICS_TOKEN", "").strip(), + log_level=os.getenv("LOG_LEVEL", "INFO").strip() or "INFO", + gateway_max_in_flight=int(os.getenv("GATEWAY_MAX_IN_FLIGHT", "4")), + gateway_queue_timeout_sec=float(os.getenv("GATEWAY_QUEUE_TIMEOUT_SEC", "30")), lingma_bin=os.getenv("LINGMA_BIN", "/app/data/bin/Lingma"), lingma_work_dir=work_dir, lingma_socket_port=int(os.getenv("LINGMA_SOCKET_PORT", "36510")), @@ -45,10 +125,10 @@ def load_settings() -> Settings: default_model=os.getenv("DEFAULT_MODEL", "org_auto"), default_ask_mode=os.getenv("DEFAULT_ASK_MODE", "chat"), dedicated_domain_url=os.getenv("DEDICATED_DOMAIN_URL", "").strip(), - auto_login_enabled=os.getenv("AUTO_LOGIN_ENABLED", "true").lower() in {"1", "true", "yes", "on"}, - auto_login_headless=os.getenv("AUTO_LOGIN_HEADLESS", "true").lower() in {"1", "true", "yes", "on"}, + auto_login_enabled=_bool_env("AUTO_LOGIN_ENABLED", True), + auto_login_headless=_bool_env("AUTO_LOGIN_HEADLESS", True), auto_login_timeout=int(os.getenv("AUTO_LOGIN_TIMEOUT", "180")), auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")), - lingma_username=os.getenv("LINGMA_USERNAME", "").strip(), - lingma_password=os.getenv("LINGMA_PASSWORD", "").strip(), + accounts=accounts, + instance_count=instance_count, ) diff --git a/app/lingma_client.py b/app/lingma_client.py index 36ed509..9c4c80d 100644 --- a/app/lingma_client.py +++ b/app/lingma_client.py @@ -9,10 +9,23 @@ import subprocess import time import uuid from pathlib import Path -from typing import AsyncIterator +from typing import AsyncIterator, Callable, Optional import websockets +from .logging_config import get_logger + + +logger = get_logger("lingma_gateway.client") + + +# Some callers live on Python 3.10 where asyncio.TimeoutError is a distinct class, +# while 3.11+ unifies it with the builtin TimeoutError. Always catch both. +TIMEOUT_EXCEPTIONS: tuple[type[BaseException], ...] = ( + asyncio.TimeoutError, + TimeoutError, +) + def _is_port_open(host: str, port: int, timeout_sec: float = 0.5) -> bool: try: @@ -79,23 +92,37 @@ def _parse_lsp_frames(buf: bytes): class LspWsRpcClient: - def __init__(self, ws): + def __init__(self, ws, on_disconnect: Optional[Callable[[BaseException], None]] = None): self.ws = ws self._id = 1 self._pending: dict[int, asyncio.Future] = {} self._send_lock = asyncio.Lock() - self._reader_task = None + self._reader_task: asyncio.Task | None = None self._rx_buffer = b"" self._chat_streams: dict[str, dict] = {} + self._on_disconnect = on_disconnect + self._closed = False async def start(self): self._reader_task = asyncio.create_task(self._reader_loop()) async def close(self): + self._closed = True if self._reader_task: self._reader_task.cancel() with contextlib.suppress(Exception): await self._reader_task + # Abort any pending futures so callers fail fast instead of hanging. + for fut in self._pending.values(): + if not fut.done(): + fut.set_exception(ConnectionError("lingma client closed")) + self._pending.clear() + # Signal open streams to terminate. + for stream in self._chat_streams.values(): + if not stream["done"].is_set(): + stream["done"].set() + stream["chunks"].put_nowait(None) + self._chat_streams.clear() async def _send(self, payload: dict): async with self._send_lock: @@ -127,10 +154,23 @@ class LspWsRpcClient: except asyncio.CancelledError: pass except Exception as exc: + if not self._closed: + logger.warning("lingma reader loop terminated: %s", exc) + # Propagate failure to anyone waiting on an RPC. for fut in self._pending.values(): if not fut.done(): fut.set_exception(exc) self._pending.clear() + # Also unblock any in-flight chat streams so consumers exit. + for stream in self._chat_streams.values(): + if not stream["done"].is_set(): + stream["done"].set() + stream["chunks"].put_nowait(None) + if not self._closed and self._on_disconnect is not None: + try: + self._on_disconnect(exc) + except Exception: + logger.exception("on_disconnect callback failed") async def _handle_server_message(self, msg: dict): method = msg.get("method") @@ -168,7 +208,7 @@ class LspWsRpcClient: await self._send(payload) try: msg = await asyncio.wait_for(fut, timeout=timeout) - except TimeoutError: + except TIMEOUT_EXCEPTIONS: self._pending.pop(rid, None) raise TimeoutError(f"RPC timeout: {method}") if "error" in msg: @@ -189,8 +229,20 @@ class LspWsRpcClient: "finish_at": None, } + def pop_stream(self, request_id: str) -> None: + stream = self._chat_streams.pop(request_id, None) + if stream is None: + return + # Drain queue so no stray future gets stuck if the consumer bailed early. + if not stream["done"].is_set(): + stream["done"].set() + with contextlib.suppress(Exception): + stream["chunks"].put_nowait(None) + async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]: - stream = self._chat_streams[request_id] + stream = self._chat_streams.get(request_id) + if stream is None: + return start = time.monotonic() while True: remain = timeout - (time.monotonic() - start) @@ -218,6 +270,19 @@ class LspWsRpcClient: class LingmaGatewayClient: + """Owns the Lingma subprocess and the LSP-over-WS connection. + + Adds a small state machine + reconnect loop so the gateway can survive Lingma + restarts and slow cold starts without bringing down the FastAPI app. + """ + + STATE_STOPPED = "stopped" + STATE_STARTING = "starting" + STATE_READY = "ready" + STATE_RECONNECTING = "reconnecting" + STATE_FAILED = "failed" + STATE_CLOSED = "closed" + def __init__( self, lingma_bin: str, @@ -227,7 +292,11 @@ class LingmaGatewayClient: rpc_timeout: int, default_model: str, default_ask_mode: str, + *, + name: str = "lingma", + extra_info_paths: list[Path] | None = None, ): + self.name = name self.lingma_bin = Path(lingma_bin) self.work_dir = Path(work_dir) self.socket_port = socket_port @@ -235,19 +304,115 @@ class LingmaGatewayClient: self.rpc_timeout = rpc_timeout self.default_model = default_model self.default_ask_mode = default_ask_mode + # Each pool instance should only look at its own workDir .info to avoid + # cross-instance clobbering via the shared ~/.lingma/.info path. + if extra_info_paths is None: + extra_info_paths = [Path.home() / ".lingma" / ".info"] + self._extra_info_paths = list(extra_info_paths) self._rpc: LspWsRpcClient | None = None self._ws = None + self._state = self.STATE_STOPPED + self._state_lock = asyncio.Lock() + self._ready_event = asyncio.Event() + self._reconnect_task: asyncio.Task | None = None + self._last_error: str = "" + + # ------------------------------------------------------------------ state + + @property + def state(self) -> str: + return self._state + + @property + def last_error(self) -> str: + return self._last_error + + def _set_state(self, state: str, err: str = "") -> None: + if state != self._state: + logger.info("lingma client state %s -> %s", self._state, state, extra={"ctx_new_state": state}) + self._state = state + if err: + self._last_error = err + if state == self.STATE_READY: + self._ready_event.set() + else: + self._ready_event.clear() + + # -------------------------------------------------------------- lifecycle + + async def start(self) -> None: + """Initial start. Failure is non-fatal: ensure_ready() will retry later.""" + try: + await self._connect(initial=True) + except Exception as exc: + self._set_state(self.STATE_FAILED, err=str(exc)) + logger.exception("initial lingma start failed; will retry on demand") + + async def close(self) -> None: + self._set_state(self.STATE_CLOSED) + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + with contextlib.suppress(Exception): + await self._reconnect_task + if self._rpc: + await self._rpc.close() + if self._ws: + with contextlib.suppress(Exception): + await self._ws.close() + + async def ensure_ready(self, timeout: float | None = None) -> None: + """Block until the RPC connection is usable, (re)connecting on demand.""" + if self._state == self.STATE_CLOSED: + raise RuntimeError("lingma client is closed") + if self._state == self.STATE_READY and self._ws is not None: + return + + async with self._state_lock: + if self._state == self.STATE_READY and self._ws is not None: + return + if self._state in (self.STATE_STOPPED, self.STATE_FAILED): + try: + await self._connect(initial=False) + return + except Exception as exc: + self._set_state(self.STATE_FAILED, err=str(exc)) + raise + + wait_timeout = timeout if timeout is not None else max( + 30.0, float(self.startup_timeout) + 10.0 + ) + try: + await asyncio.wait_for(self._ready_event.wait(), timeout=wait_timeout) + except TIMEOUT_EXCEPTIONS: + raise RuntimeError(f"lingma not ready (state={self._state}, err={self._last_error})") + + # --------------------------------------------------------------- connect + + async def _connect(self, *, initial: bool) -> None: + self._set_state(self.STATE_STARTING) - async def start(self): if not self.lingma_bin.exists(): raise FileNotFoundError(f"Lingma not found: {self.lingma_bin}") - if not _is_port_open("127.0.0.1", self.socket_port): + + info_paths = [self.work_dir / ".info", *self._extra_info_paths] + + # socket_port <= 0 is the pool-friendly "always spawn and read .info" mode. + port_prewarmed = self.socket_port > 0 and _is_port_open( + "127.0.0.1", self.socket_port + ) + if not port_prewarmed: self.work_dir.mkdir(parents=True, exist_ok=True) # Remove stale info files from host-mounted workspace before boot. - for p in [self.work_dir / ".info", Path.home() / ".lingma" / ".info"]: + for p in info_paths: with contextlib.suppress(Exception): if p.exists(): p.unlink() + logger.info( + "[%s] spawning lingma: %s start --workDir %s", + self.name, + self.lingma_bin, + self.work_dir, + ) subprocess.Popen( [str(self.lingma_bin), "start", "--workDir", str(self.work_dir)], cwd=str(self.lingma_bin.parent), @@ -255,13 +420,9 @@ class LingmaGatewayClient: stderr=subprocess.DEVNULL, start_new_session=True, ) - info, _, _ = _wait_info_any( - [self.work_dir / ".info", Path.home() / ".lingma" / ".info"], - timeout_sec=self.startup_timeout, - ) + info, _, _ = _wait_info_any(info_paths, timeout_sec=self.startup_timeout) self.socket_port = info - # Wait for socket to actually become connectable. deadline = time.time() + self.startup_timeout while time.time() < deadline: if _is_port_open("127.0.0.1", self.socket_port, timeout_sec=0.3): @@ -270,9 +431,19 @@ class LingmaGatewayClient: else: raise TimeoutError(f"Lingma socket not open on port {self.socket_port}") + # Close any stale ws/rpc before creating fresh ones (reconnect path). + if self._rpc is not None: + with contextlib.suppress(Exception): + await self._rpc.close() + self._rpc = None + if self._ws is not None: + with contextlib.suppress(Exception): + await self._ws.close() + self._ws = None + ws_url = f"ws://127.0.0.1:{self.socket_port}" self._ws = await websockets.connect(ws_url, max_size=10 * 1024 * 1024) - self._rpc = LspWsRpcClient(self._ws) + self._rpc = LspWsRpcClient(self._ws, on_disconnect=self._on_disconnect) await self._rpc.start() await self._rpc.request( "initialize", @@ -286,32 +457,73 @@ class LingmaGatewayClient: timeout=self.rpc_timeout, ) await self._rpc.notify("initialized", {}) + self._set_state(self.STATE_READY) + logger.info( + "[%s] lingma ready on port %d (initial=%s)", + self.name, + self.socket_port, + initial, + ) - async def close(self): - if self._rpc: - await self._rpc.close() - if self._ws: - await self._ws.close() + def _on_disconnect(self, exc: BaseException) -> None: + if self._state == self.STATE_CLOSED: + return + self._set_state(self.STATE_RECONNECTING, err=str(exc)) + if self._reconnect_task and not self._reconnect_task.done(): + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._reconnect_task = loop.create_task(self._reconnect_loop()) + + async def _reconnect_loop(self) -> None: + backoff = 1.0 + max_backoff = 30.0 + max_attempts = 20 + for attempt in range(1, max_attempts + 1): + if self._state == self.STATE_CLOSED: + return + await asyncio.sleep(backoff) + try: + async with self._state_lock: + await self._connect(initial=False) + logger.info("lingma reconnected after %d attempt(s)", attempt) + return + except Exception as exc: + self._last_error = str(exc) + logger.warning("lingma reconnect attempt %d failed: %s", attempt, exc) + backoff = min(backoff * 2, max_backoff) + self._set_state(self.STATE_FAILED, err="reconnect exhausted") + + # ------------------------------------------------------------------ RPC @property def rpc(self) -> LspWsRpcClient: if self._rpc is None: - raise RuntimeError("Lingma RPC not initialized") + raise RuntimeError(f"Lingma RPC not initialized (state={self._state})") return self._rpc async def auth_status(self): + await self.ensure_ready() return await self.rpc.request("auth/status", {}, timeout=self.rpc_timeout) async def query_models(self): + await self.ensure_ready() return await self.rpc.request("config/queryModels", {}, timeout=self.rpc_timeout) async def get_endpoint(self): + await self.ensure_ready() return await self.rpc.request("config/getEndpoint", {}, timeout=self.rpc_timeout) async def update_endpoint(self, endpoint: str): - return await self.rpc.request("config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout) + await self.ensure_ready() + return await self.rpc.request( + "config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout + ) async def generate_login_url(self): + await self.ensure_ready() result = await self.rpc.request("login/generateUrl", {}, timeout=self.rpc_timeout) if isinstance(result, str): return result, {"raw": result} @@ -322,6 +534,8 @@ class LingmaGatewayClient: return "", result return "", {"raw": result} + # ------------------------------------------------------------------ chat + def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str): session_type = "developer" if ask_mode == "agent" else "chat" return { @@ -355,17 +569,24 @@ class LingmaGatewayClient: } async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict: + await self.ensure_ready() request_id = str(uuid.uuid4()) session_id = str(uuid.uuid4()) payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id) self.rpc.create_stream(request_id) try: - await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout) - except (TimeoutError, asyncio.TimeoutError): - pass - async for _ in self.rpc.consume_stream(request_id, timeout=max(20.0, self.rpc_timeout + 20.0)): - pass - result = self.rpc.get_stream_result(request_id) + try: + await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout) + except TIMEOUT_EXCEPTIONS: + # chat/ask often returns nothing until chat/finish arrives; tolerate. + pass + async for _ in self.rpc.consume_stream( + request_id, timeout=max(20.0, self.rpc_timeout + 20.0) + ): + pass + result = self.rpc.get_stream_result(request_id) + finally: + self.rpc.pop_stream(request_id) finish = result.get("finish") or {} result["requestId"] = request_id result["sessionId"] = finish.get("sessionId") or session_id @@ -374,13 +595,20 @@ class LingmaGatewayClient: return result async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]: + await self.ensure_ready() request_id = str(uuid.uuid4()) session_id = str(uuid.uuid4()) payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id) self.rpc.create_stream(request_id) try: - await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout) - except (TimeoutError, asyncio.TimeoutError): - pass - async for chunk in self.rpc.consume_stream(request_id, timeout=max(20.0, self.rpc_timeout + 40.0)): - yield chunk + try: + await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout) + except TIMEOUT_EXCEPTIONS: + pass + async for chunk in self.rpc.consume_stream( + request_id, timeout=max(20.0, self.rpc_timeout + 40.0) + ): + yield chunk + finally: + # Runs on normal completion, exception, or consumer GeneratorExit (client disconnect). + self.rpc.pop_stream(request_id) diff --git a/app/lingma_pool.py b/app/lingma_pool.py new file mode 100644 index 0000000..476ffdd --- /dev/null +++ b/app/lingma_pool.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from pathlib import Path + +from .auto_login import AutoLoginManager +from .config import LingmaAccount +from .lingma_client import LingmaGatewayClient +from .logging_config import get_logger + + +logger = get_logger("lingma_gateway.pool") + + +@dataclass +class InstanceConfig: + index: int + name: str + work_dir: str + socket_port: int + account: LingmaAccount + + +class PoolInstance: + """A single Lingma process + its auto_login + in-flight counter.""" + + __slots__ = ("cfg", "client", "auto_login", "in_flight") + + def __init__( + self, + cfg: InstanceConfig, + client: LingmaGatewayClient, + auto_login: AutoLoginManager, + ): + self.cfg = cfg + self.client = client + self.auto_login = auto_login + self.in_flight = 0 + + @property + def name(self) -> str: + return self.cfg.name + + @property + def healthy(self) -> bool: + return self.client.state == LingmaGatewayClient.STATE_READY + + +class LingmaPool: + """N-Lingma process pool with least-in-flight + affinity routing. + + For N=1 this degenerates into the original single-client setup, preserving + backwards compatibility with `LINGMA_USERNAME/LINGMA_PASSWORD`-only deploys. + """ + + def __init__(self, instances: list[PoolInstance]): + if not instances: + raise RuntimeError("LingmaPool requires at least 1 instance") + self._instances: list[PoolInstance] = instances + self._rr_counter = 0 + + @classmethod + def build( + cls, + *, + lingma_bin: str, + base_work_dir: str, + legacy_socket_port: int, + startup_timeout: int, + rpc_timeout: int, + default_model: str, + default_ask_mode: str, + accounts: list[LingmaAccount], + instance_count: int, + auto_login_headless: bool, + auto_login_timeout: int, + auto_login_max_retry: int, + verify_timeout_sec: int | None = None, + ) -> "LingmaPool": + """Materialize N PoolInstances. + + Single-instance (N=1) uses the legacy workDir and LINGMA_SOCKET_PORT so + existing deployments keep their state after upgrade. N>1 derives per-instance + workDirs under `/../pool/inst-` and uses dynamic ports. + """ + + if instance_count < 1: + instance_count = 1 + + resolved_accounts: list[LingmaAccount] = [] + for i in range(instance_count): + if accounts: + resolved_accounts.append(accounts[i % len(accounts)]) + else: + resolved_accounts.append(LingmaAccount(username="", password="")) + + if instance_count > len(accounts) and accounts: + logger.warning( + "instance_count=%d exceeds unique accounts=%d; accounts will be reused", + instance_count, + len(accounts), + ) + + base_dir = Path(base_work_dir) + # Put per-instance workDirs under `/.lingma/pool/inst-`. + # Walk up past the vscode/sharedClientCache layout if present. + pool_root = base_dir + for _ in range(3): + if pool_root.name == ".lingma": + break + if pool_root.parent == pool_root: + break + pool_root = pool_root.parent + pool_root = pool_root / "pool" + + instances: list[PoolInstance] = [] + for i, acc in enumerate(resolved_accounts): + if instance_count == 1: + work_dir = str(base_dir) + socket_port = legacy_socket_port + extra_info: list[Path] | None = None + else: + work_dir = str(pool_root / f"inst-{i}") + socket_port = 0 + # In pool mode each instance reads only its own workDir .info to + # avoid the shared ~/.lingma/.info race between instances. + extra_info = [] + + name = f"inst-{i}" + + client = LingmaGatewayClient( + lingma_bin=lingma_bin, + work_dir=work_dir, + socket_port=socket_port, + startup_timeout=startup_timeout, + rpc_timeout=rpc_timeout, + default_model=default_model, + default_ask_mode=default_ask_mode, + name=name, + extra_info_paths=extra_info, + ) + + def _make_verify(_client: LingmaGatewayClient): + async def _verify() -> bool: + try: + st = await _client.auth_status() + except Exception: + return False + return bool(st and st.get("id")) + + return _verify + + auto_login = AutoLoginManager( + username=acc.username, + password=acc.password, + headless=auto_login_headless, + timeout_sec=auto_login_timeout, + max_retry=auto_login_max_retry, + verify_logged_in=_make_verify(client), + verify_timeout_sec=verify_timeout_sec + or max(30, min(180, auto_login_timeout)), + debug_dir=f"/tmp/lingma-auto-login/{name}", + ) + + cfg = InstanceConfig( + index=i, + name=name, + work_dir=work_dir, + socket_port=socket_port, + account=acc, + ) + instances.append(PoolInstance(cfg, client, auto_login)) + + return cls(instances) + + # -------------------------------------------------------------- lifecycle + + async def start(self) -> None: + """Start all instances sequentially. + + Sequential startup avoids racing on the shared ~/.lingma/.info file (for + pool-mode we skip it anyway, but Lingma may still write there internally) + and keeps docker logs readable. Failures are non-fatal; per-instance + reconnect loops will take over. + """ + for inst in self._instances: + logger.info( + "pool starting %s (workDir=%s port=%d account=%s)", + inst.name, + inst.cfg.work_dir, + inst.cfg.socket_port, + inst.cfg.account.username or "", + ) + try: + await inst.client.start() + except Exception as exc: + logger.warning("pool start %s failed: %s", inst.name, exc) + + async def close(self) -> None: + tasks = [asyncio.create_task(inst.client.close()) for inst in self._instances] + for t in tasks: + try: + await t + except Exception: + pass + + # -------------------------------------------------------------- inspection + + @property + def instances(self) -> list[PoolInstance]: + return list(self._instances) + + def size(self) -> int: + return len(self._instances) + + def stats(self) -> list[dict]: + return [ + { + "index": inst.cfg.index, + "name": inst.name, + "state": inst.client.state, + "last_error": inst.client.last_error, + "in_flight": inst.in_flight, + "work_dir": inst.cfg.work_dir, + "socket_port": inst.cfg.socket_port, + "username": inst.cfg.account.username, + "auto_login": inst.auto_login.status(), + } + for inst in self._instances + ] + + def prometheus_lines(self) -> list[str]: + lines: list[str] = [ + "# TYPE gateway_pool_instance_in_flight gauge", + "# TYPE gateway_pool_instance_ready gauge", + ] + for inst in self._instances: + lbl = f'name="{inst.name}",idx="{inst.cfg.index}"' + lines.append(f"gateway_pool_instance_in_flight{{{lbl}}} {inst.in_flight}") + lines.append( + f"gateway_pool_instance_ready{{{lbl}}} {1 if inst.healthy else 0}" + ) + return lines + + # -------------------------------------------------------------- selection + + def pick(self, affinity_key: str | None = None) -> PoolInstance: + """Pick an instance for a request. + + Preference order: + 1. Sticky affinity if `affinity_key` is provided and the bucket is healthy. + 2. Least-in-flight among healthy instances. + 3. Round-robin fallback when nothing is healthy (lazy-start will kick in). + """ + if not self._instances: + raise RuntimeError("lingma pool is empty") + + healthy = [i for i in self._instances if i.healthy] + + if affinity_key: + bucket = self._instances[ + abs(hash(affinity_key)) % len(self._instances) + ] + if bucket.healthy: + return bucket + + if healthy: + return min(healthy, key=lambda x: (x.in_flight, x.cfg.index)) + + # Nothing healthy. Fall back to round-robin so every instance gets a + # chance to reconnect via ensure_ready(). + idx = self._rr_counter % len(self._instances) + self._rr_counter += 1 + return self._instances[idx] diff --git a/app/logging_config.py b/app/logging_config.py new file mode 100644 index 0000000..a7a6ace --- /dev/null +++ b/app/logging_config.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import contextvars +import json +import logging +import sys +import time + + +request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar( + "request_id", default="-" +) + + +class _JsonFormatter(logging.Formatter): + def format(self, record: logging.LogRecord) -> str: + ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(record.created)) + data: dict = { + "ts": f"{ts}.{int(record.msecs):03d}Z", + "level": record.levelname, + "logger": record.name, + "msg": record.getMessage(), + "request_id": request_id_var.get(), + } + if record.exc_info: + data["exc"] = self.formatException(record.exc_info) + for key, val in record.__dict__.items(): + if key.startswith("ctx_"): + data[key[4:]] = val + return json.dumps(data, ensure_ascii=False) + + +def configure_logging(level: str = "INFO") -> None: + level = (level or "INFO").upper() + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(_JsonFormatter()) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(level) + + # Align uvicorn access/error logs with our JSON formatter. + for name in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi"): + lg = logging.getLogger(name) + lg.handlers.clear() + lg.propagate = True + lg.setLevel(level) + + # Trim noisy libraries. + logging.getLogger("websockets").setLevel("WARNING") + logging.getLogger("websockets.client").setLevel("WARNING") + + +def get_logger(name: str = "lingma_gateway") -> logging.Logger: + return logging.getLogger(name) diff --git a/app/main.py b/app/main.py index 9259425..218c26e 100644 --- a/app/main.py +++ b/app/main.py @@ -1,16 +1,20 @@ from __future__ import annotations +import asyncio +import hashlib import json import time import uuid +from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse -from .auto_login import AutoLoginManager -from .auth import require_bearer +from .auth import require_bearer, require_metrics_access +from .concurrency import BackpressureRejected, InFlightGuard from .config import Settings, load_settings -from .lingma_client import LingmaGatewayClient +from .lingma_pool import LingmaPool, PoolInstance +from .logging_config import configure_logging, get_logger, request_id_var from .model_map import build_model_name_map, flatten_model_keys, resolve_model from .openai_schema import ( ChatCompletionChoice, @@ -18,107 +22,219 @@ from .openai_schema import ( ChatCompletionsRequest, ModelData, ModelsResponse, + flatten_content, ) from .stats import StatsCollector, estimate_tokens -app = FastAPI(title="Lingma OpenAI Gateway", version="0.1.0") settings: Settings = load_settings() -lingma: LingmaGatewayClient | None = None -auto_login: AutoLoginManager | None = None +configure_logging(settings.log_level) +logger = get_logger("lingma_gateway") + +pool: LingmaPool | None = None stats_collector = StatsCollector() +chat_guard = InFlightGuard( + max_in_flight=settings.gateway_max_in_flight, + queue_timeout_sec=settings.gateway_queue_timeout_sec, +) + + +def _require_pool() -> LingmaPool: + if pool is None: + raise HTTPException( + status_code=503, + detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}}, + ) + return pool + + +@asynccontextmanager +async def lifespan(_app: FastAPI): + global pool + pool = LingmaPool.build( + lingma_bin=settings.lingma_bin, + base_work_dir=settings.lingma_work_dir, + legacy_socket_port=settings.lingma_socket_port, + startup_timeout=settings.lingma_startup_timeout, + rpc_timeout=settings.lingma_rpc_timeout, + default_model=settings.default_model, + default_ask_mode=settings.default_ask_mode, + accounts=settings.accounts, + instance_count=settings.instance_count, + auto_login_headless=settings.auto_login_headless, + auto_login_timeout=settings.auto_login_timeout, + auto_login_max_retry=settings.auto_login_max_retry, + ) + logger.info( + "gateway startup: pool_size=%d max_in_flight=%d", + pool.size(), + settings.gateway_max_in_flight, + ) + await pool.start() + try: + yield + finally: + if pool is not None: + await pool.close() + + +app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan) + + +@app.middleware("http") +async def request_id_middleware(request: Request, call_next): + req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}" + token = request_id_var.set(req_id) + start = time.monotonic() + status_code = 500 + try: + response = await call_next(request) + status_code = response.status_code + response.headers["x-request-id"] = req_id + return response + finally: + elapsed_ms = int((time.monotonic() - start) * 1000) + logger.info( + "http %s %s -> %s in %dms", + request.method, + request.url.path, + status_code, + elapsed_ms, + extra={ + "ctx_method": request.method, + "ctx_path": request.url.path, + "ctx_status": status_code, + "ctx_elapsed_ms": elapsed_ms, + }, + ) + request_id_var.reset(token) def auth_guard(request: Request): require_bearer(request, settings.api_keys) -async def _is_logged_in() -> bool: - assert lingma is not None - st = await lingma.auth_status() - return bool(st and st.get("id")) - - -@app.on_event("startup") -async def on_startup(): - global lingma, auto_login - lingma = LingmaGatewayClient( - lingma_bin=settings.lingma_bin, - work_dir=settings.lingma_work_dir, - socket_port=settings.lingma_socket_port, - startup_timeout=settings.lingma_startup_timeout, - rpc_timeout=settings.lingma_rpc_timeout, - default_model=settings.default_model, - default_ask_mode=settings.default_ask_mode, - ) - await lingma.start() - auto_login = AutoLoginManager( - username=settings.lingma_username, - password=settings.lingma_password, - headless=settings.auto_login_headless, - timeout_sec=settings.auto_login_timeout, - max_retry=settings.auto_login_max_retry, - verify_logged_in=_is_logged_in, - verify_timeout_sec=max(30, min(180, settings.auto_login_timeout)), - ) - - -@app.on_event("shutdown") -async def on_shutdown(): - if lingma: - await lingma.close() +def metrics_auth_guard(request: Request): + require_metrics_access(request, settings.api_keys, settings.metrics_token) @app.get("/healthz") async def healthz(): - return {"ok": True, "time": int(time.time())} + if pool is None: + return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"} + insts = pool.stats() + ready = sum(1 for i in insts if i["state"] == "ready") + return { + "ok": ready > 0, + "time": int(time.time()), + "pool_size": len(insts), + "pool_ready": ready, + "instances": [ + {"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]} + for i in insts + ], + } -async def _ensure_logged_in_or_auto_login() -> dict: - assert lingma is not None - status = await lingma.auth_status() +async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: + client = inst.client + auto_login = inst.auto_login + + try: + status = await client.auth_status() + except Exception as exc: + logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc) + raise HTTPException( + status_code=503, + detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}}, + ) + if status and status.get("id"): return status if not settings.auto_login_enabled: - raise HTTPException(status_code=401, detail={"error": {"message": "Lingma not logged in"}}) - - if settings.dedicated_domain_url: - current = await lingma.get_endpoint() - current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" - if current_ep != settings.dedicated_domain_url: - await lingma.update_endpoint(settings.dedicated_domain_url) - - login_url, login_raw = await lingma.generate_login_url() - if not login_url: raise HTTPException( - status_code=500, - detail={"error": {"message": f"generate login url failed: {login_raw}"}}, + status_code=401, + detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}}, + ) + + if settings.dedicated_domain_url: + try: + current = await client.get_endpoint() + current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + if current_ep != settings.dedicated_domain_url: + await client.update_endpoint(settings.dedicated_domain_url) + except Exception as exc: + logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc) + + try: + login_url, _login_raw = await client.generate_login_url() + except Exception as exc: + logger.warning("[%s] generate_login_url failed: %s", inst.name, exc) + raise HTTPException( + status_code=502, + detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, + ) + + if not login_url: + raise HTTPException( + status_code=502, + detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, ) - assert auto_login is not None await auto_login.ensure_started(login_url) try: await auto_login.wait_done(timeout=settings.auto_login_timeout + 20) - except Exception: - pass + except Exception as exc: + logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc) + + try: + status = await client.auth_status() + except Exception as exc: + logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc) + status = None - status = await lingma.auth_status() if status and status.get("id"): return status + logger.warning( + "[%s] auto login did not result in a logged-in session: %s", + inst.name, + auto_login.status(), + ) raise HTTPException( status_code=401, - detail={"error": {"message": "Lingma auto login failed", "auto_login": auto_login.status()}}, + detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}}, ) +def _affinity_key_for(req: ChatCompletionsRequest) -> str | None: + """Derive a stable affinity key so that follow-ups go to the same instance. + + Priority: explicit `user` > hash of the first/system message. + """ + if req.user: + return req.user.strip() or None + for m in req.messages: + if m.role == "system": + text = flatten_content(m.content) + if text: + return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] + if req.messages: + first = req.messages[0] + text = flatten_content(first.content) + if text: + return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] + return None + + @app.get("/v1/models", dependencies=[Depends(auth_guard)]) async def v1_models(): - assert lingma is not None - await _ensure_logged_in_or_auto_login() + p = _require_pool() + inst = p.pick() + await _ensure_instance_logged_in(inst) await stats_collector.inc_models() - models = await lingma.query_models() + models = await inst.client.query_models() keys = flatten_model_keys(models) name_map = build_model_name_map(models) resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys]) @@ -126,20 +242,32 @@ async def v1_models(): def _messages_to_prompt(messages: list[dict]) -> str: - parts = [] + parts: list[str] = [] for m in messages: role = m.get("role", "user") - content = m.get("content", "") - parts.append(f"[{role}] {content}") + text = flatten_content(m.get("content")) + if not text and m.get("tool_calls"): + text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}" + if not text: + continue + parts.append(f"[{role}] {text}") return "\n".join(parts).strip() +def _include_usage(stream_options: dict | None) -> bool: + if not isinstance(stream_options, dict): + return False + return bool(stream_options.get("include_usage")) + + @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) async def v1_chat_completions(req: ChatCompletionsRequest): - assert lingma is not None - await _ensure_logged_in_or_auto_login() + p = _require_pool() + affinity = _affinity_key_for(req) + inst = p.pick(affinity_key=affinity) + await _ensure_instance_logged_in(inst) - models = await lingma.query_models() + models = await inst.client.query_models() available = flatten_model_keys(models) name_map = build_model_name_map(models) model = resolve_model(req.model, available, settings.default_model, name_map) @@ -150,142 +278,270 @@ async def v1_chat_completions(req: ChatCompletionsRequest): prompt = _messages_to_prompt([m.model_dump() for m in req.messages]) if not prompt: - raise HTTPException(status_code=400, detail={"error": {"message": "messages is empty"}}) + raise HTTPException( + status_code=400, + detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}}, + ) prompt_tokens = estimate_tokens(prompt) + include_usage = _include_usage(req.stream_options) - if req.stream: - created = int(time.time()) - completion_id = f"chatcmpl-{uuid.uuid4().hex}" - completion_tokens_holder = {"n": 0} + # Backpressure: acquire a slot *after* the cheap validation but before any + # upstream call. This ensures we reject quickly when saturated. + try: + ticket = await chat_guard.try_acquire() + except BackpressureRejected as exc: + retry_after = max(1, int(exc.retry_after)) + logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after) + raise HTTPException( + status_code=429, + detail={ + "error": { + "message": "Too many in-flight requests, please retry later", + "type": "rate_limit_error", + "code": "backpressure", + } + }, + headers={"Retry-After": str(retry_after)}, + ) - async def event_stream(): - success = False - try: - async for chunk in lingma.chat_stream(prompt, model, ask_mode): - completion_tokens_holder["n"] += estimate_tokens(chunk) - payload = { + inst.in_flight += 1 + logger.info( + "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d", + inst.name, + model, + ask_mode, + req.stream, + prompt_tokens, + extra={ + "ctx_instance": inst.name, + "ctx_model": model, + "ctx_ask_mode": ask_mode, + "ctx_stream": req.stream, + "ctx_prompt_tokens": prompt_tokens, + "ctx_in_flight": chat_guard.in_flight, + "ctx_affinity": affinity, + }, + ) + + ticket_transferred = False + + try: + if req.stream: + created = int(time.time()) + completion_id = f"chatcmpl-{uuid.uuid4().hex}" + completion_tokens_holder = {"n": 0} + + async def event_stream(_ticket=ticket, _inst=inst): + success = False + try: + async for chunk in _inst.client.chat_stream(prompt, model, ask_mode): + completion_tokens_holder["n"] += estimate_tokens(chunk) + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": chunk}, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + + done_payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": chunk}, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" - done_payload = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], - } - yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - success = True - finally: - await stats_collector.record_chat( - stream=True, - success=success, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens_holder["n"], - ) + if include_usage: + usage_payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens_holder["n"], + "total_tokens": prompt_tokens + completion_tokens_holder["n"], + }, + } + yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n" - return StreamingResponse(event_stream(), media_type="text/event-stream") + yield "data: [DONE]\n\n" + success = True + except asyncio.CancelledError: + logger.info("chat.stream cancelled by client (inst=%s)", _inst.name) + raise + except Exception as exc: + logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc) + finally: + await stats_collector.record_chat( + stream=True, + success=success, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens_holder["n"], + ) + _inst.in_flight = max(0, _inst.in_flight - 1) + _ticket.release() - try: - result = await lingma.chat_complete(prompt, model, ask_mode) - except Exception: + ticket_transferred = True + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + try: + result = await inst.client.chat_complete(prompt, model, ask_mode) + except Exception as exc: + logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) + await stats_collector.record_chat( + stream=False, + success=False, + prompt_tokens=prompt_tokens, + completion_tokens=0, + ) + raise HTTPException( + status_code=502, + detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, + ) + + completion_tokens = estimate_tokens(result.get("text") or "") await stats_collector.record_chat( stream=False, - success=False, + success=True, prompt_tokens=prompt_tokens, - completion_tokens=0, + completion_tokens=completion_tokens, ) - raise - - completion_tokens = estimate_tokens(result.get("text") or "") - await stats_collector.record_chat( - stream=False, - success=True, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - response = ChatCompletionResponse( - id=f"chatcmpl-{uuid.uuid4().hex}", - created=int(time.time()), - model=model, - choices=[ - ChatCompletionChoice( - index=0, - finish_reason="stop", - message={"role": "assistant", "content": result.get("text") or ""}, - ) - ], - ) - data = response.model_dump() - data["latency"] = { - "first_token_ms": result.get("firstTokenLatencyMs"), - "total_ms": result.get("totalLatencyMs"), - } - data["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - return JSONResponse(content=data) + response = ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex}", + created=int(time.time()), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + finish_reason="stop", + message={"role": "assistant", "content": result.get("text") or ""}, + ) + ], + ) + data = response.model_dump() + data["latency"] = { + "first_token_ms": result.get("firstTokenLatencyMs"), + "total_ms": result.get("totalLatencyMs"), + } + data["usage"] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + data["served_by"] = inst.name + return JSONResponse(content=data) + finally: + if not ticket_transferred: + inst.in_flight = max(0, inst.in_flight - 1) + ticket.release() @app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)]) -async def internal_auto_login_start(): - assert lingma is not None - assert auto_login is not None +async def internal_auto_login_start(instance: str | None = None): + p = _require_pool() + target = None + if instance: + for inst in p.instances: + if inst.name == instance: + target = inst + break + if target is None: + raise HTTPException( + status_code=404, + detail={"error": {"message": f"instance not found: {instance}"}}, + ) + else: + target = p.pick() - status = await lingma.auth_status() + client = target.client + auto_login = target.auto_login + + status = await client.auth_status() if status and status.get("id"): - return {"ok": True, "state": "already_logged_in", "auth": status} + return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status} if settings.dedicated_domain_url: - current = await lingma.get_endpoint() - current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" - if current_ep != settings.dedicated_domain_url: - await lingma.update_endpoint(settings.dedicated_domain_url) + try: + current = await client.get_endpoint() + current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + if current_ep != settings.dedicated_domain_url: + await client.update_endpoint(settings.dedicated_domain_url) + except Exception as exc: + logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc) + + try: + login_url, _login_raw = await client.generate_login_url() + except Exception as exc: + logger.warning("[%s] generate_login_url failed: %s", target.name, exc) + raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) - login_url, login_raw = await lingma.generate_login_url() if not login_url: - raise HTTPException(status_code=500, detail={"error": {"message": "generate login url failed", "raw": login_raw}}) + raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) started = await auto_login.ensure_started(login_url) return { "ok": True, "state": "running" if started else "already_running", - "loginUrl": login_url, + "instance": target.name, "auto_login": auto_login.status(), } @app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)]) async def internal_auto_login_status(): - assert auto_login is not None - assert lingma is not None - return { - "ok": True, - "auto_login": auto_login.status(), - "auth": await lingma.auth_status(), - } + p = _require_pool() + out = [] + for inst in p.instances: + try: + auth = await inst.client.auth_status() + except Exception as exc: + auth = {"error": str(exc)} + out.append( + { + "instance": inst.name, + "auto_login": inst.auto_login.status(), + "auth": auth, + "state": inst.client.state, + } + ) + return {"ok": True, "instances": out} @app.get("/internal/stats", dependencies=[Depends(auth_guard)]) async def internal_stats(): - return {"ok": True, "stats": await stats_collector.snapshot()} + p = _require_pool() + return { + "ok": True, + "stats": await stats_collector.snapshot(), + "concurrency": chat_guard.stats(), + "pool": p.stats(), + } -@app.get("/metrics") +@app.get("/metrics", dependencies=[Depends(metrics_auth_guard)]) async def metrics(): - text = await stats_collector.prometheus_text() - return StreamingResponse(iter([text]), media_type="text/plain; version=0.0.4") + base = await stats_collector.prometheus_text() + lines = list(chat_guard.prometheus_lines()) + if pool is not None: + lines.extend(pool.prometheus_lines()) + extra = "\n".join(lines) + "\n" + return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4") diff --git a/app/openai_schema.py b/app/openai_schema.py index 7806015..33a22df 100644 --- a/app/openai_schema.py +++ b/app/openai_schema.py @@ -1,13 +1,22 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal from pydantic import BaseModel, Field +# Keep permissive: OpenAI clients routinely send list-of-parts (multi-modal) or None +# (for tool calls). We flatten to plain text downstream. +MessageContent = str | list[dict[str, Any]] | None + + class ChatMessage(BaseModel): - role: Literal["system", "user", "assistant", "tool"] - content: str + # OpenAI supports "developer" on newer API versions in addition to the classic set. + role: Literal["system", "user", "assistant", "tool", "developer", "function"] + content: MessageContent = None + name: str | None = None + tool_call_id: str | None = None + tool_calls: list[dict[str, Any]] | None = None class ChatCompletionsRequest(BaseModel): @@ -16,6 +25,11 @@ class ChatCompletionsRequest(BaseModel): stream: bool = False temperature: float | None = None top_p: float | None = None + max_tokens: int | None = None + user: str | None = None + stream_options: dict[str, Any] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: Any | None = None class ModelData(BaseModel): @@ -35,6 +49,7 @@ class ChatCompletionChoice(BaseModel): index: int = 0 finish_reason: str | None = "stop" message: dict = Field(default_factory=dict) + logprobs: Any | None = None class ChatCompletionResponse(BaseModel): @@ -43,3 +58,34 @@ class ChatCompletionResponse(BaseModel): created: int model: str choices: list[ChatCompletionChoice] + system_fingerprint: str | None = None + + +def flatten_content(content: MessageContent) -> str: + """Reduce OpenAI multi-part content to a plain string prompt for Lingma.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if not isinstance(item, dict): + parts.append(str(item)) + continue + t = item.get("type") + if t == "text": + text = item.get("text") or "" + if text: + parts.append(text) + elif t in ("image_url", "input_image"): + # Lingma 不支持多模态,降级成占位符,保留语义信号 + parts.append("[image]") + elif t == "input_audio": + parts.append("[audio]") + else: + text = item.get("text") or item.get("content") + if isinstance(text, str) and text: + parts.append(text) + return "\n".join(p for p in parts if p) + return str(content)