Files
lingma-openai-gateway/app/main.py
GitHub Actions 707acc9005 feat: M1+M2 gateway hardening and multi-instance pool
Behavior hardening (M1):
- Fix `_chat_streams` memory leak: pop_stream on completion, error, and
  client disconnect.
- Add WebSocket reconnect with state machine (stopped/starting/ready/
  reconnecting/failed/closed) and exponential backoff, so a Lingma
  restart no longer requires restarting the gateway.
- Lazy initialization: startup failure is non-fatal, first real request
  triggers retry, `/healthz` reflects readiness.
- Migrate FastAPI on_event to lifespan.
- Structured JSON logging with request_id ContextVar; `x-request-id`
  propagated to responses.
- SSE now sets `Cache-Control: no-cache`, `X-Accel-Buffering: no` to
  defeat proxy buffering.
- OpenAI schema compatibility: `content` accepts str | list[parts] | None,
  added `developer`/`function` roles, `tools/tool_choice/stream_options/
  user/max_tokens` fields, and `stream_options.include_usage` emits final
  usage chunk.
- `require_bearer` uses `hmac.compare_digest`; `/metrics` now requires
  Bearer when `METRICS_TOKEN` or `API_KEYS` are set.
- Python 3.10/3.11 `TimeoutError` vs `asyncio.TimeoutError` unified.
- Error responses no longer leak `auto_login.status()` details.

Backpressure (M2 / A2):
- New `InFlightGuard` with per-request ticket, queue + rejection
  accounting, `BackpressureRejected` raises 429 + `Retry-After` once
  `GATEWAY_QUEUE_TIMEOUT_SEC` elapses.
- Streaming ticket ownership transfers to the generator so CancelledError
  from client disconnect still releases the slot.
- `/internal/stats.concurrency` and `/metrics` expose in_flight/queued/
  accepted_total/rejected_total/max_in_flight.

Multi-instance pool (M2 / A1 + B3):
- New `LingmaPool` with N processes, each with its own workDir, socket
  port (dynamic when N>1), and `AutoLoginManager`.
- Account parser supports CSV (`u1:p1,u2:p2`) and JSON formats via
  `LINGMA_ACCOUNTS`; falls back to `LINGMA_USERNAME/LINGMA_PASSWORD` for
  backwards compatibility (N=1 keeps legacy paths/ports).
- Routing: sticky affinity by `user` / system-prompt hash, then
  least-in-flight, finally round-robin fallback for unhealthy pool.
- `/healthz` reports per-instance state and ready count.
- `/internal/stats.pool` and `/metrics` expose per-instance
  `gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`.
- `/internal/auto-login/start?instance=inst-N` targets a specific instance;
  `/internal/auto-login/status` lists all instances.

Compat notes:
- `.env.example` adds `METRICS_TOKEN`, `LOG_LEVEL`, `GATEWAY_MAX_IN_FLIGHT`,
  `GATEWAY_QUEUE_TIMEOUT_SEC`, `LINGMA_ACCOUNTS`, `LINGMA_INSTANCE_COUNT`.
- `.gitignore` cleaned up data/ duplication.
- Existing single-instance deployments keep working without config change.

Made-with: Cursor
2026-04-18 07:40:32 +08:00

548 lines
19 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import json
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from .auth import require_bearer, require_metrics_access
from .concurrency import BackpressureRejected, InFlightGuard
from .config import Settings, load_settings
from .lingma_pool import LingmaPool, PoolInstance
from .logging_config import configure_logging, get_logger, request_id_var
from .model_map import build_model_name_map, flatten_model_keys, resolve_model
from .openai_schema import (
ChatCompletionChoice,
ChatCompletionResponse,
ChatCompletionsRequest,
ModelData,
ModelsResponse,
flatten_content,
)
from .stats import StatsCollector, estimate_tokens
settings: Settings = load_settings()
configure_logging(settings.log_level)
logger = get_logger("lingma_gateway")
pool: LingmaPool | None = None
stats_collector = StatsCollector()
chat_guard = InFlightGuard(
max_in_flight=settings.gateway_max_in_flight,
queue_timeout_sec=settings.gateway_queue_timeout_sec,
)
def _require_pool() -> LingmaPool:
if pool is None:
raise HTTPException(
status_code=503,
detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}},
)
return pool
@asynccontextmanager
async def lifespan(_app: FastAPI):
global pool
pool = LingmaPool.build(
lingma_bin=settings.lingma_bin,
base_work_dir=settings.lingma_work_dir,
legacy_socket_port=settings.lingma_socket_port,
startup_timeout=settings.lingma_startup_timeout,
rpc_timeout=settings.lingma_rpc_timeout,
default_model=settings.default_model,
default_ask_mode=settings.default_ask_mode,
accounts=settings.accounts,
instance_count=settings.instance_count,
auto_login_headless=settings.auto_login_headless,
auto_login_timeout=settings.auto_login_timeout,
auto_login_max_retry=settings.auto_login_max_retry,
)
logger.info(
"gateway startup: pool_size=%d max_in_flight=%d",
pool.size(),
settings.gateway_max_in_flight,
)
await pool.start()
try:
yield
finally:
if pool is not None:
await pool.close()
app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan)
@app.middleware("http")
async def request_id_middleware(request: Request, call_next):
req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}"
token = request_id_var.set(req_id)
start = time.monotonic()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["x-request-id"] = req_id
return response
finally:
elapsed_ms = int((time.monotonic() - start) * 1000)
logger.info(
"http %s %s -> %s in %dms",
request.method,
request.url.path,
status_code,
elapsed_ms,
extra={
"ctx_method": request.method,
"ctx_path": request.url.path,
"ctx_status": status_code,
"ctx_elapsed_ms": elapsed_ms,
},
)
request_id_var.reset(token)
def auth_guard(request: Request):
require_bearer(request, settings.api_keys)
def metrics_auth_guard(request: Request):
require_metrics_access(request, settings.api_keys, settings.metrics_token)
@app.get("/healthz")
async def healthz():
if pool is None:
return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"}
insts = pool.stats()
ready = sum(1 for i in insts if i["state"] == "ready")
return {
"ok": ready > 0,
"time": int(time.time()),
"pool_size": len(insts),
"pool_ready": ready,
"instances": [
{"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]}
for i in insts
],
}
async def _ensure_instance_logged_in(inst: PoolInstance) -> dict:
client = inst.client
auto_login = inst.auto_login
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc)
raise HTTPException(
status_code=503,
detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}},
)
if status and status.get("id"):
return status
if not settings.auto_login_enabled:
raise HTTPException(
status_code=401,
detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}},
)
if settings.dedicated_domain_url:
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", inst.name, exc)
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
if not login_url:
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
await auto_login.ensure_started(login_url)
try:
await auto_login.wait_done(timeout=settings.auto_login_timeout + 20)
except Exception as exc:
logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc)
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc)
status = None
if status and status.get("id"):
return status
logger.warning(
"[%s] auto login did not result in a logged-in session: %s",
inst.name,
auto_login.status(),
)
raise HTTPException(
status_code=401,
detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}},
)
def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
"""Derive a stable affinity key so that follow-ups go to the same instance.
Priority: explicit `user` > hash of the first/system message.
"""
if req.user:
return req.user.strip() or None
for m in req.messages:
if m.role == "system":
text = flatten_content(m.content)
if text:
return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
if req.messages:
first = req.messages[0]
text = flatten_content(first.content)
if text:
return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
return None
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
async def v1_models():
p = _require_pool()
inst = p.pick()
await _ensure_instance_logged_in(inst)
await stats_collector.inc_models()
models = await inst.client.query_models()
keys = flatten_model_keys(models)
name_map = build_model_name_map(models)
resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys])
return JSONResponse(content=resp.model_dump())
def _messages_to_prompt(messages: list[dict]) -> str:
parts: list[str] = []
for m in messages:
role = m.get("role", "user")
text = flatten_content(m.get("content"))
if not text and m.get("tool_calls"):
text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}"
if not text:
continue
parts.append(f"[{role}] {text}")
return "\n".join(parts).strip()
def _include_usage(stream_options: dict | None) -> bool:
if not isinstance(stream_options, dict):
return False
return bool(stream_options.get("include_usage"))
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest):
p = _require_pool()
affinity = _affinity_key_for(req)
inst = p.pick(affinity_key=affinity)
await _ensure_instance_logged_in(inst)
models = await inst.client.query_models()
available = flatten_model_keys(models)
name_map = build_model_name_map(models)
model = resolve_model(req.model, available, settings.default_model, name_map)
ask_mode = settings.default_ask_mode
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
if not prompt:
raise HTTPException(
status_code=400,
detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}},
)
prompt_tokens = estimate_tokens(prompt)
include_usage = _include_usage(req.stream_options)
# Backpressure: acquire a slot *after* the cheap validation but before any
# upstream call. This ensures we reject quickly when saturated.
try:
ticket = await chat_guard.try_acquire()
except BackpressureRejected as exc:
retry_after = max(1, int(exc.retry_after))
logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after)
raise HTTPException(
status_code=429,
detail={
"error": {
"message": "Too many in-flight requests, please retry later",
"type": "rate_limit_error",
"code": "backpressure",
}
},
headers={"Retry-After": str(retry_after)},
)
inst.in_flight += 1
logger.info(
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
inst.name,
model,
ask_mode,
req.stream,
prompt_tokens,
extra={
"ctx_instance": inst.name,
"ctx_model": model,
"ctx_ask_mode": ask_mode,
"ctx_stream": req.stream,
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": affinity,
},
)
ticket_transferred = False
try:
if req.stream:
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
async def event_stream(_ticket=ticket, _inst=inst):
success = False
try:
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
if include_usage:
usage_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens_holder["n"],
"total_tokens": prompt_tokens + completion_tokens_holder["n"],
},
}
yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
success = True
except asyncio.CancelledError:
logger.info("chat.stream cancelled by client (inst=%s)", _inst.name)
raise
except Exception as exc:
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
finally:
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens_holder["n"],
)
_inst.in_flight = max(0, _inst.in_flight - 1)
_ticket.release()
ticket_transferred = True
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
try:
result = await inst.client.chat_complete(prompt, model, ask_mode)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
await stats_collector.record_chat(
stream=False,
success=False,
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
raise HTTPException(
status_code=502,
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
)
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""},
)
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
"total_ms": result.get("totalLatencyMs"),
}
data["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
data["served_by"] = inst.name
return JSONResponse(content=data)
finally:
if not ticket_transferred:
inst.in_flight = max(0, inst.in_flight - 1)
ticket.release()
@app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)])
async def internal_auto_login_start(instance: str | None = None):
p = _require_pool()
target = None
if instance:
for inst in p.instances:
if inst.name == instance:
target = inst
break
if target is None:
raise HTTPException(
status_code=404,
detail={"error": {"message": f"instance not found: {instance}"}},
)
else:
target = p.pick()
client = target.client
auto_login = target.auto_login
status = await client.auth_status()
if status and status.get("id"):
return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status}
if settings.dedicated_domain_url:
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", target.name, exc)
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
if not login_url:
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
started = await auto_login.ensure_started(login_url)
return {
"ok": True,
"state": "running" if started else "already_running",
"instance": target.name,
"auto_login": auto_login.status(),
}
@app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)])
async def internal_auto_login_status():
p = _require_pool()
out = []
for inst in p.instances:
try:
auth = await inst.client.auth_status()
except Exception as exc:
auth = {"error": str(exc)}
out.append(
{
"instance": inst.name,
"auto_login": inst.auto_login.status(),
"auth": auth,
"state": inst.client.state,
}
)
return {"ok": True, "instances": out}
@app.get("/internal/stats", dependencies=[Depends(auth_guard)])
async def internal_stats():
p = _require_pool()
return {
"ok": True,
"stats": await stats_collector.snapshot(),
"concurrency": chat_guard.stats(),
"pool": p.stats(),
}
@app.get("/metrics", dependencies=[Depends(metrics_auth_guard)])
async def metrics():
base = await stats_collector.prometheus_text()
lines = list(chat_guard.prometheus_lines())
if pool is not None:
lines.extend(pool.prometheus_lines())
extra = "\n".join(lines) + "\n"
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")