perf: session reuse for multi-turn latency
- Add SessionCache (LRU + TTL, per-API-key scoped) mapping conversation-prefix hash -> upstream Lingma sessionId. - Hash only user/system/developer turns so client-side assistant reformatting doesn't invalidate the key. - On cache hit: reuse sessionId, send only the latest user message with isReply=true, and stick the request to the instance that originally served it. - LingmaGatewayClient.chat_complete/chat_stream accept session_id/is_reply and report the real finish.sessionId via out_meta so we persist what Lingma actually allocated. - Invalidate cache on non-stream failure; skip writes on cancelled/partial streams. - Expose cache stats in /internal/stats and /metrics. - Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES / SESSION_CACHE_TTL_SEC (documented in README + .env.example). Made-with: Cursor
This commit is contained in:
129
app/main.py
129
app/main.py
@@ -24,6 +24,7 @@ from .openai_schema import (
|
||||
ModelsResponse,
|
||||
flatten_content,
|
||||
)
|
||||
from .session_cache import SessionCache
|
||||
from .stats import StatsCollector, estimate_tokens
|
||||
|
||||
|
||||
@@ -37,6 +38,10 @@ chat_guard = InFlightGuard(
|
||||
max_in_flight=settings.gateway_max_in_flight,
|
||||
queue_timeout_sec=settings.gateway_queue_timeout_sec,
|
||||
)
|
||||
session_cache = SessionCache(
|
||||
max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0,
|
||||
ttl_sec=settings.session_cache_ttl_sec,
|
||||
)
|
||||
|
||||
|
||||
def _require_pool() -> LingmaPool:
|
||||
@@ -228,6 +233,27 @@ def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_api_key(request: Request) -> str:
|
||||
h = request.headers.get("authorization", "")
|
||||
if h.lower().startswith("bearer "):
|
||||
return h[7:].strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _last_user_text(messages: list[dict]) -> str:
|
||||
"""Extract the text of the latest user message (trailing from end).
|
||||
|
||||
Used when we hit the session cache and only need to send the delta.
|
||||
Falls back to the last message regardless of role if no user is found.
|
||||
"""
|
||||
for m in reversed(messages):
|
||||
if m.get("role") == "user":
|
||||
return flatten_content(m.get("content")) or ""
|
||||
if messages:
|
||||
return flatten_content(messages[-1].get("content")) or ""
|
||||
return ""
|
||||
|
||||
|
||||
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
|
||||
async def v1_models():
|
||||
p = _require_pool()
|
||||
@@ -261,10 +287,56 @@ def _include_usage(stream_options: dict | None) -> bool:
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
||||
async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
p = _require_pool()
|
||||
affinity = _affinity_key_for(req)
|
||||
|
||||
messages_dump = [m.model_dump() for m in req.messages]
|
||||
api_key = _extract_api_key(request) or "-"
|
||||
|
||||
# ------------------------------------------------------------- session reuse
|
||||
# Look up the "conversation prefix" (everything except the latest user turn)
|
||||
# in the session cache. A hit lets us:
|
||||
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
|
||||
# 2. Send only the new user message instead of the whole history.
|
||||
# 3. Stick the request to the pool instance that originally served it.
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||
ask_mode = "agent"
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled
|
||||
and ask_mode == "chat"
|
||||
and len(messages_dump) >= 2
|
||||
)
|
||||
lookup_key: str | None = None
|
||||
write_key: str | None = None
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||
write_key = session_cache.build_key(api_key, messages_dump)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
cached_instance_name = entry.instance_name or None
|
||||
|
||||
# Instance selection: prefer cached instance for continuity, else normal affinity.
|
||||
affinity = cached_instance_name or _affinity_key_for(req)
|
||||
inst = p.pick(affinity_key=affinity)
|
||||
|
||||
# If cache pointed at a specific instance that's no longer healthy, we already
|
||||
# fell back via pool.pick -> drop the cached session since Lingma on a
|
||||
# different process won't know about it.
|
||||
if cached_instance_name and inst.name != cached_instance_name:
|
||||
logger.info(
|
||||
"session cache instance %s unhealthy, falling back to %s (dropping cached session)",
|
||||
cached_instance_name,
|
||||
inst.name,
|
||||
)
|
||||
cached_session_id = None
|
||||
if lookup_key:
|
||||
await session_cache.invalidate(lookup_key)
|
||||
|
||||
await _ensure_instance_logged_in(inst)
|
||||
|
||||
models = await inst.client.query_models()
|
||||
@@ -272,11 +344,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
name_map = build_model_name_map(models)
|
||||
model = resolve_model(req.model, available, settings.default_model, name_map)
|
||||
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||
ask_mode = "agent"
|
||||
# Prompt construction: on cache hit send only the last user turn so Lingma's
|
||||
# stored context isn't duplicated.
|
||||
if cached_session_id:
|
||||
prompt = _last_user_text(messages_dump)
|
||||
is_reply = True
|
||||
else:
|
||||
prompt = _messages_to_prompt(messages_dump)
|
||||
is_reply = False
|
||||
|
||||
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
|
||||
if not prompt:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -306,12 +382,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
|
||||
inst.in_flight += 1
|
||||
logger.info(
|
||||
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
|
||||
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
|
||||
inst.name,
|
||||
model,
|
||||
ask_mode,
|
||||
req.stream,
|
||||
prompt_tokens,
|
||||
bool(cached_session_id),
|
||||
extra={
|
||||
"ctx_instance": inst.name,
|
||||
"ctx_model": model,
|
||||
@@ -320,6 +397,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
"ctx_prompt_tokens": prompt_tokens,
|
||||
"ctx_in_flight": chat_guard.in_flight,
|
||||
"ctx_affinity": affinity,
|
||||
"ctx_session_reuse": bool(cached_session_id),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -330,11 +408,19 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
created = int(time.time())
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
completion_tokens_holder = {"n": 0}
|
||||
stream_meta: dict = {}
|
||||
|
||||
async def event_stream(_ticket=ticket, _inst=inst):
|
||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||
success = False
|
||||
try:
|
||||
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
|
||||
async for chunk in _inst.client.chat_stream(
|
||||
prompt,
|
||||
model,
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
out_meta=_meta,
|
||||
):
|
||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
@@ -383,6 +469,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
except Exception as exc:
|
||||
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
|
||||
finally:
|
||||
# Persist upstream sessionId only on a clean chat/finish.
|
||||
# Partial streams (cancelled, timed out) leave Lingma's
|
||||
# session in an indeterminate state, so we must not reuse.
|
||||
if success and write_key:
|
||||
sid = _meta.get("session_id")
|
||||
if sid:
|
||||
await session_cache.put(write_key, sid, _inst.name)
|
||||
await stats_collector.record_chat(
|
||||
stream=True,
|
||||
success=success,
|
||||
@@ -404,7 +497,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
)
|
||||
|
||||
try:
|
||||
result = await inst.client.chat_complete(prompt, model, ask_mode)
|
||||
result = await inst.client.chat_complete(
|
||||
prompt,
|
||||
model,
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
||||
await stats_collector.record_chat(
|
||||
@@ -413,6 +512,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
)
|
||||
# If we used a cached session and the call blew up, drop it so the
|
||||
# next turn can start fresh instead of hitting the same dead session.
|
||||
if cached_session_id and lookup_key:
|
||||
await session_cache.invalidate(lookup_key)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
|
||||
@@ -425,6 +528,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if write_key:
|
||||
sid = result.get("sessionId")
|
||||
if sid:
|
||||
await session_cache.put(write_key, sid, inst.name)
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||
created=int(time.time()),
|
||||
@@ -534,6 +641,7 @@ async def internal_stats():
|
||||
"stats": await stats_collector.snapshot(),
|
||||
"concurrency": chat_guard.stats(),
|
||||
"pool": p.stats(),
|
||||
"session_cache": session_cache.stats(),
|
||||
}
|
||||
|
||||
|
||||
@@ -543,5 +651,6 @@ async def metrics():
|
||||
lines = list(chat_guard.prometheus_lines())
|
||||
if pool is not None:
|
||||
lines.extend(pool.prometheus_lines())
|
||||
lines.extend(session_cache.prometheus_lines())
|
||||
extra = "\n".join(lines) + "\n"
|
||||
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")
|
||||
|
||||
Reference in New Issue
Block a user