feat: M1+M2 gateway hardening and multi-instance pool

Behavior hardening (M1):
- Fix `_chat_streams` memory leak: pop_stream on completion, error, and
  client disconnect.
- Add WebSocket reconnect with state machine (stopped/starting/ready/
  reconnecting/failed/closed) and exponential backoff, so a Lingma
  restart no longer requires restarting the gateway.
- Lazy initialization: startup failure is non-fatal, first real request
  triggers retry, `/healthz` reflects readiness.
- Migrate FastAPI on_event to lifespan.
- Structured JSON logging with request_id ContextVar; `x-request-id`
  propagated to responses.
- SSE now sets `Cache-Control: no-cache`, `X-Accel-Buffering: no` to
  defeat proxy buffering.
- OpenAI schema compatibility: `content` accepts str | list[parts] | None,
  added `developer`/`function` roles, `tools/tool_choice/stream_options/
  user/max_tokens` fields, and `stream_options.include_usage` emits final
  usage chunk.
- `require_bearer` uses `hmac.compare_digest`; `/metrics` now requires
  Bearer when `METRICS_TOKEN` or `API_KEYS` are set.
- Python 3.10/3.11 `TimeoutError` vs `asyncio.TimeoutError` unified.
- Error responses no longer leak `auto_login.status()` details.

Backpressure (M2 / A2):
- New `InFlightGuard` with per-request ticket, queue + rejection
  accounting, `BackpressureRejected` raises 429 + `Retry-After` once
  `GATEWAY_QUEUE_TIMEOUT_SEC` elapses.
- Streaming ticket ownership transfers to the generator so CancelledError
  from client disconnect still releases the slot.
- `/internal/stats.concurrency` and `/metrics` expose in_flight/queued/
  accepted_total/rejected_total/max_in_flight.

Multi-instance pool (M2 / A1 + B3):
- New `LingmaPool` with N processes, each with its own workDir, socket
  port (dynamic when N>1), and `AutoLoginManager`.
- Account parser supports CSV (`u1:p1,u2:p2`) and JSON formats via
  `LINGMA_ACCOUNTS`; falls back to `LINGMA_USERNAME/LINGMA_PASSWORD` for
  backwards compatibility (N=1 keeps legacy paths/ports).
- Routing: sticky affinity by `user` / system-prompt hash, then
  least-in-flight, finally round-robin fallback for unhealthy pool.
- `/healthz` reports per-instance state and ready count.
- `/internal/stats.pool` and `/metrics` expose per-instance
  `gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`.
- `/internal/auto-login/start?instance=inst-N` targets a specific instance;
  `/internal/auto-login/status` lists all instances.

Compat notes:
- `.env.example` adds `METRICS_TOKEN`, `LOG_LEVEL`, `GATEWAY_MAX_IN_FLIGHT`,
  `GATEWAY_QUEUE_TIMEOUT_SEC`, `LINGMA_ACCOUNTS`, `LINGMA_INSTANCE_COUNT`.
- `.gitignore` cleaned up data/ duplication.
- Existing single-instance deployments keep working without config change.

Made-with: Cursor
This commit is contained in:
GitHub Actions
2026-04-18 07:40:32 +08:00
parent 6114c66aed
commit 707acc9005
11 changed files with 1360 additions and 222 deletions

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import hmac
from fastapi import HTTPException, Request, status
def require_bearer(request: Request, api_keys: list[str]) -> None:
if not api_keys:
return
def _extract_bearer(request: Request) -> str:
auth = request.headers.get("authorization", "")
if not auth.startswith("Bearer "):
raise HTTPException(
@@ -19,9 +18,22 @@ def require_bearer(request: Request, api_keys: list[str]) -> None:
}
},
)
return auth[len("Bearer ") :].strip()
token = auth[len("Bearer ") :].strip()
if token not in api_keys:
def _match_any(token: str, candidates: list[str]) -> bool:
for c in candidates:
if c and hmac.compare_digest(token, c):
return True
return False
def require_bearer(request: Request, api_keys: list[str]) -> None:
# Empty api_keys means auth is disabled (keeps the old behavior).
if not api_keys:
return
token = _extract_bearer(request)
if not _match_any(token, api_keys):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
@@ -32,3 +44,31 @@ def require_bearer(request: Request, api_keys: list[str]) -> None:
}
},
)
def require_metrics_access(
request: Request, api_keys: list[str], metrics_token: str
) -> None:
"""Allow metrics if any of: METRICS_TOKEN matches, or any API_KEYS match.
If neither METRICS_TOKEN nor API_KEYS are configured, metrics is public
(backwards compatible default).
"""
accepted: list[str] = []
if metrics_token:
accepted.append(metrics_token)
accepted.extend(api_keys)
if not accepted:
return
token = _extract_bearer(request)
if not _match_any(token, accepted):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"error": {
"message": "Invalid metrics token",
"type": "invalid_request_error",
"code": "invalid_api_key",
}
},
)

121
app/concurrency.py Normal file
View File

@@ -0,0 +1,121 @@
from __future__ import annotations
import asyncio
from .logging_config import get_logger
logger = get_logger("lingma_gateway.concurrency")
class BackpressureRejected(Exception):
"""Raised when a request cannot acquire an in-flight slot before timeout."""
def __init__(self, retry_after: float):
super().__init__(f"backpressure rejected, retry_after={retry_after:.1f}s")
self.retry_after = retry_after
class InFlightTicket:
"""Reference-counted handle for a single in-flight slot.
Release is idempotent so callers can defensively `release()` from multiple
cleanup paths (stream finally + outer exception handler) without worrying.
"""
__slots__ = ("_parent", "_released")
def __init__(self, parent: "InFlightGuard | None"):
self._parent = parent
self._released = False
def release(self) -> None:
if self._released or self._parent is None:
self._released = True
return
self._released = True
self._parent._on_release()
async def __aenter__(self) -> "InFlightTicket":
return self
async def __aexit__(self, *_exc) -> None:
self.release()
class InFlightGuard:
"""Async semaphore wrapper with queue/reject accounting and Prometheus hooks.
- `max_in_flight <= 0` disables limiting (back-compat, unlimited).
- `queue_timeout_sec` bounds how long a request may wait for a slot. On
timeout, `try_acquire()` raises `BackpressureRejected`.
"""
def __init__(self, max_in_flight: int, queue_timeout_sec: float):
self.max = max(0, int(max_in_flight))
self.queue_timeout = max(0.0, float(queue_timeout_sec))
self._sem: asyncio.Semaphore | None = (
asyncio.Semaphore(self.max) if self.max > 0 else None
)
self.in_flight = 0
self.queued = 0
self.accepted_total = 0
self.rejected_total = 0
async def try_acquire(self) -> InFlightTicket:
if self._sem is None:
self.in_flight += 1
self.accepted_total += 1
return InFlightTicket(parent=self)
self.queued += 1
try:
if self.queue_timeout <= 0:
await self._sem.acquire()
else:
try:
await asyncio.wait_for(self._sem.acquire(), timeout=self.queue_timeout)
except (asyncio.TimeoutError, TimeoutError):
self.rejected_total += 1
logger.warning(
"backpressure rejected: in_flight=%d queued=%d max=%d",
self.in_flight,
self.queued - 1,
self.max,
)
raise BackpressureRejected(retry_after=self.queue_timeout)
finally:
self.queued -= 1
self.in_flight += 1
self.accepted_total += 1
return InFlightTicket(parent=self)
def _on_release(self) -> None:
self.in_flight -= 1
if self._sem is not None:
self._sem.release()
def stats(self) -> dict:
return {
"max_in_flight": self.max,
"in_flight": self.in_flight,
"queued": self.queued,
"accepted_total": self.accepted_total,
"rejected_total": self.rejected_total,
"queue_timeout_sec": self.queue_timeout,
}
def prometheus_lines(self) -> list[str]:
return [
"# TYPE gateway_in_flight gauge",
f"gateway_in_flight {self.in_flight}",
"# TYPE gateway_queued gauge",
f"gateway_queued {self.queued}",
"# TYPE gateway_max_in_flight gauge",
f"gateway_max_in_flight {self.max}",
"# TYPE gateway_accepted_total counter",
f"gateway_accepted_total {self.accepted_total}",
"# TYPE gateway_rejected_total counter",
f"gateway_rejected_total {self.rejected_total}",
]

View File

@@ -1,8 +1,14 @@
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from pathlib import Path
from dataclasses import dataclass, field
@dataclass
class LingmaAccount:
username: str
password: str
@dataclass
@@ -10,6 +16,10 @@ class Settings:
host: str
port: int
api_keys: list[str]
metrics_token: str
log_level: str
gateway_max_in_flight: int
gateway_queue_timeout_sec: float
lingma_bin: str
lingma_work_dir: str
lingma_socket_port: int
@@ -22,8 +32,57 @@ class Settings:
auto_login_headless: bool
auto_login_timeout: int
auto_login_max_retry: int
lingma_username: str
lingma_password: str
accounts: list[LingmaAccount] = field(default_factory=list)
instance_count: int = 1
def _bool_env(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in {"1", "true", "yes", "on"}
def _parse_accounts(raw: str) -> list[LingmaAccount]:
"""Parse LINGMA_ACCOUNTS.
Accepted formats:
- JSON array: `[{"username":"u1","password":"p1"},{"username":"u2","password":"p2"}]`
- CSV: `u1:p1,u2:p2`
- Newlines: `u1:p1\nu2:p2`
Whitespace around entries is trimmed. Empty entries are ignored.
Passwords containing ':' are supported (only the first ':' is the separator).
"""
raw = (raw or "").strip()
if not raw:
return []
if raw.startswith("["):
try:
data = json.loads(raw)
except Exception:
return []
out: list[LingmaAccount] = []
if isinstance(data, list):
for item in data:
if isinstance(item, dict):
u = str(item.get("username", "")).strip()
p = str(item.get("password", "")).strip()
if u and p:
out.append(LingmaAccount(u, p))
return out
out: list[LingmaAccount] = []
for entry in raw.replace("\n", ",").split(","):
entry = entry.strip()
if not entry or ":" not in entry:
continue
u, p = entry.split(":", 1)
u, p = u.strip(), p.strip()
if u and p:
out.append(LingmaAccount(u, p))
return out
def load_settings() -> Settings:
@@ -33,10 +92,31 @@ def load_settings() -> Settings:
"LINGMA_WORK_DIR",
"/app/data/.lingma/vscode/sharedClientCache",
)
accounts = _parse_accounts(os.getenv("LINGMA_ACCOUNTS", ""))
if not accounts:
u = os.getenv("LINGMA_USERNAME", "").strip()
p = os.getenv("LINGMA_PASSWORD", "").strip()
if u and p:
accounts.append(LingmaAccount(u, p))
explicit_count = os.getenv("LINGMA_INSTANCE_COUNT", "").strip()
if explicit_count:
try:
instance_count = max(1, int(explicit_count))
except ValueError:
instance_count = len(accounts) or 1
else:
instance_count = max(1, len(accounts)) if accounts else 1
return Settings(
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8317")),
api_keys=api_keys,
metrics_token=os.getenv("METRICS_TOKEN", "").strip(),
log_level=os.getenv("LOG_LEVEL", "INFO").strip() or "INFO",
gateway_max_in_flight=int(os.getenv("GATEWAY_MAX_IN_FLIGHT", "4")),
gateway_queue_timeout_sec=float(os.getenv("GATEWAY_QUEUE_TIMEOUT_SEC", "30")),
lingma_bin=os.getenv("LINGMA_BIN", "/app/data/bin/Lingma"),
lingma_work_dir=work_dir,
lingma_socket_port=int(os.getenv("LINGMA_SOCKET_PORT", "36510")),
@@ -45,10 +125,10 @@ def load_settings() -> Settings:
default_model=os.getenv("DEFAULT_MODEL", "org_auto"),
default_ask_mode=os.getenv("DEFAULT_ASK_MODE", "chat"),
dedicated_domain_url=os.getenv("DEDICATED_DOMAIN_URL", "").strip(),
auto_login_enabled=os.getenv("AUTO_LOGIN_ENABLED", "true").lower() in {"1", "true", "yes", "on"},
auto_login_headless=os.getenv("AUTO_LOGIN_HEADLESS", "true").lower() in {"1", "true", "yes", "on"},
auto_login_enabled=_bool_env("AUTO_LOGIN_ENABLED", True),
auto_login_headless=_bool_env("AUTO_LOGIN_HEADLESS", True),
auto_login_timeout=int(os.getenv("AUTO_LOGIN_TIMEOUT", "180")),
auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")),
lingma_username=os.getenv("LINGMA_USERNAME", "").strip(),
lingma_password=os.getenv("LINGMA_PASSWORD", "").strip(),
accounts=accounts,
instance_count=instance_count,
)

View File

@@ -9,10 +9,23 @@ import subprocess
import time
import uuid
from pathlib import Path
from typing import AsyncIterator
from typing import AsyncIterator, Callable, Optional
import websockets
from .logging_config import get_logger
logger = get_logger("lingma_gateway.client")
# Some callers live on Python 3.10 where asyncio.TimeoutError is a distinct class,
# while 3.11+ unifies it with the builtin TimeoutError. Always catch both.
TIMEOUT_EXCEPTIONS: tuple[type[BaseException], ...] = (
asyncio.TimeoutError,
TimeoutError,
)
def _is_port_open(host: str, port: int, timeout_sec: float = 0.5) -> bool:
try:
@@ -79,23 +92,37 @@ def _parse_lsp_frames(buf: bytes):
class LspWsRpcClient:
def __init__(self, ws):
def __init__(self, ws, on_disconnect: Optional[Callable[[BaseException], None]] = None):
self.ws = ws
self._id = 1
self._pending: dict[int, asyncio.Future] = {}
self._send_lock = asyncio.Lock()
self._reader_task = None
self._reader_task: asyncio.Task | None = None
self._rx_buffer = b""
self._chat_streams: dict[str, dict] = {}
self._on_disconnect = on_disconnect
self._closed = False
async def start(self):
self._reader_task = asyncio.create_task(self._reader_loop())
async def close(self):
self._closed = True
if self._reader_task:
self._reader_task.cancel()
with contextlib.suppress(Exception):
await self._reader_task
# Abort any pending futures so callers fail fast instead of hanging.
for fut in self._pending.values():
if not fut.done():
fut.set_exception(ConnectionError("lingma client closed"))
self._pending.clear()
# Signal open streams to terminate.
for stream in self._chat_streams.values():
if not stream["done"].is_set():
stream["done"].set()
stream["chunks"].put_nowait(None)
self._chat_streams.clear()
async def _send(self, payload: dict):
async with self._send_lock:
@@ -127,10 +154,23 @@ class LspWsRpcClient:
except asyncio.CancelledError:
pass
except Exception as exc:
if not self._closed:
logger.warning("lingma reader loop terminated: %s", exc)
# Propagate failure to anyone waiting on an RPC.
for fut in self._pending.values():
if not fut.done():
fut.set_exception(exc)
self._pending.clear()
# Also unblock any in-flight chat streams so consumers exit.
for stream in self._chat_streams.values():
if not stream["done"].is_set():
stream["done"].set()
stream["chunks"].put_nowait(None)
if not self._closed and self._on_disconnect is not None:
try:
self._on_disconnect(exc)
except Exception:
logger.exception("on_disconnect callback failed")
async def _handle_server_message(self, msg: dict):
method = msg.get("method")
@@ -168,7 +208,7 @@ class LspWsRpcClient:
await self._send(payload)
try:
msg = await asyncio.wait_for(fut, timeout=timeout)
except TimeoutError:
except TIMEOUT_EXCEPTIONS:
self._pending.pop(rid, None)
raise TimeoutError(f"RPC timeout: {method}")
if "error" in msg:
@@ -189,8 +229,20 @@ class LspWsRpcClient:
"finish_at": None,
}
def pop_stream(self, request_id: str) -> None:
stream = self._chat_streams.pop(request_id, None)
if stream is None:
return
# Drain queue so no stray future gets stuck if the consumer bailed early.
if not stream["done"].is_set():
stream["done"].set()
with contextlib.suppress(Exception):
stream["chunks"].put_nowait(None)
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]:
stream = self._chat_streams[request_id]
stream = self._chat_streams.get(request_id)
if stream is None:
return
start = time.monotonic()
while True:
remain = timeout - (time.monotonic() - start)
@@ -218,6 +270,19 @@ class LspWsRpcClient:
class LingmaGatewayClient:
"""Owns the Lingma subprocess and the LSP-over-WS connection.
Adds a small state machine + reconnect loop so the gateway can survive Lingma
restarts and slow cold starts without bringing down the FastAPI app.
"""
STATE_STOPPED = "stopped"
STATE_STARTING = "starting"
STATE_READY = "ready"
STATE_RECONNECTING = "reconnecting"
STATE_FAILED = "failed"
STATE_CLOSED = "closed"
def __init__(
self,
lingma_bin: str,
@@ -227,7 +292,11 @@ class LingmaGatewayClient:
rpc_timeout: int,
default_model: str,
default_ask_mode: str,
*,
name: str = "lingma",
extra_info_paths: list[Path] | None = None,
):
self.name = name
self.lingma_bin = Path(lingma_bin)
self.work_dir = Path(work_dir)
self.socket_port = socket_port
@@ -235,19 +304,115 @@ class LingmaGatewayClient:
self.rpc_timeout = rpc_timeout
self.default_model = default_model
self.default_ask_mode = default_ask_mode
# Each pool instance should only look at its own workDir .info to avoid
# cross-instance clobbering via the shared ~/.lingma/.info path.
if extra_info_paths is None:
extra_info_paths = [Path.home() / ".lingma" / ".info"]
self._extra_info_paths = list(extra_info_paths)
self._rpc: LspWsRpcClient | None = None
self._ws = None
self._state = self.STATE_STOPPED
self._state_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._reconnect_task: asyncio.Task | None = None
self._last_error: str = ""
# ------------------------------------------------------------------ state
@property
def state(self) -> str:
return self._state
@property
def last_error(self) -> str:
return self._last_error
def _set_state(self, state: str, err: str = "") -> None:
if state != self._state:
logger.info("lingma client state %s -> %s", self._state, state, extra={"ctx_new_state": state})
self._state = state
if err:
self._last_error = err
if state == self.STATE_READY:
self._ready_event.set()
else:
self._ready_event.clear()
# -------------------------------------------------------------- lifecycle
async def start(self) -> None:
"""Initial start. Failure is non-fatal: ensure_ready() will retry later."""
try:
await self._connect(initial=True)
except Exception as exc:
self._set_state(self.STATE_FAILED, err=str(exc))
logger.exception("initial lingma start failed; will retry on demand")
async def close(self) -> None:
self._set_state(self.STATE_CLOSED)
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
with contextlib.suppress(Exception):
await self._reconnect_task
if self._rpc:
await self._rpc.close()
if self._ws:
with contextlib.suppress(Exception):
await self._ws.close()
async def ensure_ready(self, timeout: float | None = None) -> None:
"""Block until the RPC connection is usable, (re)connecting on demand."""
if self._state == self.STATE_CLOSED:
raise RuntimeError("lingma client is closed")
if self._state == self.STATE_READY and self._ws is not None:
return
async with self._state_lock:
if self._state == self.STATE_READY and self._ws is not None:
return
if self._state in (self.STATE_STOPPED, self.STATE_FAILED):
try:
await self._connect(initial=False)
return
except Exception as exc:
self._set_state(self.STATE_FAILED, err=str(exc))
raise
wait_timeout = timeout if timeout is not None else max(
30.0, float(self.startup_timeout) + 10.0
)
try:
await asyncio.wait_for(self._ready_event.wait(), timeout=wait_timeout)
except TIMEOUT_EXCEPTIONS:
raise RuntimeError(f"lingma not ready (state={self._state}, err={self._last_error})")
# --------------------------------------------------------------- connect
async def _connect(self, *, initial: bool) -> None:
self._set_state(self.STATE_STARTING)
async def start(self):
if not self.lingma_bin.exists():
raise FileNotFoundError(f"Lingma not found: {self.lingma_bin}")
if not _is_port_open("127.0.0.1", self.socket_port):
info_paths = [self.work_dir / ".info", *self._extra_info_paths]
# socket_port <= 0 is the pool-friendly "always spawn and read .info" mode.
port_prewarmed = self.socket_port > 0 and _is_port_open(
"127.0.0.1", self.socket_port
)
if not port_prewarmed:
self.work_dir.mkdir(parents=True, exist_ok=True)
# Remove stale info files from host-mounted workspace before boot.
for p in [self.work_dir / ".info", Path.home() / ".lingma" / ".info"]:
for p in info_paths:
with contextlib.suppress(Exception):
if p.exists():
p.unlink()
logger.info(
"[%s] spawning lingma: %s start --workDir %s",
self.name,
self.lingma_bin,
self.work_dir,
)
subprocess.Popen(
[str(self.lingma_bin), "start", "--workDir", str(self.work_dir)],
cwd=str(self.lingma_bin.parent),
@@ -255,13 +420,9 @@ class LingmaGatewayClient:
stderr=subprocess.DEVNULL,
start_new_session=True,
)
info, _, _ = _wait_info_any(
[self.work_dir / ".info", Path.home() / ".lingma" / ".info"],
timeout_sec=self.startup_timeout,
)
info, _, _ = _wait_info_any(info_paths, timeout_sec=self.startup_timeout)
self.socket_port = info
# Wait for socket to actually become connectable.
deadline = time.time() + self.startup_timeout
while time.time() < deadline:
if _is_port_open("127.0.0.1", self.socket_port, timeout_sec=0.3):
@@ -270,9 +431,19 @@ class LingmaGatewayClient:
else:
raise TimeoutError(f"Lingma socket not open on port {self.socket_port}")
# Close any stale ws/rpc before creating fresh ones (reconnect path).
if self._rpc is not None:
with contextlib.suppress(Exception):
await self._rpc.close()
self._rpc = None
if self._ws is not None:
with contextlib.suppress(Exception):
await self._ws.close()
self._ws = None
ws_url = f"ws://127.0.0.1:{self.socket_port}"
self._ws = await websockets.connect(ws_url, max_size=10 * 1024 * 1024)
self._rpc = LspWsRpcClient(self._ws)
self._rpc = LspWsRpcClient(self._ws, on_disconnect=self._on_disconnect)
await self._rpc.start()
await self._rpc.request(
"initialize",
@@ -286,32 +457,73 @@ class LingmaGatewayClient:
timeout=self.rpc_timeout,
)
await self._rpc.notify("initialized", {})
self._set_state(self.STATE_READY)
logger.info(
"[%s] lingma ready on port %d (initial=%s)",
self.name,
self.socket_port,
initial,
)
async def close(self):
if self._rpc:
await self._rpc.close()
if self._ws:
await self._ws.close()
def _on_disconnect(self, exc: BaseException) -> None:
if self._state == self.STATE_CLOSED:
return
self._set_state(self.STATE_RECONNECTING, err=str(exc))
if self._reconnect_task and not self._reconnect_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._reconnect_task = loop.create_task(self._reconnect_loop())
async def _reconnect_loop(self) -> None:
backoff = 1.0
max_backoff = 30.0
max_attempts = 20
for attempt in range(1, max_attempts + 1):
if self._state == self.STATE_CLOSED:
return
await asyncio.sleep(backoff)
try:
async with self._state_lock:
await self._connect(initial=False)
logger.info("lingma reconnected after %d attempt(s)", attempt)
return
except Exception as exc:
self._last_error = str(exc)
logger.warning("lingma reconnect attempt %d failed: %s", attempt, exc)
backoff = min(backoff * 2, max_backoff)
self._set_state(self.STATE_FAILED, err="reconnect exhausted")
# ------------------------------------------------------------------ RPC
@property
def rpc(self) -> LspWsRpcClient:
if self._rpc is None:
raise RuntimeError("Lingma RPC not initialized")
raise RuntimeError(f"Lingma RPC not initialized (state={self._state})")
return self._rpc
async def auth_status(self):
await self.ensure_ready()
return await self.rpc.request("auth/status", {}, timeout=self.rpc_timeout)
async def query_models(self):
await self.ensure_ready()
return await self.rpc.request("config/queryModels", {}, timeout=self.rpc_timeout)
async def get_endpoint(self):
await self.ensure_ready()
return await self.rpc.request("config/getEndpoint", {}, timeout=self.rpc_timeout)
async def update_endpoint(self, endpoint: str):
return await self.rpc.request("config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout)
await self.ensure_ready()
return await self.rpc.request(
"config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout
)
async def generate_login_url(self):
await self.ensure_ready()
result = await self.rpc.request("login/generateUrl", {}, timeout=self.rpc_timeout)
if isinstance(result, str):
return result, {"raw": result}
@@ -322,6 +534,8 @@ class LingmaGatewayClient:
return "", result
return "", {"raw": result}
# ------------------------------------------------------------------ chat
def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str):
session_type = "developer" if ask_mode == "agent" else "chat"
return {
@@ -355,17 +569,24 @@ class LingmaGatewayClient:
}
async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict:
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
self.rpc.create_stream(request_id)
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except (TimeoutError, asyncio.TimeoutError):
pass
async for _ in self.rpc.consume_stream(request_id, timeout=max(20.0, self.rpc_timeout + 20.0)):
pass
result = self.rpc.get_stream_result(request_id)
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except TIMEOUT_EXCEPTIONS:
# chat/ask often returns nothing until chat/finish arrives; tolerate.
pass
async for _ in self.rpc.consume_stream(
request_id, timeout=max(20.0, self.rpc_timeout + 20.0)
):
pass
result = self.rpc.get_stream_result(request_id)
finally:
self.rpc.pop_stream(request_id)
finish = result.get("finish") or {}
result["requestId"] = request_id
result["sessionId"] = finish.get("sessionId") or session_id
@@ -374,13 +595,20 @@ class LingmaGatewayClient:
return result
async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]:
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
self.rpc.create_stream(request_id)
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except (TimeoutError, asyncio.TimeoutError):
pass
async for chunk in self.rpc.consume_stream(request_id, timeout=max(20.0, self.rpc_timeout + 40.0)):
yield chunk
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except TIMEOUT_EXCEPTIONS:
pass
async for chunk in self.rpc.consume_stream(
request_id, timeout=max(20.0, self.rpc_timeout + 40.0)
):
yield chunk
finally:
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
self.rpc.pop_stream(request_id)

275
app/lingma_pool.py Normal file
View File

@@ -0,0 +1,275 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from pathlib import Path
from .auto_login import AutoLoginManager
from .config import LingmaAccount
from .lingma_client import LingmaGatewayClient
from .logging_config import get_logger
logger = get_logger("lingma_gateway.pool")
@dataclass
class InstanceConfig:
index: int
name: str
work_dir: str
socket_port: int
account: LingmaAccount
class PoolInstance:
"""A single Lingma process + its auto_login + in-flight counter."""
__slots__ = ("cfg", "client", "auto_login", "in_flight")
def __init__(
self,
cfg: InstanceConfig,
client: LingmaGatewayClient,
auto_login: AutoLoginManager,
):
self.cfg = cfg
self.client = client
self.auto_login = auto_login
self.in_flight = 0
@property
def name(self) -> str:
return self.cfg.name
@property
def healthy(self) -> bool:
return self.client.state == LingmaGatewayClient.STATE_READY
class LingmaPool:
"""N-Lingma process pool with least-in-flight + affinity routing.
For N=1 this degenerates into the original single-client setup, preserving
backwards compatibility with `LINGMA_USERNAME/LINGMA_PASSWORD`-only deploys.
"""
def __init__(self, instances: list[PoolInstance]):
if not instances:
raise RuntimeError("LingmaPool requires at least 1 instance")
self._instances: list[PoolInstance] = instances
self._rr_counter = 0
@classmethod
def build(
cls,
*,
lingma_bin: str,
base_work_dir: str,
legacy_socket_port: int,
startup_timeout: int,
rpc_timeout: int,
default_model: str,
default_ask_mode: str,
accounts: list[LingmaAccount],
instance_count: int,
auto_login_headless: bool,
auto_login_timeout: int,
auto_login_max_retry: int,
verify_timeout_sec: int | None = None,
) -> "LingmaPool":
"""Materialize N PoolInstances.
Single-instance (N=1) uses the legacy workDir and LINGMA_SOCKET_PORT so
existing deployments keep their state after upgrade. N>1 derives per-instance
workDirs under `<base_work_dir>/../pool/inst-<i>` and uses dynamic ports.
"""
if instance_count < 1:
instance_count = 1
resolved_accounts: list[LingmaAccount] = []
for i in range(instance_count):
if accounts:
resolved_accounts.append(accounts[i % len(accounts)])
else:
resolved_accounts.append(LingmaAccount(username="", password=""))
if instance_count > len(accounts) and accounts:
logger.warning(
"instance_count=%d exceeds unique accounts=%d; accounts will be reused",
instance_count,
len(accounts),
)
base_dir = Path(base_work_dir)
# Put per-instance workDirs under `<data>/.lingma/pool/inst-<i>`.
# Walk up past the vscode/sharedClientCache layout if present.
pool_root = base_dir
for _ in range(3):
if pool_root.name == ".lingma":
break
if pool_root.parent == pool_root:
break
pool_root = pool_root.parent
pool_root = pool_root / "pool"
instances: list[PoolInstance] = []
for i, acc in enumerate(resolved_accounts):
if instance_count == 1:
work_dir = str(base_dir)
socket_port = legacy_socket_port
extra_info: list[Path] | None = None
else:
work_dir = str(pool_root / f"inst-{i}")
socket_port = 0
# In pool mode each instance reads only its own workDir .info to
# avoid the shared ~/.lingma/.info race between instances.
extra_info = []
name = f"inst-{i}"
client = LingmaGatewayClient(
lingma_bin=lingma_bin,
work_dir=work_dir,
socket_port=socket_port,
startup_timeout=startup_timeout,
rpc_timeout=rpc_timeout,
default_model=default_model,
default_ask_mode=default_ask_mode,
name=name,
extra_info_paths=extra_info,
)
def _make_verify(_client: LingmaGatewayClient):
async def _verify() -> bool:
try:
st = await _client.auth_status()
except Exception:
return False
return bool(st and st.get("id"))
return _verify
auto_login = AutoLoginManager(
username=acc.username,
password=acc.password,
headless=auto_login_headless,
timeout_sec=auto_login_timeout,
max_retry=auto_login_max_retry,
verify_logged_in=_make_verify(client),
verify_timeout_sec=verify_timeout_sec
or max(30, min(180, auto_login_timeout)),
debug_dir=f"/tmp/lingma-auto-login/{name}",
)
cfg = InstanceConfig(
index=i,
name=name,
work_dir=work_dir,
socket_port=socket_port,
account=acc,
)
instances.append(PoolInstance(cfg, client, auto_login))
return cls(instances)
# -------------------------------------------------------------- lifecycle
async def start(self) -> None:
"""Start all instances sequentially.
Sequential startup avoids racing on the shared ~/.lingma/.info file (for
pool-mode we skip it anyway, but Lingma may still write there internally)
and keeps docker logs readable. Failures are non-fatal; per-instance
reconnect loops will take over.
"""
for inst in self._instances:
logger.info(
"pool starting %s (workDir=%s port=%d account=%s)",
inst.name,
inst.cfg.work_dir,
inst.cfg.socket_port,
inst.cfg.account.username or "<empty>",
)
try:
await inst.client.start()
except Exception as exc:
logger.warning("pool start %s failed: %s", inst.name, exc)
async def close(self) -> None:
tasks = [asyncio.create_task(inst.client.close()) for inst in self._instances]
for t in tasks:
try:
await t
except Exception:
pass
# -------------------------------------------------------------- inspection
@property
def instances(self) -> list[PoolInstance]:
return list(self._instances)
def size(self) -> int:
return len(self._instances)
def stats(self) -> list[dict]:
return [
{
"index": inst.cfg.index,
"name": inst.name,
"state": inst.client.state,
"last_error": inst.client.last_error,
"in_flight": inst.in_flight,
"work_dir": inst.cfg.work_dir,
"socket_port": inst.cfg.socket_port,
"username": inst.cfg.account.username,
"auto_login": inst.auto_login.status(),
}
for inst in self._instances
]
def prometheus_lines(self) -> list[str]:
lines: list[str] = [
"# TYPE gateway_pool_instance_in_flight gauge",
"# TYPE gateway_pool_instance_ready gauge",
]
for inst in self._instances:
lbl = f'name="{inst.name}",idx="{inst.cfg.index}"'
lines.append(f"gateway_pool_instance_in_flight{{{lbl}}} {inst.in_flight}")
lines.append(
f"gateway_pool_instance_ready{{{lbl}}} {1 if inst.healthy else 0}"
)
return lines
# -------------------------------------------------------------- selection
def pick(self, affinity_key: str | None = None) -> PoolInstance:
"""Pick an instance for a request.
Preference order:
1. Sticky affinity if `affinity_key` is provided and the bucket is healthy.
2. Least-in-flight among healthy instances.
3. Round-robin fallback when nothing is healthy (lazy-start will kick in).
"""
if not self._instances:
raise RuntimeError("lingma pool is empty")
healthy = [i for i in self._instances if i.healthy]
if affinity_key:
bucket = self._instances[
abs(hash(affinity_key)) % len(self._instances)
]
if bucket.healthy:
return bucket
if healthy:
return min(healthy, key=lambda x: (x.in_flight, x.cfg.index))
# Nothing healthy. Fall back to round-robin so every instance gets a
# chance to reconnect via ensure_ready().
idx = self._rr_counter % len(self._instances)
self._rr_counter += 1
return self._instances[idx]

56
app/logging_config.py Normal file
View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import contextvars
import json
import logging
import sys
import time
request_id_var: contextvars.ContextVar[str] = contextvars.ContextVar(
"request_id", default="-"
)
class _JsonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
ts = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(record.created))
data: dict = {
"ts": f"{ts}.{int(record.msecs):03d}Z",
"level": record.levelname,
"logger": record.name,
"msg": record.getMessage(),
"request_id": request_id_var.get(),
}
if record.exc_info:
data["exc"] = self.formatException(record.exc_info)
for key, val in record.__dict__.items():
if key.startswith("ctx_"):
data[key[4:]] = val
return json.dumps(data, ensure_ascii=False)
def configure_logging(level: str = "INFO") -> None:
level = (level or "INFO").upper()
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(_JsonFormatter())
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(level)
# Align uvicorn access/error logs with our JSON formatter.
for name in ("uvicorn", "uvicorn.error", "uvicorn.access", "fastapi"):
lg = logging.getLogger(name)
lg.handlers.clear()
lg.propagate = True
lg.setLevel(level)
# Trim noisy libraries.
logging.getLogger("websockets").setLevel("WARNING")
logging.getLogger("websockets.client").setLevel("WARNING")
def get_logger(name: str = "lingma_gateway") -> logging.Logger:
return logging.getLogger(name)

View File

@@ -1,16 +1,20 @@
from __future__ import annotations
import asyncio
import hashlib
import json
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from .auto_login import AutoLoginManager
from .auth import require_bearer
from .auth import require_bearer, require_metrics_access
from .concurrency import BackpressureRejected, InFlightGuard
from .config import Settings, load_settings
from .lingma_client import LingmaGatewayClient
from .lingma_pool import LingmaPool, PoolInstance
from .logging_config import configure_logging, get_logger, request_id_var
from .model_map import build_model_name_map, flatten_model_keys, resolve_model
from .openai_schema import (
ChatCompletionChoice,
@@ -18,107 +22,219 @@ from .openai_schema import (
ChatCompletionsRequest,
ModelData,
ModelsResponse,
flatten_content,
)
from .stats import StatsCollector, estimate_tokens
app = FastAPI(title="Lingma OpenAI Gateway", version="0.1.0")
settings: Settings = load_settings()
lingma: LingmaGatewayClient | None = None
auto_login: AutoLoginManager | None = None
configure_logging(settings.log_level)
logger = get_logger("lingma_gateway")
pool: LingmaPool | None = None
stats_collector = StatsCollector()
chat_guard = InFlightGuard(
max_in_flight=settings.gateway_max_in_flight,
queue_timeout_sec=settings.gateway_queue_timeout_sec,
)
def _require_pool() -> LingmaPool:
if pool is None:
raise HTTPException(
status_code=503,
detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}},
)
return pool
@asynccontextmanager
async def lifespan(_app: FastAPI):
global pool
pool = LingmaPool.build(
lingma_bin=settings.lingma_bin,
base_work_dir=settings.lingma_work_dir,
legacy_socket_port=settings.lingma_socket_port,
startup_timeout=settings.lingma_startup_timeout,
rpc_timeout=settings.lingma_rpc_timeout,
default_model=settings.default_model,
default_ask_mode=settings.default_ask_mode,
accounts=settings.accounts,
instance_count=settings.instance_count,
auto_login_headless=settings.auto_login_headless,
auto_login_timeout=settings.auto_login_timeout,
auto_login_max_retry=settings.auto_login_max_retry,
)
logger.info(
"gateway startup: pool_size=%d max_in_flight=%d",
pool.size(),
settings.gateway_max_in_flight,
)
await pool.start()
try:
yield
finally:
if pool is not None:
await pool.close()
app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan)
@app.middleware("http")
async def request_id_middleware(request: Request, call_next):
req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}"
token = request_id_var.set(req_id)
start = time.monotonic()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["x-request-id"] = req_id
return response
finally:
elapsed_ms = int((time.monotonic() - start) * 1000)
logger.info(
"http %s %s -> %s in %dms",
request.method,
request.url.path,
status_code,
elapsed_ms,
extra={
"ctx_method": request.method,
"ctx_path": request.url.path,
"ctx_status": status_code,
"ctx_elapsed_ms": elapsed_ms,
},
)
request_id_var.reset(token)
def auth_guard(request: Request):
require_bearer(request, settings.api_keys)
async def _is_logged_in() -> bool:
assert lingma is not None
st = await lingma.auth_status()
return bool(st and st.get("id"))
@app.on_event("startup")
async def on_startup():
global lingma, auto_login
lingma = LingmaGatewayClient(
lingma_bin=settings.lingma_bin,
work_dir=settings.lingma_work_dir,
socket_port=settings.lingma_socket_port,
startup_timeout=settings.lingma_startup_timeout,
rpc_timeout=settings.lingma_rpc_timeout,
default_model=settings.default_model,
default_ask_mode=settings.default_ask_mode,
)
await lingma.start()
auto_login = AutoLoginManager(
username=settings.lingma_username,
password=settings.lingma_password,
headless=settings.auto_login_headless,
timeout_sec=settings.auto_login_timeout,
max_retry=settings.auto_login_max_retry,
verify_logged_in=_is_logged_in,
verify_timeout_sec=max(30, min(180, settings.auto_login_timeout)),
)
@app.on_event("shutdown")
async def on_shutdown():
if lingma:
await lingma.close()
def metrics_auth_guard(request: Request):
require_metrics_access(request, settings.api_keys, settings.metrics_token)
@app.get("/healthz")
async def healthz():
return {"ok": True, "time": int(time.time())}
if pool is None:
return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"}
insts = pool.stats()
ready = sum(1 for i in insts if i["state"] == "ready")
return {
"ok": ready > 0,
"time": int(time.time()),
"pool_size": len(insts),
"pool_ready": ready,
"instances": [
{"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]}
for i in insts
],
}
async def _ensure_logged_in_or_auto_login() -> dict:
assert lingma is not None
status = await lingma.auth_status()
async def _ensure_instance_logged_in(inst: PoolInstance) -> dict:
client = inst.client
auto_login = inst.auto_login
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc)
raise HTTPException(
status_code=503,
detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}},
)
if status and status.get("id"):
return status
if not settings.auto_login_enabled:
raise HTTPException(status_code=401, detail={"error": {"message": "Lingma not logged in"}})
if settings.dedicated_domain_url:
current = await lingma.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await lingma.update_endpoint(settings.dedicated_domain_url)
login_url, login_raw = await lingma.generate_login_url()
if not login_url:
raise HTTPException(
status_code=500,
detail={"error": {"message": f"generate login url failed: {login_raw}"}},
status_code=401,
detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}},
)
if settings.dedicated_domain_url:
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", inst.name, exc)
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
if not login_url:
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
assert auto_login is not None
await auto_login.ensure_started(login_url)
try:
await auto_login.wait_done(timeout=settings.auto_login_timeout + 20)
except Exception:
pass
except Exception as exc:
logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc)
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc)
status = None
status = await lingma.auth_status()
if status and status.get("id"):
return status
logger.warning(
"[%s] auto login did not result in a logged-in session: %s",
inst.name,
auto_login.status(),
)
raise HTTPException(
status_code=401,
detail={"error": {"message": "Lingma auto login failed", "auto_login": auto_login.status()}},
detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}},
)
def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
"""Derive a stable affinity key so that follow-ups go to the same instance.
Priority: explicit `user` > hash of the first/system message.
"""
if req.user:
return req.user.strip() or None
for m in req.messages:
if m.role == "system":
text = flatten_content(m.content)
if text:
return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
if req.messages:
first = req.messages[0]
text = flatten_content(first.content)
if text:
return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
return None
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
async def v1_models():
assert lingma is not None
await _ensure_logged_in_or_auto_login()
p = _require_pool()
inst = p.pick()
await _ensure_instance_logged_in(inst)
await stats_collector.inc_models()
models = await lingma.query_models()
models = await inst.client.query_models()
keys = flatten_model_keys(models)
name_map = build_model_name_map(models)
resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys])
@@ -126,20 +242,32 @@ async def v1_models():
def _messages_to_prompt(messages: list[dict]) -> str:
parts = []
parts: list[str] = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
parts.append(f"[{role}] {content}")
text = flatten_content(m.get("content"))
if not text and m.get("tool_calls"):
text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}"
if not text:
continue
parts.append(f"[{role}] {text}")
return "\n".join(parts).strip()
def _include_usage(stream_options: dict | None) -> bool:
if not isinstance(stream_options, dict):
return False
return bool(stream_options.get("include_usage"))
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest):
assert lingma is not None
await _ensure_logged_in_or_auto_login()
p = _require_pool()
affinity = _affinity_key_for(req)
inst = p.pick(affinity_key=affinity)
await _ensure_instance_logged_in(inst)
models = await lingma.query_models()
models = await inst.client.query_models()
available = flatten_model_keys(models)
name_map = build_model_name_map(models)
model = resolve_model(req.model, available, settings.default_model, name_map)
@@ -150,142 +278,270 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
if not prompt:
raise HTTPException(status_code=400, detail={"error": {"message": "messages is empty"}})
raise HTTPException(
status_code=400,
detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}},
)
prompt_tokens = estimate_tokens(prompt)
include_usage = _include_usage(req.stream_options)
if req.stream:
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
# Backpressure: acquire a slot *after* the cheap validation but before any
# upstream call. This ensures we reject quickly when saturated.
try:
ticket = await chat_guard.try_acquire()
except BackpressureRejected as exc:
retry_after = max(1, int(exc.retry_after))
logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after)
raise HTTPException(
status_code=429,
detail={
"error": {
"message": "Too many in-flight requests, please retry later",
"type": "rate_limit_error",
"code": "backpressure",
}
},
headers={"Retry-After": str(retry_after)},
)
async def event_stream():
success = False
try:
async for chunk in lingma.chat_stream(prompt, model, ask_mode):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
inst.in_flight += 1
logger.info(
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
inst.name,
model,
ask_mode,
req.stream,
prompt_tokens,
extra={
"ctx_instance": inst.name,
"ctx_model": model,
"ctx_ask_mode": ask_mode,
"ctx_stream": req.stream,
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": affinity,
},
)
ticket_transferred = False
try:
if req.stream:
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
async def event_stream(_ticket=ticket, _inst=inst):
success = False
try:
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None,
}
],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
success = True
finally:
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens_holder["n"],
)
if include_usage:
usage_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens_holder["n"],
"total_tokens": prompt_tokens + completion_tokens_holder["n"],
},
}
yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
yield "data: [DONE]\n\n"
success = True
except asyncio.CancelledError:
logger.info("chat.stream cancelled by client (inst=%s)", _inst.name)
raise
except Exception as exc:
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
finally:
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens_holder["n"],
)
_inst.in_flight = max(0, _inst.in_flight - 1)
_ticket.release()
try:
result = await lingma.chat_complete(prompt, model, ask_mode)
except Exception:
ticket_transferred = True
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
try:
result = await inst.client.chat_complete(prompt, model, ask_mode)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
await stats_collector.record_chat(
stream=False,
success=False,
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
raise HTTPException(
status_code=502,
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
)
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=0,
completion_tokens=completion_tokens,
)
raise
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""},
)
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
"total_ms": result.get("totalLatencyMs"),
}
data["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return JSONResponse(content=data)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""},
)
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
"total_ms": result.get("totalLatencyMs"),
}
data["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
data["served_by"] = inst.name
return JSONResponse(content=data)
finally:
if not ticket_transferred:
inst.in_flight = max(0, inst.in_flight - 1)
ticket.release()
@app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)])
async def internal_auto_login_start():
assert lingma is not None
assert auto_login is not None
async def internal_auto_login_start(instance: str | None = None):
p = _require_pool()
target = None
if instance:
for inst in p.instances:
if inst.name == instance:
target = inst
break
if target is None:
raise HTTPException(
status_code=404,
detail={"error": {"message": f"instance not found: {instance}"}},
)
else:
target = p.pick()
status = await lingma.auth_status()
client = target.client
auto_login = target.auto_login
status = await client.auth_status()
if status and status.get("id"):
return {"ok": True, "state": "already_logged_in", "auth": status}
return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status}
if settings.dedicated_domain_url:
current = await lingma.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await lingma.update_endpoint(settings.dedicated_domain_url)
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", target.name, exc)
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
login_url, login_raw = await lingma.generate_login_url()
if not login_url:
raise HTTPException(status_code=500, detail={"error": {"message": "generate login url failed", "raw": login_raw}})
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
started = await auto_login.ensure_started(login_url)
return {
"ok": True,
"state": "running" if started else "already_running",
"loginUrl": login_url,
"instance": target.name,
"auto_login": auto_login.status(),
}
@app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)])
async def internal_auto_login_status():
assert auto_login is not None
assert lingma is not None
return {
"ok": True,
"auto_login": auto_login.status(),
"auth": await lingma.auth_status(),
}
p = _require_pool()
out = []
for inst in p.instances:
try:
auth = await inst.client.auth_status()
except Exception as exc:
auth = {"error": str(exc)}
out.append(
{
"instance": inst.name,
"auto_login": inst.auto_login.status(),
"auth": auth,
"state": inst.client.state,
}
)
return {"ok": True, "instances": out}
@app.get("/internal/stats", dependencies=[Depends(auth_guard)])
async def internal_stats():
return {"ok": True, "stats": await stats_collector.snapshot()}
p = _require_pool()
return {
"ok": True,
"stats": await stats_collector.snapshot(),
"concurrency": chat_guard.stats(),
"pool": p.stats(),
}
@app.get("/metrics")
@app.get("/metrics", dependencies=[Depends(metrics_auth_guard)])
async def metrics():
text = await stats_collector.prometheus_text()
return StreamingResponse(iter([text]), media_type="text/plain; version=0.0.4")
base = await stats_collector.prometheus_text()
lines = list(chat_guard.prometheus_lines())
if pool is not None:
lines.extend(pool.prometheus_lines())
extra = "\n".join(lines) + "\n"
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")

View File

@@ -1,13 +1,22 @@
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
# Keep permissive: OpenAI clients routinely send list-of-parts (multi-modal) or None
# (for tool calls). We flatten to plain text downstream.
MessageContent = str | list[dict[str, Any]] | None
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant", "tool"]
content: str
# OpenAI supports "developer" on newer API versions in addition to the classic set.
role: Literal["system", "user", "assistant", "tool", "developer", "function"]
content: MessageContent = None
name: str | None = None
tool_call_id: str | None = None
tool_calls: list[dict[str, Any]] | None = None
class ChatCompletionsRequest(BaseModel):
@@ -16,6 +25,11 @@ class ChatCompletionsRequest(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
max_tokens: int | None = None
user: str | None = None
stream_options: dict[str, Any] | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: Any | None = None
class ModelData(BaseModel):
@@ -35,6 +49,7 @@ class ChatCompletionChoice(BaseModel):
index: int = 0
finish_reason: str | None = "stop"
message: dict = Field(default_factory=dict)
logprobs: Any | None = None
class ChatCompletionResponse(BaseModel):
@@ -43,3 +58,34 @@ class ChatCompletionResponse(BaseModel):
created: int
model: str
choices: list[ChatCompletionChoice]
system_fingerprint: str | None = None
def flatten_content(content: MessageContent) -> str:
"""Reduce OpenAI multi-part content to a plain string prompt for Lingma."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if not isinstance(item, dict):
parts.append(str(item))
continue
t = item.get("type")
if t == "text":
text = item.get("text") or ""
if text:
parts.append(text)
elif t in ("image_url", "input_image"):
# Lingma 不支持多模态,降级成占位符,保留语义信号
parts.append("[image]")
elif t == "input_audio":
parts.append("[audio]")
else:
text = item.get("text") or item.get("content")
if isinstance(text, str) and text:
parts.append(text)
return "\n".join(p for p in parts if p)
return str(content)