Files
lingma-openai-gateway/app/lingma_client.py
GitHub Actions 707acc9005 feat: M1+M2 gateway hardening and multi-instance pool
Behavior hardening (M1):
- Fix `_chat_streams` memory leak: pop_stream on completion, error, and
  client disconnect.
- Add WebSocket reconnect with state machine (stopped/starting/ready/
  reconnecting/failed/closed) and exponential backoff, so a Lingma
  restart no longer requires restarting the gateway.
- Lazy initialization: startup failure is non-fatal, first real request
  triggers retry, `/healthz` reflects readiness.
- Migrate FastAPI on_event to lifespan.
- Structured JSON logging with request_id ContextVar; `x-request-id`
  propagated to responses.
- SSE now sets `Cache-Control: no-cache`, `X-Accel-Buffering: no` to
  defeat proxy buffering.
- OpenAI schema compatibility: `content` accepts str | list[parts] | None,
  added `developer`/`function` roles, `tools/tool_choice/stream_options/
  user/max_tokens` fields, and `stream_options.include_usage` emits final
  usage chunk.
- `require_bearer` uses `hmac.compare_digest`; `/metrics` now requires
  Bearer when `METRICS_TOKEN` or `API_KEYS` are set.
- Python 3.10/3.11 `TimeoutError` vs `asyncio.TimeoutError` unified.
- Error responses no longer leak `auto_login.status()` details.

Backpressure (M2 / A2):
- New `InFlightGuard` with per-request ticket, queue + rejection
  accounting, `BackpressureRejected` raises 429 + `Retry-After` once
  `GATEWAY_QUEUE_TIMEOUT_SEC` elapses.
- Streaming ticket ownership transfers to the generator so CancelledError
  from client disconnect still releases the slot.
- `/internal/stats.concurrency` and `/metrics` expose in_flight/queued/
  accepted_total/rejected_total/max_in_flight.

Multi-instance pool (M2 / A1 + B3):
- New `LingmaPool` with N processes, each with its own workDir, socket
  port (dynamic when N>1), and `AutoLoginManager`.
- Account parser supports CSV (`u1:p1,u2:p2`) and JSON formats via
  `LINGMA_ACCOUNTS`; falls back to `LINGMA_USERNAME/LINGMA_PASSWORD` for
  backwards compatibility (N=1 keeps legacy paths/ports).
- Routing: sticky affinity by `user` / system-prompt hash, then
  least-in-flight, finally round-robin fallback for unhealthy pool.
- `/healthz` reports per-instance state and ready count.
- `/internal/stats.pool` and `/metrics` expose per-instance
  `gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`.
- `/internal/auto-login/start?instance=inst-N` targets a specific instance;
  `/internal/auto-login/status` lists all instances.

Compat notes:
- `.env.example` adds `METRICS_TOKEN`, `LOG_LEVEL`, `GATEWAY_MAX_IN_FLIGHT`,
  `GATEWAY_QUEUE_TIMEOUT_SEC`, `LINGMA_ACCOUNTS`, `LINGMA_INSTANCE_COUNT`.
- `.gitignore` cleaned up data/ duplication.
- Existing single-instance deployments keep working without config change.

Made-with: Cursor
2026-04-18 07:40:32 +08:00

615 lines
23 KiB
Python

from __future__ import annotations
import asyncio
import contextlib
import json
import os
import socket
import subprocess
import time
import uuid
from pathlib import Path
from typing import AsyncIterator, Callable, Optional
import websockets
from .logging_config import get_logger
logger = get_logger("lingma_gateway.client")
# Some callers live on Python 3.10 where asyncio.TimeoutError is a distinct class,
# while 3.11+ unifies it with the builtin TimeoutError. Always catch both.
TIMEOUT_EXCEPTIONS: tuple[type[BaseException], ...] = (
asyncio.TimeoutError,
TimeoutError,
)
def _is_port_open(host: str, port: int, timeout_sec: float = 0.5) -> bool:
try:
with socket.create_connection((host, port), timeout=timeout_sec):
return True
except OSError:
return False
def _read_info_file(info_path: Path):
if not info_path.exists():
return None, None
txt = info_path.read_text(encoding="utf-8", errors="ignore").strip()
if not txt:
return None, None
lines = txt.splitlines()
if len(lines) < 2:
return None, None
try:
return int(lines[0].strip()), int(lines[1].strip())
except ValueError:
return None, None
def _wait_info_any(info_paths: list[Path], timeout_sec: int):
start = time.time()
while time.time() - start < timeout_sec:
for p in info_paths:
port, pid = _read_info_file(p)
if port and pid:
return port, pid, p
time.sleep(0.2)
raise TimeoutError(".info not ready")
def _encode_lsp_frame(payload_obj: dict) -> bytes:
body = json.dumps(payload_obj, ensure_ascii=False).encode("utf-8")
header = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")
return header + body
def _parse_lsp_frames(buf: bytes):
frames = []
while True:
header_end = buf.find(b"\r\n\r\n")
if header_end < 0:
break
header = buf[:header_end]
body_start = header_end + 4
content_length = None
for line in header.split(b"\r\n"):
if line.lower().startswith(b"content-length:"):
content_length = int(line.split(b":", 1)[1].strip())
break
if content_length is None:
buf = buf[body_start:]
continue
if len(buf) < body_start + content_length:
break
body = buf[body_start : body_start + content_length]
frames.append(body.decode("utf-8", errors="ignore"))
buf = buf[body_start + content_length :]
return frames, buf
class LspWsRpcClient:
def __init__(self, ws, on_disconnect: Optional[Callable[[BaseException], None]] = None):
self.ws = ws
self._id = 1
self._pending: dict[int, asyncio.Future] = {}
self._send_lock = asyncio.Lock()
self._reader_task: asyncio.Task | None = None
self._rx_buffer = b""
self._chat_streams: dict[str, dict] = {}
self._on_disconnect = on_disconnect
self._closed = False
async def start(self):
self._reader_task = asyncio.create_task(self._reader_loop())
async def close(self):
self._closed = True
if self._reader_task:
self._reader_task.cancel()
with contextlib.suppress(Exception):
await self._reader_task
# Abort any pending futures so callers fail fast instead of hanging.
for fut in self._pending.values():
if not fut.done():
fut.set_exception(ConnectionError("lingma client closed"))
self._pending.clear()
# Signal open streams to terminate.
for stream in self._chat_streams.values():
if not stream["done"].is_set():
stream["done"].set()
stream["chunks"].put_nowait(None)
self._chat_streams.clear()
async def _send(self, payload: dict):
async with self._send_lock:
await self.ws.send(_encode_lsp_frame(payload))
async def _reader_loop(self):
try:
while True:
raw = await self.ws.recv()
chunk = raw if isinstance(raw, bytes) else raw.encode("utf-8", errors="ignore")
self._rx_buffer += chunk
bodies, self._rx_buffer = _parse_lsp_frames(self._rx_buffer)
for body in bodies:
try:
msg = json.loads(body)
except Exception:
continue
if "method" in msg and "result" not in msg and "error" not in msg:
await self._handle_server_message(msg)
continue
rid = msg.get("id")
if rid is None:
continue
fut = self._pending.pop(rid, None)
if fut and not fut.done():
fut.set_result(msg)
except asyncio.CancelledError:
pass
except Exception as exc:
if not self._closed:
logger.warning("lingma reader loop terminated: %s", exc)
# Propagate failure to anyone waiting on an RPC.
for fut in self._pending.values():
if not fut.done():
fut.set_exception(exc)
self._pending.clear()
# Also unblock any in-flight chat streams so consumers exit.
for stream in self._chat_streams.values():
if not stream["done"].is_set():
stream["done"].set()
stream["chunks"].put_nowait(None)
if not self._closed and self._on_disconnect is not None:
try:
self._on_disconnect(exc)
except Exception:
logger.exception("on_disconnect callback failed")
async def _handle_server_message(self, msg: dict):
method = msg.get("method")
params = msg.get("params") or {}
if method == "chat/answer":
req_id = params.get("requestId")
stream = self._chat_streams.get(req_id)
if stream is not None:
text = params.get("text") or params.get("content") or ""
if text:
stream["parts"].append(text)
if stream["first_chunk_at"] is None:
stream["first_chunk_at"] = time.monotonic()
stream["chunks"].put_nowait(text)
if method == "chat/finish":
req_id = params.get("requestId")
stream = self._chat_streams.get(req_id)
if stream is not None and not stream["done"].is_set():
stream["finish"] = params
stream["finish_at"] = time.monotonic()
stream["done"].set()
stream["chunks"].put_nowait(None)
if "id" in msg:
await self._send({"jsonrpc": "2.0", "id": msg.get("id"), "result": {}})
async def request(self, method, params=None, timeout=20):
rid = self._id
self._id += 1
payload = {"jsonrpc": "2.0", "id": rid, "method": method, "params": params or {}}
fut = asyncio.get_running_loop().create_future()
self._pending[rid] = fut
await self._send(payload)
try:
msg = await asyncio.wait_for(fut, timeout=timeout)
except TIMEOUT_EXCEPTIONS:
self._pending.pop(rid, None)
raise TimeoutError(f"RPC timeout: {method}")
if "error" in msg:
raise RuntimeError(f"RPC {method} error: {msg['error']}")
return msg.get("result")
async def notify(self, method, params=None):
await self._send({"jsonrpc": "2.0", "method": method, "params": params or {}})
def create_stream(self, request_id: str):
self._chat_streams[request_id] = {
"parts": [],
"chunks": asyncio.Queue(),
"done": asyncio.Event(),
"finish": None,
"started_at": time.monotonic(),
"first_chunk_at": None,
"finish_at": None,
}
def pop_stream(self, request_id: str) -> None:
stream = self._chat_streams.pop(request_id, None)
if stream is None:
return
# Drain queue so no stray future gets stuck if the consumer bailed early.
if not stream["done"].is_set():
stream["done"].set()
with contextlib.suppress(Exception):
stream["chunks"].put_nowait(None)
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]:
stream = self._chat_streams.get(request_id)
if stream is None:
return
start = time.monotonic()
while True:
remain = timeout - (time.monotonic() - start)
if remain <= 0:
raise TimeoutError("chat stream timeout")
chunk = await asyncio.wait_for(stream["chunks"].get(), timeout=remain)
if chunk is None:
break
yield chunk
def get_stream_result(self, request_id: str) -> dict:
stream = self._chat_streams.get(request_id) or {}
first_ms = None
total_ms = None
if stream.get("first_chunk_at") is not None:
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
if stream.get("finish_at") is not None:
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
return {
"text": "".join(stream.get("parts") or []),
"finish": stream.get("finish") or {},
"firstTokenLatencyMs": first_ms,
"totalLatencyMs": total_ms,
}
class LingmaGatewayClient:
"""Owns the Lingma subprocess and the LSP-over-WS connection.
Adds a small state machine + reconnect loop so the gateway can survive Lingma
restarts and slow cold starts without bringing down the FastAPI app.
"""
STATE_STOPPED = "stopped"
STATE_STARTING = "starting"
STATE_READY = "ready"
STATE_RECONNECTING = "reconnecting"
STATE_FAILED = "failed"
STATE_CLOSED = "closed"
def __init__(
self,
lingma_bin: str,
work_dir: str,
socket_port: int,
startup_timeout: int,
rpc_timeout: int,
default_model: str,
default_ask_mode: str,
*,
name: str = "lingma",
extra_info_paths: list[Path] | None = None,
):
self.name = name
self.lingma_bin = Path(lingma_bin)
self.work_dir = Path(work_dir)
self.socket_port = socket_port
self.startup_timeout = startup_timeout
self.rpc_timeout = rpc_timeout
self.default_model = default_model
self.default_ask_mode = default_ask_mode
# Each pool instance should only look at its own workDir .info to avoid
# cross-instance clobbering via the shared ~/.lingma/.info path.
if extra_info_paths is None:
extra_info_paths = [Path.home() / ".lingma" / ".info"]
self._extra_info_paths = list(extra_info_paths)
self._rpc: LspWsRpcClient | None = None
self._ws = None
self._state = self.STATE_STOPPED
self._state_lock = asyncio.Lock()
self._ready_event = asyncio.Event()
self._reconnect_task: asyncio.Task | None = None
self._last_error: str = ""
# ------------------------------------------------------------------ state
@property
def state(self) -> str:
return self._state
@property
def last_error(self) -> str:
return self._last_error
def _set_state(self, state: str, err: str = "") -> None:
if state != self._state:
logger.info("lingma client state %s -> %s", self._state, state, extra={"ctx_new_state": state})
self._state = state
if err:
self._last_error = err
if state == self.STATE_READY:
self._ready_event.set()
else:
self._ready_event.clear()
# -------------------------------------------------------------- lifecycle
async def start(self) -> None:
"""Initial start. Failure is non-fatal: ensure_ready() will retry later."""
try:
await self._connect(initial=True)
except Exception as exc:
self._set_state(self.STATE_FAILED, err=str(exc))
logger.exception("initial lingma start failed; will retry on demand")
async def close(self) -> None:
self._set_state(self.STATE_CLOSED)
if self._reconnect_task and not self._reconnect_task.done():
self._reconnect_task.cancel()
with contextlib.suppress(Exception):
await self._reconnect_task
if self._rpc:
await self._rpc.close()
if self._ws:
with contextlib.suppress(Exception):
await self._ws.close()
async def ensure_ready(self, timeout: float | None = None) -> None:
"""Block until the RPC connection is usable, (re)connecting on demand."""
if self._state == self.STATE_CLOSED:
raise RuntimeError("lingma client is closed")
if self._state == self.STATE_READY and self._ws is not None:
return
async with self._state_lock:
if self._state == self.STATE_READY and self._ws is not None:
return
if self._state in (self.STATE_STOPPED, self.STATE_FAILED):
try:
await self._connect(initial=False)
return
except Exception as exc:
self._set_state(self.STATE_FAILED, err=str(exc))
raise
wait_timeout = timeout if timeout is not None else max(
30.0, float(self.startup_timeout) + 10.0
)
try:
await asyncio.wait_for(self._ready_event.wait(), timeout=wait_timeout)
except TIMEOUT_EXCEPTIONS:
raise RuntimeError(f"lingma not ready (state={self._state}, err={self._last_error})")
# --------------------------------------------------------------- connect
async def _connect(self, *, initial: bool) -> None:
self._set_state(self.STATE_STARTING)
if not self.lingma_bin.exists():
raise FileNotFoundError(f"Lingma not found: {self.lingma_bin}")
info_paths = [self.work_dir / ".info", *self._extra_info_paths]
# socket_port <= 0 is the pool-friendly "always spawn and read .info" mode.
port_prewarmed = self.socket_port > 0 and _is_port_open(
"127.0.0.1", self.socket_port
)
if not port_prewarmed:
self.work_dir.mkdir(parents=True, exist_ok=True)
# Remove stale info files from host-mounted workspace before boot.
for p in info_paths:
with contextlib.suppress(Exception):
if p.exists():
p.unlink()
logger.info(
"[%s] spawning lingma: %s start --workDir %s",
self.name,
self.lingma_bin,
self.work_dir,
)
subprocess.Popen(
[str(self.lingma_bin), "start", "--workDir", str(self.work_dir)],
cwd=str(self.lingma_bin.parent),
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
info, _, _ = _wait_info_any(info_paths, timeout_sec=self.startup_timeout)
self.socket_port = info
deadline = time.time() + self.startup_timeout
while time.time() < deadline:
if _is_port_open("127.0.0.1", self.socket_port, timeout_sec=0.3):
break
await asyncio.sleep(0.2)
else:
raise TimeoutError(f"Lingma socket not open on port {self.socket_port}")
# Close any stale ws/rpc before creating fresh ones (reconnect path).
if self._rpc is not None:
with contextlib.suppress(Exception):
await self._rpc.close()
self._rpc = None
if self._ws is not None:
with contextlib.suppress(Exception):
await self._ws.close()
self._ws = None
ws_url = f"ws://127.0.0.1:{self.socket_port}"
self._ws = await websockets.connect(ws_url, max_size=10 * 1024 * 1024)
self._rpc = LspWsRpcClient(self._ws, on_disconnect=self._on_disconnect)
await self._rpc.start()
await self._rpc.request(
"initialize",
{
"processId": os.getpid(),
"clientInfo": {"name": "lingma-openai-gateway", "version": "0.1.0"},
"capabilities": {},
"workspaceFolders": [],
"rootUri": None,
},
timeout=self.rpc_timeout,
)
await self._rpc.notify("initialized", {})
self._set_state(self.STATE_READY)
logger.info(
"[%s] lingma ready on port %d (initial=%s)",
self.name,
self.socket_port,
initial,
)
def _on_disconnect(self, exc: BaseException) -> None:
if self._state == self.STATE_CLOSED:
return
self._set_state(self.STATE_RECONNECTING, err=str(exc))
if self._reconnect_task and not self._reconnect_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._reconnect_task = loop.create_task(self._reconnect_loop())
async def _reconnect_loop(self) -> None:
backoff = 1.0
max_backoff = 30.0
max_attempts = 20
for attempt in range(1, max_attempts + 1):
if self._state == self.STATE_CLOSED:
return
await asyncio.sleep(backoff)
try:
async with self._state_lock:
await self._connect(initial=False)
logger.info("lingma reconnected after %d attempt(s)", attempt)
return
except Exception as exc:
self._last_error = str(exc)
logger.warning("lingma reconnect attempt %d failed: %s", attempt, exc)
backoff = min(backoff * 2, max_backoff)
self._set_state(self.STATE_FAILED, err="reconnect exhausted")
# ------------------------------------------------------------------ RPC
@property
def rpc(self) -> LspWsRpcClient:
if self._rpc is None:
raise RuntimeError(f"Lingma RPC not initialized (state={self._state})")
return self._rpc
async def auth_status(self):
await self.ensure_ready()
return await self.rpc.request("auth/status", {}, timeout=self.rpc_timeout)
async def query_models(self):
await self.ensure_ready()
return await self.rpc.request("config/queryModels", {}, timeout=self.rpc_timeout)
async def get_endpoint(self):
await self.ensure_ready()
return await self.rpc.request("config/getEndpoint", {}, timeout=self.rpc_timeout)
async def update_endpoint(self, endpoint: str):
await self.ensure_ready()
return await self.rpc.request(
"config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout
)
async def generate_login_url(self):
await self.ensure_ready()
result = await self.rpc.request("login/generateUrl", {}, timeout=self.rpc_timeout)
if isinstance(result, str):
return result, {"raw": result}
if isinstance(result, dict):
for key in ("loginUrl", "url", "login_url"):
if isinstance(result.get(key), str):
return result[key], result
return "", result
return "", {"raw": result}
# ------------------------------------------------------------------ chat
def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str):
session_type = "developer" if ask_mode == "agent" else "chat"
return {
"requestId": request_id,
"sessionId": session_id,
"sessionType": session_type,
"chatTask": "FREE_INPUT",
"mode": ask_mode,
"stream": True,
"source": 1,
"isReply": False,
"taskDefinitionType": "system",
"content": prompt,
"text": prompt,
"message": prompt,
"questionText": prompt,
"extra": {
"modelConfig": {"key": model_key},
"workspacePath": str(Path.cwd()),
},
"pluginPayloadConfig": {
"isEnableAskAgent": ask_mode == "agent",
"isEnableAutoMemory": True,
},
"chatContext": {
"text": prompt,
"features": [],
"preferredLanguage": "zh-CN",
"localeLang": "zh-CN",
},
}
async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict:
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
self.rpc.create_stream(request_id)
try:
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except TIMEOUT_EXCEPTIONS:
# chat/ask often returns nothing until chat/finish arrives; tolerate.
pass
async for _ in self.rpc.consume_stream(
request_id, timeout=max(20.0, self.rpc_timeout + 20.0)
):
pass
result = self.rpc.get_stream_result(request_id)
finally:
self.rpc.pop_stream(request_id)
finish = result.get("finish") or {}
result["requestId"] = request_id
result["sessionId"] = finish.get("sessionId") or session_id
result["model"] = model_key
result["mode"] = ask_mode
return result
async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]:
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
self.rpc.create_stream(request_id)
try:
try:
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
except TIMEOUT_EXCEPTIONS:
pass
async for chunk in self.rpc.consume_stream(
request_id, timeout=max(20.0, self.rpc_timeout + 40.0)
):
yield chunk
finally:
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
self.rpc.pop_stream(request_id)