feat: M1+M2 gateway hardening and multi-instance pool

Behavior hardening (M1):
- Fix `_chat_streams` memory leak: pop_stream on completion, error, and
  client disconnect.
- Add WebSocket reconnect with state machine (stopped/starting/ready/
  reconnecting/failed/closed) and exponential backoff, so a Lingma
  restart no longer requires restarting the gateway.
- Lazy initialization: startup failure is non-fatal, first real request
  triggers retry, `/healthz` reflects readiness.
- Migrate FastAPI on_event to lifespan.
- Structured JSON logging with request_id ContextVar; `x-request-id`
  propagated to responses.
- SSE now sets `Cache-Control: no-cache`, `X-Accel-Buffering: no` to
  defeat proxy buffering.
- OpenAI schema compatibility: `content` accepts str | list[parts] | None,
  added `developer`/`function` roles, `tools/tool_choice/stream_options/
  user/max_tokens` fields, and `stream_options.include_usage` emits final
  usage chunk.
- `require_bearer` uses `hmac.compare_digest`; `/metrics` now requires
  Bearer when `METRICS_TOKEN` or `API_KEYS` are set.
- Python 3.10/3.11 `TimeoutError` vs `asyncio.TimeoutError` unified.
- Error responses no longer leak `auto_login.status()` details.

Backpressure (M2 / A2):
- New `InFlightGuard` with per-request ticket, queue + rejection
  accounting, `BackpressureRejected` raises 429 + `Retry-After` once
  `GATEWAY_QUEUE_TIMEOUT_SEC` elapses.
- Streaming ticket ownership transfers to the generator so CancelledError
  from client disconnect still releases the slot.
- `/internal/stats.concurrency` and `/metrics` expose in_flight/queued/
  accepted_total/rejected_total/max_in_flight.

Multi-instance pool (M2 / A1 + B3):
- New `LingmaPool` with N processes, each with its own workDir, socket
  port (dynamic when N>1), and `AutoLoginManager`.
- Account parser supports CSV (`u1:p1,u2:p2`) and JSON formats via
  `LINGMA_ACCOUNTS`; falls back to `LINGMA_USERNAME/LINGMA_PASSWORD` for
  backwards compatibility (N=1 keeps legacy paths/ports).
- Routing: sticky affinity by `user` / system-prompt hash, then
  least-in-flight, finally round-robin fallback for unhealthy pool.
- `/healthz` reports per-instance state and ready count.
- `/internal/stats.pool` and `/metrics` expose per-instance
  `gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`.
- `/internal/auto-login/start?instance=inst-N` targets a specific instance;
  `/internal/auto-login/status` lists all instances.

Compat notes:
- `.env.example` adds `METRICS_TOKEN`, `LOG_LEVEL`, `GATEWAY_MAX_IN_FLIGHT`,
  `GATEWAY_QUEUE_TIMEOUT_SEC`, `LINGMA_ACCOUNTS`, `LINGMA_INSTANCE_COUNT`.
- `.gitignore` cleaned up data/ duplication.
- Existing single-instance deployments keep working without config change.

Made-with: Cursor
This commit is contained in:
GitHub Actions
2026-04-18 07:40:32 +08:00
parent 6114c66aed
commit 707acc9005
11 changed files with 1360 additions and 222 deletions

View File

@@ -1,16 +1,20 @@
from __future__ import annotations
import asyncio
import hashlib
import json
import time
import uuid
from contextlib import asynccontextmanager
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from .auto_login import AutoLoginManager
from .auth import require_bearer
from .auth import require_bearer, require_metrics_access
from .concurrency import BackpressureRejected, InFlightGuard
from .config import Settings, load_settings
from .lingma_client import LingmaGatewayClient
from .lingma_pool import LingmaPool, PoolInstance
from .logging_config import configure_logging, get_logger, request_id_var
from .model_map import build_model_name_map, flatten_model_keys, resolve_model
from .openai_schema import (
ChatCompletionChoice,
@@ -18,107 +22,219 @@ from .openai_schema import (
ChatCompletionsRequest,
ModelData,
ModelsResponse,
flatten_content,
)
from .stats import StatsCollector, estimate_tokens
app = FastAPI(title="Lingma OpenAI Gateway", version="0.1.0")
settings: Settings = load_settings()
lingma: LingmaGatewayClient | None = None
auto_login: AutoLoginManager | None = None
configure_logging(settings.log_level)
logger = get_logger("lingma_gateway")
pool: LingmaPool | None = None
stats_collector = StatsCollector()
chat_guard = InFlightGuard(
max_in_flight=settings.gateway_max_in_flight,
queue_timeout_sec=settings.gateway_queue_timeout_sec,
)
def _require_pool() -> LingmaPool:
if pool is None:
raise HTTPException(
status_code=503,
detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}},
)
return pool
@asynccontextmanager
async def lifespan(_app: FastAPI):
global pool
pool = LingmaPool.build(
lingma_bin=settings.lingma_bin,
base_work_dir=settings.lingma_work_dir,
legacy_socket_port=settings.lingma_socket_port,
startup_timeout=settings.lingma_startup_timeout,
rpc_timeout=settings.lingma_rpc_timeout,
default_model=settings.default_model,
default_ask_mode=settings.default_ask_mode,
accounts=settings.accounts,
instance_count=settings.instance_count,
auto_login_headless=settings.auto_login_headless,
auto_login_timeout=settings.auto_login_timeout,
auto_login_max_retry=settings.auto_login_max_retry,
)
logger.info(
"gateway startup: pool_size=%d max_in_flight=%d",
pool.size(),
settings.gateway_max_in_flight,
)
await pool.start()
try:
yield
finally:
if pool is not None:
await pool.close()
app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan)
@app.middleware("http")
async def request_id_middleware(request: Request, call_next):
req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}"
token = request_id_var.set(req_id)
start = time.monotonic()
status_code = 500
try:
response = await call_next(request)
status_code = response.status_code
response.headers["x-request-id"] = req_id
return response
finally:
elapsed_ms = int((time.monotonic() - start) * 1000)
logger.info(
"http %s %s -> %s in %dms",
request.method,
request.url.path,
status_code,
elapsed_ms,
extra={
"ctx_method": request.method,
"ctx_path": request.url.path,
"ctx_status": status_code,
"ctx_elapsed_ms": elapsed_ms,
},
)
request_id_var.reset(token)
def auth_guard(request: Request):
require_bearer(request, settings.api_keys)
async def _is_logged_in() -> bool:
assert lingma is not None
st = await lingma.auth_status()
return bool(st and st.get("id"))
@app.on_event("startup")
async def on_startup():
global lingma, auto_login
lingma = LingmaGatewayClient(
lingma_bin=settings.lingma_bin,
work_dir=settings.lingma_work_dir,
socket_port=settings.lingma_socket_port,
startup_timeout=settings.lingma_startup_timeout,
rpc_timeout=settings.lingma_rpc_timeout,
default_model=settings.default_model,
default_ask_mode=settings.default_ask_mode,
)
await lingma.start()
auto_login = AutoLoginManager(
username=settings.lingma_username,
password=settings.lingma_password,
headless=settings.auto_login_headless,
timeout_sec=settings.auto_login_timeout,
max_retry=settings.auto_login_max_retry,
verify_logged_in=_is_logged_in,
verify_timeout_sec=max(30, min(180, settings.auto_login_timeout)),
)
@app.on_event("shutdown")
async def on_shutdown():
if lingma:
await lingma.close()
def metrics_auth_guard(request: Request):
require_metrics_access(request, settings.api_keys, settings.metrics_token)
@app.get("/healthz")
async def healthz():
return {"ok": True, "time": int(time.time())}
if pool is None:
return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"}
insts = pool.stats()
ready = sum(1 for i in insts if i["state"] == "ready")
return {
"ok": ready > 0,
"time": int(time.time()),
"pool_size": len(insts),
"pool_ready": ready,
"instances": [
{"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]}
for i in insts
],
}
async def _ensure_logged_in_or_auto_login() -> dict:
assert lingma is not None
status = await lingma.auth_status()
async def _ensure_instance_logged_in(inst: PoolInstance) -> dict:
client = inst.client
auto_login = inst.auto_login
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc)
raise HTTPException(
status_code=503,
detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}},
)
if status and status.get("id"):
return status
if not settings.auto_login_enabled:
raise HTTPException(status_code=401, detail={"error": {"message": "Lingma not logged in"}})
if settings.dedicated_domain_url:
current = await lingma.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await lingma.update_endpoint(settings.dedicated_domain_url)
login_url, login_raw = await lingma.generate_login_url()
if not login_url:
raise HTTPException(
status_code=500,
detail={"error": {"message": f"generate login url failed: {login_raw}"}},
status_code=401,
detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}},
)
if settings.dedicated_domain_url:
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", inst.name, exc)
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
if not login_url:
raise HTTPException(
status_code=502,
detail={"error": {"message": "generate login url failed", "type": "upstream_error"}},
)
assert auto_login is not None
await auto_login.ensure_started(login_url)
try:
await auto_login.wait_done(timeout=settings.auto_login_timeout + 20)
except Exception:
pass
except Exception as exc:
logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc)
try:
status = await client.auth_status()
except Exception as exc:
logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc)
status = None
status = await lingma.auth_status()
if status and status.get("id"):
return status
logger.warning(
"[%s] auto login did not result in a logged-in session: %s",
inst.name,
auto_login.status(),
)
raise HTTPException(
status_code=401,
detail={"error": {"message": "Lingma auto login failed", "auto_login": auto_login.status()}},
detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}},
)
def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
"""Derive a stable affinity key so that follow-ups go to the same instance.
Priority: explicit `user` > hash of the first/system message.
"""
if req.user:
return req.user.strip() or None
for m in req.messages:
if m.role == "system":
text = flatten_content(m.content)
if text:
return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
if req.messages:
first = req.messages[0]
text = flatten_content(first.content)
if text:
return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16]
return None
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
async def v1_models():
assert lingma is not None
await _ensure_logged_in_or_auto_login()
p = _require_pool()
inst = p.pick()
await _ensure_instance_logged_in(inst)
await stats_collector.inc_models()
models = await lingma.query_models()
models = await inst.client.query_models()
keys = flatten_model_keys(models)
name_map = build_model_name_map(models)
resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys])
@@ -126,20 +242,32 @@ async def v1_models():
def _messages_to_prompt(messages: list[dict]) -> str:
parts = []
parts: list[str] = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
parts.append(f"[{role}] {content}")
text = flatten_content(m.get("content"))
if not text and m.get("tool_calls"):
text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}"
if not text:
continue
parts.append(f"[{role}] {text}")
return "\n".join(parts).strip()
def _include_usage(stream_options: dict | None) -> bool:
if not isinstance(stream_options, dict):
return False
return bool(stream_options.get("include_usage"))
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest):
assert lingma is not None
await _ensure_logged_in_or_auto_login()
p = _require_pool()
affinity = _affinity_key_for(req)
inst = p.pick(affinity_key=affinity)
await _ensure_instance_logged_in(inst)
models = await lingma.query_models()
models = await inst.client.query_models()
available = flatten_model_keys(models)
name_map = build_model_name_map(models)
model = resolve_model(req.model, available, settings.default_model, name_map)
@@ -150,142 +278,270 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
if not prompt:
raise HTTPException(status_code=400, detail={"error": {"message": "messages is empty"}})
raise HTTPException(
status_code=400,
detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}},
)
prompt_tokens = estimate_tokens(prompt)
include_usage = _include_usage(req.stream_options)
if req.stream:
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
# Backpressure: acquire a slot *after* the cheap validation but before any
# upstream call. This ensures we reject quickly when saturated.
try:
ticket = await chat_guard.try_acquire()
except BackpressureRejected as exc:
retry_after = max(1, int(exc.retry_after))
logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after)
raise HTTPException(
status_code=429,
detail={
"error": {
"message": "Too many in-flight requests, please retry later",
"type": "rate_limit_error",
"code": "backpressure",
}
},
headers={"Retry-After": str(retry_after)},
)
async def event_stream():
success = False
try:
async for chunk in lingma.chat_stream(prompt, model, ask_mode):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
inst.in_flight += 1
logger.info(
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
inst.name,
model,
ask_mode,
req.stream,
prompt_tokens,
extra={
"ctx_instance": inst.name,
"ctx_model": model,
"ctx_ask_mode": ask_mode,
"ctx_stream": req.stream,
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": affinity,
},
)
ticket_transferred = False
try:
if req.stream:
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
async def event_stream(_ticket=ticket, _inst=inst):
success = False
try:
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": chunk},
"finish_reason": None,
}
],
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
success = True
finally:
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens_holder["n"],
)
if include_usage:
usage_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens_holder["n"],
"total_tokens": prompt_tokens + completion_tokens_holder["n"],
},
}
yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
yield "data: [DONE]\n\n"
success = True
except asyncio.CancelledError:
logger.info("chat.stream cancelled by client (inst=%s)", _inst.name)
raise
except Exception as exc:
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
finally:
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens_holder["n"],
)
_inst.in_flight = max(0, _inst.in_flight - 1)
_ticket.release()
try:
result = await lingma.chat_complete(prompt, model, ask_mode)
except Exception:
ticket_transferred = True
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-transform",
"X-Accel-Buffering": "no",
"Connection": "keep-alive",
},
)
try:
result = await inst.client.chat_complete(prompt, model, ask_mode)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
await stats_collector.record_chat(
stream=False,
success=False,
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
raise HTTPException(
status_code=502,
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
)
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=0,
completion_tokens=completion_tokens,
)
raise
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""},
)
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
"total_ms": result.get("totalLatencyMs"),
}
data["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
return JSONResponse(content=data)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""},
)
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
"total_ms": result.get("totalLatencyMs"),
}
data["usage"] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
data["served_by"] = inst.name
return JSONResponse(content=data)
finally:
if not ticket_transferred:
inst.in_flight = max(0, inst.in_flight - 1)
ticket.release()
@app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)])
async def internal_auto_login_start():
assert lingma is not None
assert auto_login is not None
async def internal_auto_login_start(instance: str | None = None):
p = _require_pool()
target = None
if instance:
for inst in p.instances:
if inst.name == instance:
target = inst
break
if target is None:
raise HTTPException(
status_code=404,
detail={"error": {"message": f"instance not found: {instance}"}},
)
else:
target = p.pick()
status = await lingma.auth_status()
client = target.client
auto_login = target.auto_login
status = await client.auth_status()
if status and status.get("id"):
return {"ok": True, "state": "already_logged_in", "auth": status}
return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status}
if settings.dedicated_domain_url:
current = await lingma.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await lingma.update_endpoint(settings.dedicated_domain_url)
try:
current = await client.get_endpoint()
current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else ""
if current_ep != settings.dedicated_domain_url:
await client.update_endpoint(settings.dedicated_domain_url)
except Exception as exc:
logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc)
try:
login_url, _login_raw = await client.generate_login_url()
except Exception as exc:
logger.warning("[%s] generate_login_url failed: %s", target.name, exc)
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
login_url, login_raw = await lingma.generate_login_url()
if not login_url:
raise HTTPException(status_code=500, detail={"error": {"message": "generate login url failed", "raw": login_raw}})
raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}})
started = await auto_login.ensure_started(login_url)
return {
"ok": True,
"state": "running" if started else "already_running",
"loginUrl": login_url,
"instance": target.name,
"auto_login": auto_login.status(),
}
@app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)])
async def internal_auto_login_status():
assert auto_login is not None
assert lingma is not None
return {
"ok": True,
"auto_login": auto_login.status(),
"auth": await lingma.auth_status(),
}
p = _require_pool()
out = []
for inst in p.instances:
try:
auth = await inst.client.auth_status()
except Exception as exc:
auth = {"error": str(exc)}
out.append(
{
"instance": inst.name,
"auto_login": inst.auto_login.status(),
"auth": auth,
"state": inst.client.state,
}
)
return {"ok": True, "instances": out}
@app.get("/internal/stats", dependencies=[Depends(auth_guard)])
async def internal_stats():
return {"ok": True, "stats": await stats_collector.snapshot()}
p = _require_pool()
return {
"ok": True,
"stats": await stats_collector.snapshot(),
"concurrency": chat_guard.stats(),
"pool": p.stats(),
}
@app.get("/metrics")
@app.get("/metrics", dependencies=[Depends(metrics_auth_guard)])
async def metrics():
text = await stats_collector.prometheus_text()
return StreamingResponse(iter([text]), media_type="text/plain; version=0.0.4")
base = await stats_collector.prometheus_text()
lines = list(chat_guard.prometheus_lines())
if pool is not None:
lines.extend(pool.prometheus_lines())
extra = "\n".join(lines) + "\n"
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")