perf: session reuse for multi-turn latency

- Add SessionCache (LRU + TTL, per-API-key scoped) mapping
  conversation-prefix hash -> upstream Lingma sessionId.
- Hash only user/system/developer turns so client-side
  assistant reformatting doesn't invalidate the key.
- On cache hit: reuse sessionId, send only the latest user
  message with isReply=true, and stick the request to the
  instance that originally served it.
- LingmaGatewayClient.chat_complete/chat_stream accept
  session_id/is_reply and report the real finish.sessionId
  via out_meta so we persist what Lingma actually allocated.
- Invalidate cache on non-stream failure; skip writes on
  cancelled/partial streams.
- Expose cache stats in /internal/stats and /metrics.
- Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES
  / SESSION_CACHE_TTL_SEC (documented in README + .env.example).

Made-with: Cursor
This commit is contained in:
GitHub Actions
2026-04-18 08:10:39 +08:00
parent d209d8ac0b
commit dfdb7087dc
6 changed files with 360 additions and 19 deletions

View File

@@ -67,3 +67,11 @@ LINGMA_PASSWORD=
LINGMA_ACCOUNTS=
# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning
LINGMA_INSTANCE_COUNT=
# ==== 会话复用(多轮对话命中上游 KV cache减少首 token 延迟) ====
# 开关(默认开)
SESSION_REUSE_ENABLED=true
# 最多缓存多少条会话 (LRU)
SESSION_CACHE_MAX_ENTRIES=256
# 会话 TTL 秒数;超时自动失效,避免 Lingma 侧早已回收还在命中
SESSION_CACHE_TTL_SEC=1800

View File

@@ -70,6 +70,9 @@ cp .env.example .env
- `GATEWAY_QUEUE_TIMEOUT_SEC`:排队等待超时秒数(默认 30超过后直接 429 + `Retry-After`
- `LINGMA_ACCOUNTS`:多账号实例池,格式 `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号起一个独立 Lingma 子进程
- `LINGMA_INSTANCE_COUNT`:实例数(默认等于账号数;显式指定且不足时账号会循环复用)
- `SESSION_REUSE_ENABLED`:多轮对话复用上游 sessionId默认 `true`)。命中时只把最新一条 user 消息发给 Lingma命中上游 KV cache显著降低第 2 轮及以后的首 token 延迟
- `SESSION_CACHE_MAX_ENTRIES`会话缓存容量LRU默认 256
- `SESSION_CACHE_TTL_SEC`:会话缓存 TTL 秒数(默认 1800超时自动失效避免复用到已被 Lingma 回收的 session
### `.env` 最小必填示例

View File

@@ -34,6 +34,9 @@ class Settings:
auto_login_max_retry: int
accounts: list[LingmaAccount] = field(default_factory=list)
instance_count: int = 1
session_reuse_enabled: bool = True
session_cache_max_entries: int = 256
session_cache_ttl_sec: float = 1800.0
def _bool_env(name: str, default: bool) -> bool:
@@ -131,4 +134,7 @@ def load_settings() -> Settings:
auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")),
accounts=accounts,
instance_count=instance_count,
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
)

View File

@@ -536,7 +536,16 @@ class LingmaGatewayClient:
# ------------------------------------------------------------------ chat
def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str):
def _build_payload(
self,
prompt: str,
model_key: str,
ask_mode: str,
session_id: str,
request_id: str,
*,
is_reply: bool = False,
):
session_type = "developer" if ask_mode == "agent" else "chat"
return {
"requestId": request_id,
@@ -546,7 +555,7 @@ class LingmaGatewayClient:
"mode": ask_mode,
"stream": True,
"source": 1,
"isReply": False,
"isReply": is_reply,
"taskDefinitionType": "system",
"content": prompt,
"text": prompt,
@@ -579,11 +588,21 @@ class LingmaGatewayClient:
"""
await self.rpc.notify("chat/ask", payload)
async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict:
async def chat_complete(
self,
prompt: str,
model_key: str,
ask_mode: str,
*,
session_id: str | None = None,
is_reply: bool = False,
) -> dict:
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
)
self.rpc.create_stream(request_id)
try:
await self._kick_chat_ask(payload)
@@ -597,16 +616,37 @@ class LingmaGatewayClient:
self.rpc.pop_stream(request_id)
finish = result.get("finish") or {}
result["requestId"] = request_id
result["sessionId"] = finish.get("sessionId") or session_id
# Prefer upstream-reported sessionId so the next turn binds to whatever
# Lingma actually allocated (sometimes differs from our hint).
result["sessionId"] = finish.get("sessionId") or sid
result["model"] = model_key
result["mode"] = ask_mode
result["isReply"] = is_reply
return result
async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]:
async def chat_stream(
self,
prompt: str,
model_key: str,
ask_mode: str,
*,
session_id: str | None = None,
is_reply: bool = False,
out_meta: dict | None = None,
) -> AsyncIterator[str]:
"""Stream `chat/answer` chunks.
If `out_meta` is provided, the final `chat/finish` payload's sessionId
(and the raw finish dict) is written into it when the stream ends or is
cancelled. This is the hook the session cache uses to record the
upstream sessionId without holding a second reference to the RPC.
"""
await self.ensure_ready()
request_id = str(uuid.uuid4())
session_id = str(uuid.uuid4())
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
)
self.rpc.create_stream(request_id)
try:
await self._kick_chat_ask(payload)
@@ -616,4 +656,14 @@ class LingmaGatewayClient:
yield chunk
finally:
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
if out_meta is not None:
try:
stream_result = self.rpc.get_stream_result(request_id)
finish = stream_result.get("finish") or {}
out_meta["session_id"] = finish.get("sessionId") or sid
out_meta["finish"] = finish
out_meta["request_id"] = request_id
out_meta["chars"] = len(stream_result.get("text") or "")
except Exception:
pass
self.rpc.pop_stream(request_id)

View File

@@ -24,6 +24,7 @@ from .openai_schema import (
ModelsResponse,
flatten_content,
)
from .session_cache import SessionCache
from .stats import StatsCollector, estimate_tokens
@@ -37,6 +38,10 @@ chat_guard = InFlightGuard(
max_in_flight=settings.gateway_max_in_flight,
queue_timeout_sec=settings.gateway_queue_timeout_sec,
)
session_cache = SessionCache(
max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0,
ttl_sec=settings.session_cache_ttl_sec,
)
def _require_pool() -> LingmaPool:
@@ -228,6 +233,27 @@ def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
return None
def _extract_api_key(request: Request) -> str:
h = request.headers.get("authorization", "")
if h.lower().startswith("bearer "):
return h[7:].strip()
return ""
def _last_user_text(messages: list[dict]) -> str:
"""Extract the text of the latest user message (trailing from end).
Used when we hit the session cache and only need to send the delta.
Falls back to the last message regardless of role if no user is found.
"""
for m in reversed(messages):
if m.get("role") == "user":
return flatten_content(m.get("content")) or ""
if messages:
return flatten_content(messages[-1].get("content")) or ""
return ""
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
async def v1_models():
p = _require_pool()
@@ -261,10 +287,56 @@ def _include_usage(stream_options: dict | None) -> bool:
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest):
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
p = _require_pool()
affinity = _affinity_key_for(req)
messages_dump = [m.model_dump() for m in req.messages]
api_key = _extract_api_key(request) or "-"
# ------------------------------------------------------------- session reuse
# Look up the "conversation prefix" (everything except the latest user turn)
# in the session cache. A hit lets us:
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
# 2. Send only the new user message instead of the whole history.
# 3. Stick the request to the pool instance that originally served it.
ask_mode = settings.default_ask_mode
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
reuse_eligible = (
session_cache.enabled
and ask_mode == "chat"
and len(messages_dump) >= 2
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
cached_instance_name = entry.instance_name or None
# Instance selection: prefer cached instance for continuity, else normal affinity.
affinity = cached_instance_name or _affinity_key_for(req)
inst = p.pick(affinity_key=affinity)
# If cache pointed at a specific instance that's no longer healthy, we already
# fell back via pool.pick -> drop the cached session since Lingma on a
# different process won't know about it.
if cached_instance_name and inst.name != cached_instance_name:
logger.info(
"session cache instance %s unhealthy, falling back to %s (dropping cached session)",
cached_instance_name,
inst.name,
)
cached_session_id = None
if lookup_key:
await session_cache.invalidate(lookup_key)
await _ensure_instance_logged_in(inst)
models = await inst.client.query_models()
@@ -272,11 +344,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
name_map = build_model_name_map(models)
model = resolve_model(req.model, available, settings.default_model, name_map)
ask_mode = settings.default_ask_mode
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
# Prompt construction: on cache hit send only the last user turn so Lingma's
# stored context isn't duplicated.
if cached_session_id:
prompt = _last_user_text(messages_dump)
is_reply = True
else:
prompt = _messages_to_prompt(messages_dump)
is_reply = False
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
if not prompt:
raise HTTPException(
status_code=400,
@@ -306,12 +382,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
inst.in_flight += 1
logger.info(
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
inst.name,
model,
ask_mode,
req.stream,
prompt_tokens,
bool(cached_session_id),
extra={
"ctx_instance": inst.name,
"ctx_model": model,
@@ -320,6 +397,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": affinity,
"ctx_session_reuse": bool(cached_session_id),
},
)
@@ -330,11 +408,19 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
stream_meta: dict = {}
async def event_stream(_ticket=ticket, _inst=inst):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
try:
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
async for chunk in _inst.client.chat_stream(
prompt,
model,
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
out_meta=_meta,
):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
"id": completion_id,
@@ -383,6 +469,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
except Exception as exc:
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
finally:
# Persist upstream sessionId only on a clean chat/finish.
# Partial streams (cancelled, timed out) leave Lingma's
# session in an indeterminate state, so we must not reuse.
if success and write_key:
sid = _meta.get("session_id")
if sid:
await session_cache.put(write_key, sid, _inst.name)
await stats_collector.record_chat(
stream=True,
success=success,
@@ -404,7 +497,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
)
try:
result = await inst.client.chat_complete(prompt, model, ask_mode)
result = await inst.client.chat_complete(
prompt,
model,
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
await stats_collector.record_chat(
@@ -413,6 +512,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
# If we used a cached session and the call blew up, drop it so the
# next turn can start fresh instead of hitting the same dead session.
if cached_session_id and lookup_key:
await session_cache.invalidate(lookup_key)
raise HTTPException(
status_code=502,
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
@@ -425,6 +528,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if write_key:
sid = result.get("sessionId")
if sid:
await session_cache.put(write_key, sid, inst.name)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
@@ -534,6 +641,7 @@ async def internal_stats():
"stats": await stats_collector.snapshot(),
"concurrency": chat_guard.stats(),
"pool": p.stats(),
"session_cache": session_cache.stats(),
}
@@ -543,5 +651,6 @@ async def metrics():
lines = list(chat_guard.prometheus_lines())
if pool is not None:
lines.extend(pool.prometheus_lines())
lines.extend(session_cache.prometheus_lines())
extra = "\n".join(lines) + "\n"
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")

165
app/session_cache.py Normal file
View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import asyncio
import hashlib
import time
from collections import OrderedDict
from dataclasses import dataclass
from .logging_config import get_logger
from .openai_schema import flatten_content
logger = get_logger("lingma_gateway.session")
@dataclass
class SessionEntry:
session_id: str
created_at: float
last_used_at: float
hit_count: int = 0
instance_name: str = ""
def hash_user_context(messages: list[dict]) -> str:
"""Hash the user/system/developer turns of a message list.
We deliberately skip `assistant`/`tool` messages because:
- Clients may subtly reformat or trim assistant replies between turns,
breaking exact-match keying.
- Only the *inputs* are stable, and they're sufficient to identify a
conversation prefix for cache lookup.
"""
h = hashlib.sha1()
for m in messages:
role = m.get("role", "")
if role not in ("system", "user", "developer"):
continue
content = m.get("content")
text = content if isinstance(content, str) else flatten_content(content)
h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8"))
return h.hexdigest()
class SessionCache:
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
Usage pattern:
- On request arrival, look up `build_key(api_key, messages[:-1])`.
If hit, reuse that sessionId and send ONLY the latest user message
with `isReply=True` so Lingma can hit its KV cache.
- After a successful completion, store under `build_key(api_key, messages)`.
Since the next turn's `messages[:-1]` contains exactly the current turn's
`messages` (in the standard OpenAI append-only pattern), the next lookup
will hit. We never have to hash the assistant reply itself.
Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests
for the same conversation may both miss briefly, but that's harmless -- the
second write just overwrites the first, both produce valid sessionIds
upstream.
Safety: entries also remember which pool instance served them, so follow-up
turns stick to the same Lingma process (which is the only place that
knows the session).
"""
def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0):
self.max = max(0, int(max_entries))
self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0
self._data: "OrderedDict[str, SessionEntry]" = OrderedDict()
self._lock = asyncio.Lock()
self.hit_total = 0
self.miss_total = 0
self.evict_total = 0
self.expire_total = 0
self.invalidate_total = 0
@property
def enabled(self) -> bool:
return self.max > 0
def build_key(self, api_key: str, messages: list[dict]) -> str:
# API key scoping prevents cross-tenant session leakage even when
# different clients happen to produce identical histories.
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
return f"{key_scope}:{hash_user_context(messages)}"
async def get(self, key: str) -> SessionEntry | None:
if not self.enabled:
return None
async with self._lock:
entry = self._data.get(key)
if entry is None:
self.miss_total += 1
return None
if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl:
self._data.pop(key, None)
self.expire_total += 1
self.miss_total += 1
return None
entry.last_used_at = time.monotonic()
entry.hit_count += 1
self._data.move_to_end(key)
self.hit_total += 1
return entry
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
if not self.enabled or not session_id:
return
now = time.monotonic()
async with self._lock:
existing = self._data.get(key)
if existing is not None:
existing.session_id = session_id
existing.last_used_at = now
if instance_name:
existing.instance_name = instance_name
self._data.move_to_end(key)
return
self._data[key] = SessionEntry(
session_id=session_id,
created_at=now,
last_used_at=now,
instance_name=instance_name,
)
while len(self._data) > self.max:
self._data.popitem(last=False)
self.evict_total += 1
async def invalidate(self, key: str) -> None:
async with self._lock:
if self._data.pop(key, None) is not None:
self.invalidate_total += 1
def stats(self) -> dict:
total = self.hit_total + self.miss_total
rate = (self.hit_total / total) if total > 0 else 0.0
return {
"enabled": self.enabled,
"size": len(self._data),
"max": self.max,
"ttl_sec": self.ttl,
"hit_total": self.hit_total,
"miss_total": self.miss_total,
"hit_rate": round(rate, 4),
"evict_total": self.evict_total,
"expire_total": self.expire_total,
"invalidate_total": self.invalidate_total,
}
def prometheus_lines(self) -> list[str]:
return [
"# TYPE gateway_session_cache_size gauge",
f"gateway_session_cache_size {len(self._data)}",
"# TYPE gateway_session_cache_hit_total counter",
f"gateway_session_cache_hit_total {self.hit_total}",
"# TYPE gateway_session_cache_miss_total counter",
f"gateway_session_cache_miss_total {self.miss_total}",
"# TYPE gateway_session_cache_evict_total counter",
f"gateway_session_cache_evict_total {self.evict_total}",
"# TYPE gateway_session_cache_expire_total counter",
f"gateway_session_cache_expire_total {self.expire_total}",
"# TYPE gateway_session_cache_invalidate_total counter",
f"gateway_session_cache_invalidate_total {self.invalidate_total}",
]