diff --git a/.env.example b/.env.example index 152502d..4606ce8 100644 --- a/.env.example +++ b/.env.example @@ -67,3 +67,11 @@ LINGMA_PASSWORD= LINGMA_ACCOUNTS= # 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning LINGMA_INSTANCE_COUNT= + +# ==== 会话复用(多轮对话命中上游 KV cache,减少首 token 延迟) ==== +# 开关(默认开) +SESSION_REUSE_ENABLED=true +# 最多缓存多少条会话 (LRU) +SESSION_CACHE_MAX_ENTRIES=256 +# 会话 TTL 秒数;超时自动失效,避免 Lingma 侧早已回收还在命中 +SESSION_CACHE_TTL_SEC=1800 diff --git a/README.md b/README.md index 68d3d5c..4529e09 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,9 @@ cp .env.example .env - `GATEWAY_QUEUE_TIMEOUT_SEC`:排队等待超时秒数(默认 30,超过后直接 429 + `Retry-After`) - `LINGMA_ACCOUNTS`:多账号实例池,格式 `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号起一个独立 Lingma 子进程 - `LINGMA_INSTANCE_COUNT`:实例数(默认等于账号数;显式指定且不足时账号会循环复用) +- `SESSION_REUSE_ENABLED`:多轮对话复用上游 sessionId(默认 `true`)。命中时只把最新一条 user 消息发给 Lingma,命中上游 KV cache,显著降低第 2 轮及以后的首 token 延迟 +- `SESSION_CACHE_MAX_ENTRIES`:会话缓存容量(LRU,默认 256) +- `SESSION_CACHE_TTL_SEC`:会话缓存 TTL 秒数(默认 1800;超时自动失效,避免复用到已被 Lingma 回收的 session) ### `.env` 最小必填示例 diff --git a/app/config.py b/app/config.py index 5b96094..15852c9 100644 --- a/app/config.py +++ b/app/config.py @@ -34,6 +34,9 @@ class Settings: auto_login_max_retry: int accounts: list[LingmaAccount] = field(default_factory=list) instance_count: int = 1 + session_reuse_enabled: bool = True + session_cache_max_entries: int = 256 + session_cache_ttl_sec: float = 1800.0 def _bool_env(name: str, default: bool) -> bool: @@ -131,4 +134,7 @@ def load_settings() -> Settings: auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")), accounts=accounts, instance_count=instance_count, + session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True), + session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")), + session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")), ) diff --git a/app/lingma_client.py b/app/lingma_client.py index b5fe213..13f293e 100644 --- a/app/lingma_client.py +++ b/app/lingma_client.py @@ -536,7 +536,16 @@ class LingmaGatewayClient: # ------------------------------------------------------------------ chat - def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str): + def _build_payload( + self, + prompt: str, + model_key: str, + ask_mode: str, + session_id: str, + request_id: str, + *, + is_reply: bool = False, + ): session_type = "developer" if ask_mode == "agent" else "chat" return { "requestId": request_id, @@ -546,7 +555,7 @@ class LingmaGatewayClient: "mode": ask_mode, "stream": True, "source": 1, - "isReply": False, + "isReply": is_reply, "taskDefinitionType": "system", "content": prompt, "text": prompt, @@ -579,11 +588,21 @@ class LingmaGatewayClient: """ await self.rpc.notify("chat/ask", payload) - async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict: + async def chat_complete( + self, + prompt: str, + model_key: str, + ask_mode: str, + *, + session_id: str | None = None, + is_reply: bool = False, + ) -> dict: await self.ensure_ready() request_id = str(uuid.uuid4()) - session_id = str(uuid.uuid4()) - payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id) + sid = session_id or str(uuid.uuid4()) + payload = self._build_payload( + prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply + ) self.rpc.create_stream(request_id) try: await self._kick_chat_ask(payload) @@ -597,16 +616,37 @@ class LingmaGatewayClient: self.rpc.pop_stream(request_id) finish = result.get("finish") or {} result["requestId"] = request_id - result["sessionId"] = finish.get("sessionId") or session_id + # Prefer upstream-reported sessionId so the next turn binds to whatever + # Lingma actually allocated (sometimes differs from our hint). + result["sessionId"] = finish.get("sessionId") or sid result["model"] = model_key result["mode"] = ask_mode + result["isReply"] = is_reply return result - async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]: + async def chat_stream( + self, + prompt: str, + model_key: str, + ask_mode: str, + *, + session_id: str | None = None, + is_reply: bool = False, + out_meta: dict | None = None, + ) -> AsyncIterator[str]: + """Stream `chat/answer` chunks. + + If `out_meta` is provided, the final `chat/finish` payload's sessionId + (and the raw finish dict) is written into it when the stream ends or is + cancelled. This is the hook the session cache uses to record the + upstream sessionId without holding a second reference to the RPC. + """ await self.ensure_ready() request_id = str(uuid.uuid4()) - session_id = str(uuid.uuid4()) - payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id) + sid = session_id or str(uuid.uuid4()) + payload = self._build_payload( + prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply + ) self.rpc.create_stream(request_id) try: await self._kick_chat_ask(payload) @@ -616,4 +656,14 @@ class LingmaGatewayClient: yield chunk finally: # Runs on normal completion, exception, or consumer GeneratorExit (client disconnect). + if out_meta is not None: + try: + stream_result = self.rpc.get_stream_result(request_id) + finish = stream_result.get("finish") or {} + out_meta["session_id"] = finish.get("sessionId") or sid + out_meta["finish"] = finish + out_meta["request_id"] = request_id + out_meta["chars"] = len(stream_result.get("text") or "") + except Exception: + pass self.rpc.pop_stream(request_id) diff --git a/app/main.py b/app/main.py index 218c26e..7526bb2 100644 --- a/app/main.py +++ b/app/main.py @@ -24,6 +24,7 @@ from .openai_schema import ( ModelsResponse, flatten_content, ) +from .session_cache import SessionCache from .stats import StatsCollector, estimate_tokens @@ -37,6 +38,10 @@ chat_guard = InFlightGuard( max_in_flight=settings.gateway_max_in_flight, queue_timeout_sec=settings.gateway_queue_timeout_sec, ) +session_cache = SessionCache( + max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0, + ttl_sec=settings.session_cache_ttl_sec, +) def _require_pool() -> LingmaPool: @@ -228,6 +233,27 @@ def _affinity_key_for(req: ChatCompletionsRequest) -> str | None: return None +def _extract_api_key(request: Request) -> str: + h = request.headers.get("authorization", "") + if h.lower().startswith("bearer "): + return h[7:].strip() + return "" + + +def _last_user_text(messages: list[dict]) -> str: + """Extract the text of the latest user message (trailing from end). + + Used when we hit the session cache and only need to send the delta. + Falls back to the last message regardless of role if no user is found. + """ + for m in reversed(messages): + if m.get("role") == "user": + return flatten_content(m.get("content")) or "" + if messages: + return flatten_content(messages[-1].get("content")) or "" + return "" + + @app.get("/v1/models", dependencies=[Depends(auth_guard)]) async def v1_models(): p = _require_pool() @@ -261,10 +287,56 @@ def _include_usage(stream_options: dict | None) -> bool: @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) -async def v1_chat_completions(req: ChatCompletionsRequest): +async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): p = _require_pool() - affinity = _affinity_key_for(req) + + messages_dump = [m.model_dump() for m in req.messages] + api_key = _extract_api_key(request) or "-" + + # ------------------------------------------------------------- session reuse + # Look up the "conversation prefix" (everything except the latest user turn) + # in the session cache. A hit lets us: + # 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache. + # 2. Send only the new user message instead of the whole history. + # 3. Stick the request to the pool instance that originally served it. + ask_mode = settings.default_ask_mode + if req.model.lower() in {"lingma-agent", "agent"}: + ask_mode = "agent" + + reuse_eligible = ( + session_cache.enabled + and ask_mode == "chat" + and len(messages_dump) >= 2 + ) + lookup_key: str | None = None + write_key: str | None = None + cached_session_id: str | None = None + cached_instance_name: str | None = None + if reuse_eligible: + lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) + write_key = session_cache.build_key(api_key, messages_dump) + entry = await session_cache.get(lookup_key) + if entry is not None: + cached_session_id = entry.session_id + cached_instance_name = entry.instance_name or None + + # Instance selection: prefer cached instance for continuity, else normal affinity. + affinity = cached_instance_name or _affinity_key_for(req) inst = p.pick(affinity_key=affinity) + + # If cache pointed at a specific instance that's no longer healthy, we already + # fell back via pool.pick -> drop the cached session since Lingma on a + # different process won't know about it. + if cached_instance_name and inst.name != cached_instance_name: + logger.info( + "session cache instance %s unhealthy, falling back to %s (dropping cached session)", + cached_instance_name, + inst.name, + ) + cached_session_id = None + if lookup_key: + await session_cache.invalidate(lookup_key) + await _ensure_instance_logged_in(inst) models = await inst.client.query_models() @@ -272,11 +344,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest): name_map = build_model_name_map(models) model = resolve_model(req.model, available, settings.default_model, name_map) - ask_mode = settings.default_ask_mode - if req.model.lower() in {"lingma-agent", "agent"}: - ask_mode = "agent" + # Prompt construction: on cache hit send only the last user turn so Lingma's + # stored context isn't duplicated. + if cached_session_id: + prompt = _last_user_text(messages_dump) + is_reply = True + else: + prompt = _messages_to_prompt(messages_dump) + is_reply = False - prompt = _messages_to_prompt([m.model_dump() for m in req.messages]) if not prompt: raise HTTPException( status_code=400, @@ -306,12 +382,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest): inst.in_flight += 1 logger.info( - "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d", + "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s", inst.name, model, ask_mode, req.stream, prompt_tokens, + bool(cached_session_id), extra={ "ctx_instance": inst.name, "ctx_model": model, @@ -320,6 +397,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest): "ctx_prompt_tokens": prompt_tokens, "ctx_in_flight": chat_guard.in_flight, "ctx_affinity": affinity, + "ctx_session_reuse": bool(cached_session_id), }, ) @@ -330,11 +408,19 @@ async def v1_chat_completions(req: ChatCompletionsRequest): created = int(time.time()) completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_tokens_holder = {"n": 0} + stream_meta: dict = {} - async def event_stream(_ticket=ticket, _inst=inst): + async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False try: - async for chunk in _inst.client.chat_stream(prompt, model, ask_mode): + async for chunk in _inst.client.chat_stream( + prompt, + model, + ask_mode, + session_id=cached_session_id, + is_reply=is_reply, + out_meta=_meta, + ): completion_tokens_holder["n"] += estimate_tokens(chunk) payload = { "id": completion_id, @@ -383,6 +469,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest): except Exception as exc: logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc) finally: + # Persist upstream sessionId only on a clean chat/finish. + # Partial streams (cancelled, timed out) leave Lingma's + # session in an indeterminate state, so we must not reuse. + if success and write_key: + sid = _meta.get("session_id") + if sid: + await session_cache.put(write_key, sid, _inst.name) await stats_collector.record_chat( stream=True, success=success, @@ -404,7 +497,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest): ) try: - result = await inst.client.chat_complete(prompt, model, ask_mode) + result = await inst.client.chat_complete( + prompt, + model, + ask_mode, + session_id=cached_session_id, + is_reply=is_reply, + ) except Exception as exc: logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) await stats_collector.record_chat( @@ -413,6 +512,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest): prompt_tokens=prompt_tokens, completion_tokens=0, ) + # If we used a cached session and the call blew up, drop it so the + # next turn can start fresh instead of hitting the same dead session. + if cached_session_id and lookup_key: + await session_cache.invalidate(lookup_key) raise HTTPException( status_code=502, detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, @@ -425,6 +528,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest): prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) + if write_key: + sid = result.get("sessionId") + if sid: + await session_cache.put(write_key, sid, inst.name) response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), @@ -534,6 +641,7 @@ async def internal_stats(): "stats": await stats_collector.snapshot(), "concurrency": chat_guard.stats(), "pool": p.stats(), + "session_cache": session_cache.stats(), } @@ -543,5 +651,6 @@ async def metrics(): lines = list(chat_guard.prometheus_lines()) if pool is not None: lines.extend(pool.prometheus_lines()) + lines.extend(session_cache.prometheus_lines()) extra = "\n".join(lines) + "\n" return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4") diff --git a/app/session_cache.py b/app/session_cache.py new file mode 100644 index 0000000..46d9778 --- /dev/null +++ b/app/session_cache.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import asyncio +import hashlib +import time +from collections import OrderedDict +from dataclasses import dataclass + +from .logging_config import get_logger +from .openai_schema import flatten_content + + +logger = get_logger("lingma_gateway.session") + + +@dataclass +class SessionEntry: + session_id: str + created_at: float + last_used_at: float + hit_count: int = 0 + instance_name: str = "" + + +def hash_user_context(messages: list[dict]) -> str: + """Hash the user/system/developer turns of a message list. + + We deliberately skip `assistant`/`tool` messages because: + - Clients may subtly reformat or trim assistant replies between turns, + breaking exact-match keying. + - Only the *inputs* are stable, and they're sufficient to identify a + conversation prefix for cache lookup. + """ + h = hashlib.sha1() + for m in messages: + role = m.get("role", "") + if role not in ("system", "user", "developer"): + continue + content = m.get("content") + text = content if isinstance(content, str) else flatten_content(content) + h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8")) + return h.hexdigest() + + +class SessionCache: + """LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId. + + Usage pattern: + - On request arrival, look up `build_key(api_key, messages[:-1])`. + If hit, reuse that sessionId and send ONLY the latest user message + with `isReply=True` so Lingma can hit its KV cache. + - After a successful completion, store under `build_key(api_key, messages)`. + Since the next turn's `messages[:-1]` contains exactly the current turn's + `messages` (in the standard OpenAI append-only pattern), the next lookup + will hit. We never have to hash the assistant reply itself. + + Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests + for the same conversation may both miss briefly, but that's harmless -- the + second write just overwrites the first, both produce valid sessionIds + upstream. + + Safety: entries also remember which pool instance served them, so follow-up + turns stick to the same Lingma process (which is the only place that + knows the session). + """ + + def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0): + self.max = max(0, int(max_entries)) + self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0 + self._data: "OrderedDict[str, SessionEntry]" = OrderedDict() + self._lock = asyncio.Lock() + self.hit_total = 0 + self.miss_total = 0 + self.evict_total = 0 + self.expire_total = 0 + self.invalidate_total = 0 + + @property + def enabled(self) -> bool: + return self.max > 0 + + def build_key(self, api_key: str, messages: list[dict]) -> str: + # API key scoping prevents cross-tenant session leakage even when + # different clients happen to produce identical histories. + key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12] + return f"{key_scope}:{hash_user_context(messages)}" + + async def get(self, key: str) -> SessionEntry | None: + if not self.enabled: + return None + async with self._lock: + entry = self._data.get(key) + if entry is None: + self.miss_total += 1 + return None + if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl: + self._data.pop(key, None) + self.expire_total += 1 + self.miss_total += 1 + return None + entry.last_used_at = time.monotonic() + entry.hit_count += 1 + self._data.move_to_end(key) + self.hit_total += 1 + return entry + + async def put(self, key: str, session_id: str, instance_name: str = "") -> None: + if not self.enabled or not session_id: + return + now = time.monotonic() + async with self._lock: + existing = self._data.get(key) + if existing is not None: + existing.session_id = session_id + existing.last_used_at = now + if instance_name: + existing.instance_name = instance_name + self._data.move_to_end(key) + return + self._data[key] = SessionEntry( + session_id=session_id, + created_at=now, + last_used_at=now, + instance_name=instance_name, + ) + while len(self._data) > self.max: + self._data.popitem(last=False) + self.evict_total += 1 + + async def invalidate(self, key: str) -> None: + async with self._lock: + if self._data.pop(key, None) is not None: + self.invalidate_total += 1 + + def stats(self) -> dict: + total = self.hit_total + self.miss_total + rate = (self.hit_total / total) if total > 0 else 0.0 + return { + "enabled": self.enabled, + "size": len(self._data), + "max": self.max, + "ttl_sec": self.ttl, + "hit_total": self.hit_total, + "miss_total": self.miss_total, + "hit_rate": round(rate, 4), + "evict_total": self.evict_total, + "expire_total": self.expire_total, + "invalidate_total": self.invalidate_total, + } + + def prometheus_lines(self) -> list[str]: + return [ + "# TYPE gateway_session_cache_size gauge", + f"gateway_session_cache_size {len(self._data)}", + "# TYPE gateway_session_cache_hit_total counter", + f"gateway_session_cache_hit_total {self.hit_total}", + "# TYPE gateway_session_cache_miss_total counter", + f"gateway_session_cache_miss_total {self.miss_total}", + "# TYPE gateway_session_cache_evict_total counter", + f"gateway_session_cache_evict_total {self.evict_total}", + "# TYPE gateway_session_cache_expire_total counter", + f"gateway_session_cache_expire_total {self.expire_total}", + "# TYPE gateway_session_cache_invalidate_total counter", + f"gateway_session_cache_invalidate_total {self.invalidate_total}", + ]