from __future__ import annotations import asyncio import hashlib import json import time import uuid from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from .auth import require_bearer, require_metrics_access from .concurrency import BackpressureRejected, InFlightGuard from .config import Settings, load_settings from .lingma_pool import LingmaPool, PoolInstance from .logging_config import configure_logging, get_logger, request_id_var from .model_map import build_model_name_map, flatten_model_keys, resolve_model from .openai_schema import ( ChatCompletionChoice, ChatCompletionResponse, ChatCompletionsRequest, ModelData, ModelsResponse, flatten_content, ) from .stats import StatsCollector, estimate_tokens settings: Settings = load_settings() configure_logging(settings.log_level) logger = get_logger("lingma_gateway") pool: LingmaPool | None = None stats_collector = StatsCollector() chat_guard = InFlightGuard( max_in_flight=settings.gateway_max_in_flight, queue_timeout_sec=settings.gateway_queue_timeout_sec, ) def _require_pool() -> LingmaPool: if pool is None: raise HTTPException( status_code=503, detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}}, ) return pool @asynccontextmanager async def lifespan(_app: FastAPI): global pool pool = LingmaPool.build( lingma_bin=settings.lingma_bin, base_work_dir=settings.lingma_work_dir, legacy_socket_port=settings.lingma_socket_port, startup_timeout=settings.lingma_startup_timeout, rpc_timeout=settings.lingma_rpc_timeout, default_model=settings.default_model, default_ask_mode=settings.default_ask_mode, accounts=settings.accounts, instance_count=settings.instance_count, auto_login_headless=settings.auto_login_headless, auto_login_timeout=settings.auto_login_timeout, auto_login_max_retry=settings.auto_login_max_retry, ) logger.info( "gateway startup: pool_size=%d max_in_flight=%d", pool.size(), settings.gateway_max_in_flight, ) await pool.start() try: yield finally: if pool is not None: await pool.close() app = FastAPI(title="Lingma OpenAI Gateway", version="0.3.0", lifespan=lifespan) @app.middleware("http") async def request_id_middleware(request: Request, call_next): req_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex[:12]}" token = request_id_var.set(req_id) start = time.monotonic() status_code = 500 try: response = await call_next(request) status_code = response.status_code response.headers["x-request-id"] = req_id return response finally: elapsed_ms = int((time.monotonic() - start) * 1000) logger.info( "http %s %s -> %s in %dms", request.method, request.url.path, status_code, elapsed_ms, extra={ "ctx_method": request.method, "ctx_path": request.url.path, "ctx_status": status_code, "ctx_elapsed_ms": elapsed_ms, }, ) request_id_var.reset(token) def auth_guard(request: Request): require_bearer(request, settings.api_keys) def metrics_auth_guard(request: Request): require_metrics_access(request, settings.api_keys, settings.metrics_token) @app.get("/healthz") async def healthz(): if pool is None: return {"ok": False, "time": int(time.time()), "reason": "pool uninitialized"} insts = pool.stats() ready = sum(1 for i in insts if i["state"] == "ready") return { "ok": ready > 0, "time": int(time.time()), "pool_size": len(insts), "pool_ready": ready, "instances": [ {"name": i["name"], "state": i["state"], "in_flight": i["in_flight"]} for i in insts ], } async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: client = inst.client auto_login = inst.auto_login try: status = await client.auth_status() except Exception as exc: logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc) raise HTTPException( status_code=503, detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}}, ) if status and status.get("id"): return status if not settings.auto_login_enabled: raise HTTPException( status_code=401, detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}}, ) if settings.dedicated_domain_url: try: current = await client.get_endpoint() current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: logger.warning("[%s] switch dedicated endpoint failed: %s", inst.name, exc) try: login_url, _login_raw = await client.generate_login_url() except Exception as exc: logger.warning("[%s] generate_login_url failed: %s", inst.name, exc) raise HTTPException( status_code=502, detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, ) if not login_url: raise HTTPException( status_code=502, detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, ) await auto_login.ensure_started(login_url) try: await auto_login.wait_done(timeout=settings.auto_login_timeout + 20) except Exception as exc: logger.warning("[%s] auto_login wait_done failed: %s", inst.name, exc) try: status = await client.auth_status() except Exception as exc: logger.warning("[%s] post-login auth_status failed: %s", inst.name, exc) status = None if status and status.get("id"): return status logger.warning( "[%s] auto login did not result in a logged-in session: %s", inst.name, auto_login.status(), ) raise HTTPException( status_code=401, detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}}, ) def _affinity_key_for(req: ChatCompletionsRequest) -> str | None: """Derive a stable affinity key so that follow-ups go to the same instance. Priority: explicit `user` > hash of the first/system message. """ if req.user: return req.user.strip() or None for m in req.messages: if m.role == "system": text = flatten_content(m.content) if text: return "sys:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] if req.messages: first = req.messages[0] text = flatten_content(first.content) if text: return "first:" + hashlib.sha1(text.encode("utf-8")).hexdigest()[:16] return None @app.get("/v1/models", dependencies=[Depends(auth_guard)]) async def v1_models(): p = _require_pool() inst = p.pick() await _ensure_instance_logged_in(inst) await stats_collector.inc_models() models = await inst.client.query_models() keys = flatten_model_keys(models) name_map = build_model_name_map(models) resp = ModelsResponse(data=[ModelData(id=k, name=name_map.get(k)) for k in keys]) return JSONResponse(content=resp.model_dump()) def _messages_to_prompt(messages: list[dict]) -> str: parts: list[str] = [] for m in messages: role = m.get("role", "user") text = flatten_content(m.get("content")) if not text and m.get("tool_calls"): text = f"[tool_calls] {json.dumps(m['tool_calls'], ensure_ascii=False)}" if not text: continue parts.append(f"[{role}] {text}") return "\n".join(parts).strip() def _include_usage(stream_options: dict | None) -> bool: if not isinstance(stream_options, dict): return False return bool(stream_options.get("include_usage")) @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) async def v1_chat_completions(req: ChatCompletionsRequest): p = _require_pool() affinity = _affinity_key_for(req) inst = p.pick(affinity_key=affinity) await _ensure_instance_logged_in(inst) models = await inst.client.query_models() available = flatten_model_keys(models) name_map = build_model_name_map(models) model = resolve_model(req.model, available, settings.default_model, name_map) ask_mode = settings.default_ask_mode if req.model.lower() in {"lingma-agent", "agent"}: ask_mode = "agent" prompt = _messages_to_prompt([m.model_dump() for m in req.messages]) if not prompt: raise HTTPException( status_code=400, detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}}, ) prompt_tokens = estimate_tokens(prompt) include_usage = _include_usage(req.stream_options) # Backpressure: acquire a slot *after* the cheap validation but before any # upstream call. This ensures we reject quickly when saturated. try: ticket = await chat_guard.try_acquire() except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after) raise HTTPException( status_code=429, detail={ "error": { "message": "Too many in-flight requests, please retry later", "type": "rate_limit_error", "code": "backpressure", } }, headers={"Retry-After": str(retry_after)}, ) inst.in_flight += 1 logger.info( "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d", inst.name, model, ask_mode, req.stream, prompt_tokens, extra={ "ctx_instance": inst.name, "ctx_model": model, "ctx_ask_mode": ask_mode, "ctx_stream": req.stream, "ctx_prompt_tokens": prompt_tokens, "ctx_in_flight": chat_guard.in_flight, "ctx_affinity": affinity, }, ) ticket_transferred = False try: if req.stream: created = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_tokens_holder = {"n": 0} async def event_stream(_ticket=ticket, _inst=inst): success = False try: async for chunk in _inst.client.chat_stream(prompt, model, ask_mode): completion_tokens_holder["n"] += estimate_tokens(chunk) payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [ { "index": 0, "delta": {"content": chunk}, "finish_reason": None, } ], } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" done_payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" if include_usage: usage_payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": model, "choices": [], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens_holder["n"], "total_tokens": prompt_tokens + completion_tokens_holder["n"], }, } yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" success = True except asyncio.CancelledError: logger.info("chat.stream cancelled by client (inst=%s)", _inst.name) raise except Exception as exc: logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc) finally: await stats_collector.record_chat( stream=True, success=success, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens_holder["n"], ) _inst.in_flight = max(0, _inst.in_flight - 1) _ticket.release() ticket_transferred = True return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, ) try: result = await inst.client.chat_complete(prompt, model, ask_mode) except Exception as exc: logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) await stats_collector.record_chat( stream=False, success=False, prompt_tokens=prompt_tokens, completion_tokens=0, ) raise HTTPException( status_code=502, detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, ) completion_tokens = estimate_tokens(result.get("text") or "") await stats_collector.record_chat( stream=False, success=True, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), model=model, choices=[ ChatCompletionChoice( index=0, finish_reason="stop", message={"role": "assistant", "content": result.get("text") or ""}, ) ], ) data = response.model_dump() data["latency"] = { "first_token_ms": result.get("firstTokenLatencyMs"), "total_ms": result.get("totalLatencyMs"), } data["usage"] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, } data["served_by"] = inst.name return JSONResponse(content=data) finally: if not ticket_transferred: inst.in_flight = max(0, inst.in_flight - 1) ticket.release() @app.post("/internal/auto-login/start", dependencies=[Depends(auth_guard)]) async def internal_auto_login_start(instance: str | None = None): p = _require_pool() target = None if instance: for inst in p.instances: if inst.name == instance: target = inst break if target is None: raise HTTPException( status_code=404, detail={"error": {"message": f"instance not found: {instance}"}}, ) else: target = p.pick() client = target.client auto_login = target.auto_login status = await client.auth_status() if status and status.get("id"): return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status} if settings.dedicated_domain_url: try: current = await client.get_endpoint() current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc) try: login_url, _login_raw = await client.generate_login_url() except Exception as exc: logger.warning("[%s] generate_login_url failed: %s", target.name, exc) raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) if not login_url: raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) started = await auto_login.ensure_started(login_url) return { "ok": True, "state": "running" if started else "already_running", "instance": target.name, "auto_login": auto_login.status(), } @app.get("/internal/auto-login/status", dependencies=[Depends(auth_guard)]) async def internal_auto_login_status(): p = _require_pool() out = [] for inst in p.instances: try: auth = await inst.client.auth_status() except Exception as exc: auth = {"error": str(exc)} out.append( { "instance": inst.name, "auto_login": inst.auto_login.status(), "auth": auth, "state": inst.client.state, } ) return {"ok": True, "instances": out} @app.get("/internal/stats", dependencies=[Depends(auth_guard)]) async def internal_stats(): p = _require_pool() return { "ok": True, "stats": await stats_collector.snapshot(), "concurrency": chat_guard.stats(), "pool": p.stats(), } @app.get("/metrics", dependencies=[Depends(metrics_auth_guard)]) async def metrics(): base = await stats_collector.prometheus_text() lines = list(chat_guard.prometheus_lines()) if pool is not None: lines.extend(pool.prometheus_lines()) extra = "\n".join(lines) + "\n" return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")