perf: session reuse for multi-turn latency

- Add SessionCache (LRU + TTL, per-API-key scoped) mapping
  conversation-prefix hash -> upstream Lingma sessionId.
- Hash only user/system/developer turns so client-side
  assistant reformatting doesn't invalidate the key.
- On cache hit: reuse sessionId, send only the latest user
  message with isReply=true, and stick the request to the
  instance that originally served it.
- LingmaGatewayClient.chat_complete/chat_stream accept
  session_id/is_reply and report the real finish.sessionId
  via out_meta so we persist what Lingma actually allocated.
- Invalidate cache on non-stream failure; skip writes on
  cancelled/partial streams.
- Expose cache stats in /internal/stats and /metrics.
- Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES
  / SESSION_CACHE_TTL_SEC (documented in README + .env.example).

Made-with: Cursor
This commit is contained in:
GitHub Actions
2026-04-18 08:10:39 +08:00
parent d209d8ac0b
commit dfdb7087dc
6 changed files with 360 additions and 19 deletions

View File

@@ -24,6 +24,7 @@ from .openai_schema import (
ModelsResponse,
flatten_content,
)
from .session_cache import SessionCache
from .stats import StatsCollector, estimate_tokens
@@ -37,6 +38,10 @@ chat_guard = InFlightGuard(
max_in_flight=settings.gateway_max_in_flight,
queue_timeout_sec=settings.gateway_queue_timeout_sec,
)
session_cache = SessionCache(
max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0,
ttl_sec=settings.session_cache_ttl_sec,
)
def _require_pool() -> LingmaPool:
@@ -228,6 +233,27 @@ def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
return None
def _extract_api_key(request: Request) -> str:
h = request.headers.get("authorization", "")
if h.lower().startswith("bearer "):
return h[7:].strip()
return ""
def _last_user_text(messages: list[dict]) -> str:
"""Extract the text of the latest user message (trailing from end).
Used when we hit the session cache and only need to send the delta.
Falls back to the last message regardless of role if no user is found.
"""
for m in reversed(messages):
if m.get("role") == "user":
return flatten_content(m.get("content")) or ""
if messages:
return flatten_content(messages[-1].get("content")) or ""
return ""
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
async def v1_models():
p = _require_pool()
@@ -261,10 +287,56 @@ def _include_usage(stream_options: dict | None) -> bool:
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest):
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
p = _require_pool()
affinity = _affinity_key_for(req)
messages_dump = [m.model_dump() for m in req.messages]
api_key = _extract_api_key(request) or "-"
# ------------------------------------------------------------- session reuse
# Look up the "conversation prefix" (everything except the latest user turn)
# in the session cache. A hit lets us:
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
# 2. Send only the new user message instead of the whole history.
# 3. Stick the request to the pool instance that originally served it.
ask_mode = settings.default_ask_mode
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
reuse_eligible = (
session_cache.enabled
and ask_mode == "chat"
and len(messages_dump) >= 2
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
cached_instance_name = entry.instance_name or None
# Instance selection: prefer cached instance for continuity, else normal affinity.
affinity = cached_instance_name or _affinity_key_for(req)
inst = p.pick(affinity_key=affinity)
# If cache pointed at a specific instance that's no longer healthy, we already
# fell back via pool.pick -> drop the cached session since Lingma on a
# different process won't know about it.
if cached_instance_name and inst.name != cached_instance_name:
logger.info(
"session cache instance %s unhealthy, falling back to %s (dropping cached session)",
cached_instance_name,
inst.name,
)
cached_session_id = None
if lookup_key:
await session_cache.invalidate(lookup_key)
await _ensure_instance_logged_in(inst)
models = await inst.client.query_models()
@@ -272,11 +344,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
name_map = build_model_name_map(models)
model = resolve_model(req.model, available, settings.default_model, name_map)
ask_mode = settings.default_ask_mode
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
# Prompt construction: on cache hit send only the last user turn so Lingma's
# stored context isn't duplicated.
if cached_session_id:
prompt = _last_user_text(messages_dump)
is_reply = True
else:
prompt = _messages_to_prompt(messages_dump)
is_reply = False
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
if not prompt:
raise HTTPException(
status_code=400,
@@ -306,12 +382,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
inst.in_flight += 1
logger.info(
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
inst.name,
model,
ask_mode,
req.stream,
prompt_tokens,
bool(cached_session_id),
extra={
"ctx_instance": inst.name,
"ctx_model": model,
@@ -320,6 +397,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": affinity,
"ctx_session_reuse": bool(cached_session_id),
},
)
@@ -330,11 +408,19 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
stream_meta: dict = {}
async def event_stream(_ticket=ticket, _inst=inst):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
try:
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
async for chunk in _inst.client.chat_stream(
prompt,
model,
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
out_meta=_meta,
):
completion_tokens_holder["n"] += estimate_tokens(chunk)
payload = {
"id": completion_id,
@@ -383,6 +469,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
except Exception as exc:
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
finally:
# Persist upstream sessionId only on a clean chat/finish.
# Partial streams (cancelled, timed out) leave Lingma's
# session in an indeterminate state, so we must not reuse.
if success and write_key:
sid = _meta.get("session_id")
if sid:
await session_cache.put(write_key, sid, _inst.name)
await stats_collector.record_chat(
stream=True,
success=success,
@@ -404,7 +497,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
)
try:
result = await inst.client.chat_complete(prompt, model, ask_mode)
result = await inst.client.chat_complete(
prompt,
model,
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
await stats_collector.record_chat(
@@ -413,6 +512,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
# If we used a cached session and the call blew up, drop it so the
# next turn can start fresh instead of hitting the same dead session.
if cached_session_id and lookup_key:
await session_cache.invalidate(lookup_key)
raise HTTPException(
status_code=502,
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
@@ -425,6 +528,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if write_key:
sid = result.get("sessionId")
if sid:
await session_cache.put(write_key, sid, inst.name)
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
@@ -534,6 +641,7 @@ async def internal_stats():
"stats": await stats_collector.snapshot(),
"concurrency": chat_guard.stats(),
"pool": p.stats(),
"session_cache": session_cache.stats(),
}
@@ -543,5 +651,6 @@ async def metrics():
lines = list(chat_guard.prometheus_lines())
if pool is not None:
lines.extend(pool.prometheus_lines())
lines.extend(session_cache.prometheus_lines())
extra = "\n".join(lines) + "\n"
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")