from __future__ import annotations import asyncio import hashlib import json import time import uuid from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from .auth import require_bearer, require_metrics_access from .concurrency import BackpressureRejected, InFlightGuard from .config import Settings, load_settings from .lingma_pool import LingmaPool, PoolInstance from .logging_config import configure_logging, get_logger, request_id_var from .model_map import build_model_name_map, flatten_model_keys, resolve_model from .openai_schema import ( ChatCompletionChoice, ChatCompletionResponse, ChatCompletionsRequest, ModelData, ModelsResponse, flatten_content, ) from .session_cache import SessionCache from .stats import StatsCollector, estimate_tokens settings: Settings = load_settings() configure_logging(settings.log_level) logger = get_logger("lingma_gateway") pool: LingmaPool | None = None stats_collector = StatsCollector() chat_guard = InFlightGuard( max_in_flight=settings.gateway_max_in_flight, queue_timeout_sec=settings.gateway_queue_timeout_sec, ) session_cache = SessionCache( max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0, ttl_sec=settings.session_cache_ttl_sec, ) def _require_pool() -> LingmaPool: if pool is None: raise HTTPException( status_code=503, detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}}, ) return pool @asynccontextmanager async def lifespan(_app: FastAPI): global pool pool = LingmaPool.build( lingma_bin=settings.lingma_bin, base_work_dir=settings.lingma_work_dir, legacy_socket_port=settings.lingma_socket_port, startup_timeout=settings.lingma_startup_timeout, rpc_timeout=settings.lingma_rpc_timeout, default_model=settings.default_model, default_ask_mode=settings.default_ask_mode, accounts=settings.accounts, instance_count=settings.instance_count, auto_login_headless=settings.auto_login_headless, auto_login_timeout=settings.auto_login_timeout, auto_login_max_retry=settings.auto_login_max_retry, ) logger.info( "gateway startup: pool_size=%d max_in_flight=%d", pool.size(), settings.gateway_max_in_flight, ) await pool.start() try: yield finally: if pool is not None: await pool.close() app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan) @app.middleware("http") async def request_id_middleware(request: Request, call_next): req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}" token = request_id_var.set(req_id) start = time.monotonic() status_code = 500 try: response = await call_next(request) status_code = response.status_code response.headers["x-request-id"] = req_id return response finally: elapsed_ms = int((time.monotonic() - start) * 1000) logger.info( "http %s %s -> %s in %dms", request.method, request.url.path, status_code, elapsed_ms, extra={ "ctx_method": request.method, "ctx_path": request.url.path, "ctx_status": status_code, "ctx_elapsed_ms": elapsed_ms, }, ) request_id_var.reset(token) def auth_guard(request: Request): require_bearer(request, settings.api_keys) def metrics_auth_guard(request: Request): require_metrics_access(request, settings.api_keys, settings.metrics_token) @app.get("/healthz") async def healthz(): if pool is None: return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"} insts = pool.stats() ready = sum(1 for i in insts if i["state"] == "ready") return { "ok": ready > 0, "time": int(time.time()), "pool_size": len(insts), "pool_ready": ready, "instances": [ {"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]} for i in insts ], } async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: client = inst.client auto_login = inst.auto_login try: status = await client.auth_status() except Exception as exc: logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc) raise HTTPException( status_code=503, detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}}, ) if status and status.get("id"): return status if not settings.auto_login_enabled: raise HTTPException( status_code=401, detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}}, ) if settings.dedicated_domain_url: try: current = await client.get_endpoint() current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc) try: login_url, _login_raw = await client.generate_login_url() except Exception as exc: logger.warning("[%s] generate_login_url failed: %s", inst.name, exc) raise HTTPException( status_code=502, detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, ) if not login_url: raise HTTPException( status_code=502, detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, ) await auto_login.ensure_started(login_url) try: await auto_login.wait_done(timeout=settings.auto_login_timeout + 20) except Exception as exc: logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc) try: status = await client.auth_status() except Exception as exc: logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc) status = None if status and status.get("id"): return status logger.warning( "[%s] auto login did not result in a logged-in session: %s", inst.name, auto_login.status(), ) raise HTTPException( status_code=401, detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}}, ) def _affinity_key_for(req: ChatCompletionsRequest) -> str | None: """Derive a stable affinity key so that follow-ups go to the same instance. Priority: explicit `user` > hash of the first/system message. """ if req.user: return req.user.strip() or None for m in req.messages: if m.role == "system": text = flatten_content(m.content) if text: return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] if req.messages: first = req.messages[0] text = flatten_content(first.content) if text: return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] return None def _extract_api_key(request: Request) -> str: h = request.headers.get("authorization", "") if h.lower().startswith("bearer "): return h[7:].strip() return "" def _last_user_text(messages: list[dict]) -> str: """Extract the text of the latest user message (trailing from end). Used when we hit the session cache and only need to send the delta. Falls back to the last message regardless of role if no user is found. """ for m in reversed(messages): if m.get("role") == "user": return flatten_content(m.get("content")) or "" if messages: return flatten_content(messages[-1].get("content")) or "" return "" @app.get("/v1/models", dependencies=[Depends(auth_guard)]) async def v1_models(): p = _require_pool() inst = p.pick() await _ensure_instance_logged_in(inst) await stats_collector.inc_models() models = await inst.client.query_models() keys = flatten_model_keys(models) name_map = build_model_name_map(models) resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys]) return JSONResponse(content=resp.model_dump()) def _messages_to_prompt(messages: list[dict]) -> str: parts: list[str] = [] for m in messages: role = m.get("role", "user") text = flatten_content(m.get("content")) if not text and m.get("tool_calls"): text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}" if not text: continue parts.append(f"[{role}] {text}") return "\n".join(parts).strip() def _include_usage(stream_options: dict | None) -> bool: if not isinstance(stream_options, dict): return False return bool(stream_options.get("include_usage")) @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): p = _require_pool() messages_dump = [m.model_dump() for m in req.messages] api_key = _extract_api_key(request) or "-" # ------------------------------------------------------------- session reuse # Look up the "conversation prefix" (everything except the latest user turn) # in the session cache. A hit lets us: # 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache. # 2. Send only the new user message instead of the whole history. # 3. Stick the request to the pool instance that originally served it. ask_mode = settings.default_ask_mode if req.model.lower() in {"lingma-agent", "agent"}: ask_mode = "agent" reuse_eligible = ( session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 ) lookup_key: str | None = None write_key: str | None = None cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) write_key = session_cache.build_key(api_key, messages_dump) entry = await session_cache.get(lookup_key) if entry is not None: cached_session_id = entry.session_id cached_instance_name = entry.instance_name or None # Instance selection: prefer cached instance for continuity, else normal affinity. affinity = cached_instance_name or _affinity_key_for(req) inst = p.pick(affinity_key=affinity) # If cache pointed at a specific instance that's no longer healthy, we already # fell back via pool.pick -> drop the cached session since Lingma on a # different process won't know about it. if cached_instance_name and inst.name != cached_instance_name: logger.info( "session cache instance %s unhealthy, falling back to %s (dropping cached session)", cached_instance_name, inst.name, ) cached_session_id = None if lookup_key: await session_cache.invalidate(lookup_key) await _ensure_instance_logged_in(inst) models = await inst.client.query_models() available = flatten_model_keys(models) name_map = build_model_name_map(models) model = resolve_model(req.model, available, settings.default_model, name_map) # Prompt construction: on cache hit send only the last user turn so Lingma's # stored context isn't duplicated. if cached_session_id: prompt = _last_user_text(messages_dump) is_reply = True else: prompt = _messages_to_prompt(messages_dump) is_reply = False if not prompt: raise HTTPException( status_code=400, detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}}, ) prompt_tokens = estimate_tokens(prompt) include_usage = _include_usage(req.stream_options) # Backpressure: acquire a slot *after* the cheap validation but before any # upstream call. This ensures we reject quickly when saturated. try: ticket = await chat_guard.try_acquire() except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after) raise HTTPException( status_code=429, detail={ "error": { "message": "Too many in-flight requests, please retry later", "type": "rate_limit_error", "code": "backpressure", } }, headers={"Retry-After": str(retry_after)}, ) inst.in_flight += 1 logger.info( "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s", inst.name, model, ask_mode, req.stream, prompt_tokens, bool(cached_session_id), extra={ "ctx_instance": inst.name, "ctx_model": model, "ctx_ask_mode": ask_mode, "ctx_stream": req.stream, "ctx_prompt_tokens": prompt_tokens, "ctx_in_flight": chat_guard.in_flight, "ctx_affinity": affinity, "ctx_session_reuse": bool(cached_session_id), }, ) ticket_transferred = False try: if req.stream: created = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_tokens_holder = {"n": 0} stream_meta: dict = {} async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False try: async for chunk in _inst.client.chat_stream( prompt, model, ask_mode, session_id=cached_session_id, is_reply=is_reply, out_meta=_meta, ): completion_tokens_holder["n"] += estimate_tokens(chunk) payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [ { "index": 0, "delta": {"content": chunk}, "finish_reason": None, } ], } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" done_payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" if include_usage: usage_payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens_holder["n"], "total_tokens": prompt_tokens + completion_tokens_holder["n"], }, } yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" success = True except asyncio.CancelledError: logger.info("chat.stream cancelled by client (inst=%s)", _inst.name) raise except Exception as exc: logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc) finally: # Persist upstream sessionId only on a clean chat/finish. # Partial streams (cancelled, timed out) leave Lingma's # session in an indeterminate state, so we must not reuse. if success and write_key: sid = _meta.get("session_id") if sid: await session_cache.put(write_key, sid, _inst.name) await stats_collector.record_chat( stream=True, success=success, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens_holder["n"], ) _inst.in_flight = max(0, _inst.in_flight - 1) _ticket.release() ticket_transferred = True return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, ) try: result = await inst.client.chat_complete( prompt, model, ask_mode, session_id=cached_session_id, is_reply=is_reply, ) except Exception as exc: logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) await stats_collector.record_chat( stream=False, success=False, prompt_tokens=prompt_tokens, completion_tokens=0, ) # If we used a cached session and the call blew up, drop it so the # next turn can start fresh instead of hitting the same dead session. if cached_session_id and lookup_key: await session_cache.invalidate(lookup_key) raise HTTPException( status_code=502, detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, ) completion_tokens = estimate_tokens(result.get("text") or "") await stats_collector.record_chat( stream=False, success=True, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) if write_key: sid = result.get("sessionId") if sid: await session_cache.put(write_key, sid, inst.name) response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), model=model, choices=[ ChatCompletionChoice( index=0, finish_reason="stop", message={"role": "assistant", "content": result.get("text") or ""}, ) ], ) data = response.model_dump() data["latency"] = { "first_token_ms": result.get("firstTokenLatencyMs"), "total_ms": result.get("totalLatencyMs"), } data["usage"] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, } data["served_by"] = inst.name return JSONResponse(content=data) finally: if not ticket_transferred: inst.in_flight = max(0, inst.in_flight - 1) ticket.release() @app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)]) async def internal_auto_login_start(instance: str | None = None): p = _require_pool() target = None if instance: for inst in p.instances: if inst.name == instance: target = inst break if target is None: raise HTTPException( status_code=404, detail={"error": {"message": f"instance not found: {instance}"}}, ) else: target = p.pick() client = target.client auto_login = target.auto_login status = await client.auth_status() if status and status.get("id"): return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status} if settings.dedicated_domain_url: try: current = await client.get_endpoint() current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc) try: login_url, _login_raw = await client.generate_login_url() except Exception as exc: logger.warning("[%s] generate_login_url failed: %s", target.name, exc) raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) if not login_url: raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) started = await auto_login.ensure_started(login_url) return { "ok": True, "state": "running" if started else "already_running", "instance": target.name, "auto_login": auto_login.status(), } @app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)]) async def internal_auto_login_status(): p = _require_pool() out = [] for inst in p.instances: try: auth = await inst.client.auth_status() except Exception as exc: auth = {"error": str(exc)} out.append( { "instance": inst.name, "auto_login": inst.auto_login.status(), "auth": auth, "state": inst.client.state, } ) return {"ok": True, "instances": out} @app.get("/internal/stats", dependencies=[Depends(auth_guard)]) async def internal_stats(): p = _require_pool() return { "ok": True, "stats": await stats_collector.snapshot(), "concurrency": chat_guard.stats(), "pool": p.stats(), "session_cache": session_cache.stats(), } @app.get("/metrics", dependencies=[Depends(metrics_auth_guard)]) async def metrics(): base = await stats_collector.prometheus_text() lines = list(chat_guard.prometheus_lines()) if pool is not None: lines.extend(pool.prometheus_lines()) lines.extend(session_cache.prometheus_lines()) extra = "\n".join(lines) + "\n" return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")