perf: session reuse for multi-turn latency
- Add SessionCache (LRU + TTL, per-API-key scoped) mapping conversation-prefix hash -> upstream Lingma sessionId. - Hash only user/system/developer turns so client-side assistant reformatting doesn't invalidate the key. - On cache hit: reuse sessionId, send only the latest user message with isReply=true, and stick the request to the instance that originally served it. - LingmaGatewayClient.chat_complete/chat_stream accept session_id/is_reply and report the real finish.sessionId via out_meta so we persist what Lingma actually allocated. - Invalidate cache on non-stream failure; skip writes on cancelled/partial streams. - Expose cache stats in /internal/stats and /metrics. - Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES / SESSION_CACHE_TTL_SEC (documented in README + .env.example). Made-with: Cursor
This commit is contained in:
@@ -67,3 +67,11 @@ LINGMA_PASSWORD=
|
|||||||
LINGMA_ACCOUNTS=
|
LINGMA_ACCOUNTS=
|
||||||
# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning
|
# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning
|
||||||
LINGMA_INSTANCE_COUNT=
|
LINGMA_INSTANCE_COUNT=
|
||||||
|
|
||||||
|
# ==== 会话复用(多轮对话命中上游 KV cache,减少首 token 延迟) ====
|
||||||
|
# 开关(默认开)
|
||||||
|
SESSION_REUSE_ENABLED=true
|
||||||
|
# 最多缓存多少条会话 (LRU)
|
||||||
|
SESSION_CACHE_MAX_ENTRIES=256
|
||||||
|
# 会话 TTL 秒数;超时自动失效,避免 Lingma 侧早已回收还在命中
|
||||||
|
SESSION_CACHE_TTL_SEC=1800
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ cp .env.example .env
|
|||||||
- `GATEWAY_QUEUE_TIMEOUT_SEC`:排队等待超时秒数(默认 30,超过后直接 429 + `Retry-After`)
|
- `GATEWAY_QUEUE_TIMEOUT_SEC`:排队等待超时秒数(默认 30,超过后直接 429 + `Retry-After`)
|
||||||
- `LINGMA_ACCOUNTS`:多账号实例池,格式 `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号起一个独立 Lingma 子进程
|
- `LINGMA_ACCOUNTS`:多账号实例池,格式 `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号起一个独立 Lingma 子进程
|
||||||
- `LINGMA_INSTANCE_COUNT`:实例数(默认等于账号数;显式指定且不足时账号会循环复用)
|
- `LINGMA_INSTANCE_COUNT`:实例数(默认等于账号数;显式指定且不足时账号会循环复用)
|
||||||
|
- `SESSION_REUSE_ENABLED`:多轮对话复用上游 sessionId(默认 `true`)。命中时只把最新一条 user 消息发给 Lingma,命中上游 KV cache,显著降低第 2 轮及以后的首 token 延迟
|
||||||
|
- `SESSION_CACHE_MAX_ENTRIES`:会话缓存容量(LRU,默认 256)
|
||||||
|
- `SESSION_CACHE_TTL_SEC`:会话缓存 TTL 秒数(默认 1800;超时自动失效,避免复用到已被 Lingma 回收的 session)
|
||||||
|
|
||||||
### `.env` 最小必填示例
|
### `.env` 最小必填示例
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ class Settings:
|
|||||||
auto_login_max_retry: int
|
auto_login_max_retry: int
|
||||||
accounts: list[LingmaAccount] = field(default_factory=list)
|
accounts: list[LingmaAccount] = field(default_factory=list)
|
||||||
instance_count: int = 1
|
instance_count: int = 1
|
||||||
|
session_reuse_enabled: bool = True
|
||||||
|
session_cache_max_entries: int = 256
|
||||||
|
session_cache_ttl_sec: float = 1800.0
|
||||||
|
|
||||||
|
|
||||||
def _bool_env(name: str, default: bool) -> bool:
|
def _bool_env(name: str, default: bool) -> bool:
|
||||||
@@ -131,4 +134,7 @@ def load_settings() -> Settings:
|
|||||||
auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")),
|
auto_login_max_retry=int(os.getenv("AUTO_LOGIN_MAX_RETRY", "2")),
|
||||||
accounts=accounts,
|
accounts=accounts,
|
||||||
instance_count=instance_count,
|
instance_count=instance_count,
|
||||||
|
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
|
||||||
|
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
|
||||||
|
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -536,7 +536,16 @@ class LingmaGatewayClient:
|
|||||||
|
|
||||||
# ------------------------------------------------------------------ chat
|
# ------------------------------------------------------------------ chat
|
||||||
|
|
||||||
def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str):
|
def _build_payload(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_key: str,
|
||||||
|
ask_mode: str,
|
||||||
|
session_id: str,
|
||||||
|
request_id: str,
|
||||||
|
*,
|
||||||
|
is_reply: bool = False,
|
||||||
|
):
|
||||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
session_type = "developer" if ask_mode == "agent" else "chat"
|
||||||
return {
|
return {
|
||||||
"requestId": request_id,
|
"requestId": request_id,
|
||||||
@@ -546,7 +555,7 @@ class LingmaGatewayClient:
|
|||||||
"mode": ask_mode,
|
"mode": ask_mode,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"source": 1,
|
"source": 1,
|
||||||
"isReply": False,
|
"isReply": is_reply,
|
||||||
"taskDefinitionType": "system",
|
"taskDefinitionType": "system",
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
@@ -579,11 +588,21 @@ class LingmaGatewayClient:
|
|||||||
"""
|
"""
|
||||||
await self.rpc.notify("chat/ask", payload)
|
await self.rpc.notify("chat/ask", payload)
|
||||||
|
|
||||||
async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict:
|
async def chat_complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_key: str,
|
||||||
|
ask_mode: str,
|
||||||
|
*,
|
||||||
|
session_id: str | None = None,
|
||||||
|
is_reply: bool = False,
|
||||||
|
) -> dict:
|
||||||
await self.ensure_ready()
|
await self.ensure_ready()
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
session_id = str(uuid.uuid4())
|
sid = session_id or str(uuid.uuid4())
|
||||||
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
|
payload = self._build_payload(
|
||||||
|
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||||
|
)
|
||||||
self.rpc.create_stream(request_id)
|
self.rpc.create_stream(request_id)
|
||||||
try:
|
try:
|
||||||
await self._kick_chat_ask(payload)
|
await self._kick_chat_ask(payload)
|
||||||
@@ -597,16 +616,37 @@ class LingmaGatewayClient:
|
|||||||
self.rpc.pop_stream(request_id)
|
self.rpc.pop_stream(request_id)
|
||||||
finish = result.get("finish") or {}
|
finish = result.get("finish") or {}
|
||||||
result["requestId"] = request_id
|
result["requestId"] = request_id
|
||||||
result["sessionId"] = finish.get("sessionId") or session_id
|
# Prefer upstream-reported sessionId so the next turn binds to whatever
|
||||||
|
# Lingma actually allocated (sometimes differs from our hint).
|
||||||
|
result["sessionId"] = finish.get("sessionId") or sid
|
||||||
result["model"] = model_key
|
result["model"] = model_key
|
||||||
result["mode"] = ask_mode
|
result["mode"] = ask_mode
|
||||||
|
result["isReply"] = is_reply
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]:
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_key: str,
|
||||||
|
ask_mode: str,
|
||||||
|
*,
|
||||||
|
session_id: str | None = None,
|
||||||
|
is_reply: bool = False,
|
||||||
|
out_meta: dict | None = None,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""Stream `chat/answer` chunks.
|
||||||
|
|
||||||
|
If `out_meta` is provided, the final `chat/finish` payload's sessionId
|
||||||
|
(and the raw finish dict) is written into it when the stream ends or is
|
||||||
|
cancelled. This is the hook the session cache uses to record the
|
||||||
|
upstream sessionId without holding a second reference to the RPC.
|
||||||
|
"""
|
||||||
await self.ensure_ready()
|
await self.ensure_ready()
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
session_id = str(uuid.uuid4())
|
sid = session_id or str(uuid.uuid4())
|
||||||
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
|
payload = self._build_payload(
|
||||||
|
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||||
|
)
|
||||||
self.rpc.create_stream(request_id)
|
self.rpc.create_stream(request_id)
|
||||||
try:
|
try:
|
||||||
await self._kick_chat_ask(payload)
|
await self._kick_chat_ask(payload)
|
||||||
@@ -616,4 +656,14 @@ class LingmaGatewayClient:
|
|||||||
yield chunk
|
yield chunk
|
||||||
finally:
|
finally:
|
||||||
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
||||||
|
if out_meta is not None:
|
||||||
|
try:
|
||||||
|
stream_result = self.rpc.get_stream_result(request_id)
|
||||||
|
finish = stream_result.get("finish") or {}
|
||||||
|
out_meta["session_id"] = finish.get("sessionId") or sid
|
||||||
|
out_meta["finish"] = finish
|
||||||
|
out_meta["request_id"] = request_id
|
||||||
|
out_meta["chars"] = len(stream_result.get("text") or "")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
self.rpc.pop_stream(request_id)
|
self.rpc.pop_stream(request_id)
|
||||||
|
|||||||
129
app/main.py
129
app/main.py
@@ -24,6 +24,7 @@ from .openai_schema import (
|
|||||||
ModelsResponse,
|
ModelsResponse,
|
||||||
flatten_content,
|
flatten_content,
|
||||||
)
|
)
|
||||||
|
from .session_cache import SessionCache
|
||||||
from .stats import StatsCollector, estimate_tokens
|
from .stats import StatsCollector, estimate_tokens
|
||||||
|
|
||||||
|
|
||||||
@@ -37,6 +38,10 @@ chat_guard = InFlightGuard(
|
|||||||
max_in_flight=settings.gateway_max_in_flight,
|
max_in_flight=settings.gateway_max_in_flight,
|
||||||
queue_timeout_sec=settings.gateway_queue_timeout_sec,
|
queue_timeout_sec=settings.gateway_queue_timeout_sec,
|
||||||
)
|
)
|
||||||
|
session_cache = SessionCache(
|
||||||
|
max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0,
|
||||||
|
ttl_sec=settings.session_cache_ttl_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _require_pool() -> LingmaPool:
|
def _require_pool() -> LingmaPool:
|
||||||
@@ -228,6 +233,27 @@ def _affinity_key_for(req: ChatCompletionsRequest) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_api_key(request: Request) -> str:
|
||||||
|
h = request.headers.get("authorization", "")
|
||||||
|
if h.lower().startswith("bearer "):
|
||||||
|
return h[7:].strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _last_user_text(messages: list[dict]) -> str:
|
||||||
|
"""Extract the text of the latest user message (trailing from end).
|
||||||
|
|
||||||
|
Used when we hit the session cache and only need to send the delta.
|
||||||
|
Falls back to the last message regardless of role if no user is found.
|
||||||
|
"""
|
||||||
|
for m in reversed(messages):
|
||||||
|
if m.get("role") == "user":
|
||||||
|
return flatten_content(m.get("content")) or ""
|
||||||
|
if messages:
|
||||||
|
return flatten_content(messages[-1].get("content")) or ""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
|
@app.get("/v1/models", dependencies=[Depends(auth_guard)])
|
||||||
async def v1_models():
|
async def v1_models():
|
||||||
p = _require_pool()
|
p = _require_pool()
|
||||||
@@ -261,10 +287,56 @@ def _include_usage(stream_options: dict | None) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
||||||
async def v1_chat_completions(req: ChatCompletionsRequest):
|
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||||
p = _require_pool()
|
p = _require_pool()
|
||||||
affinity = _affinity_key_for(req)
|
|
||||||
|
messages_dump = [m.model_dump() for m in req.messages]
|
||||||
|
api_key = _extract_api_key(request) or "-"
|
||||||
|
|
||||||
|
# ------------------------------------------------------------- session reuse
|
||||||
|
# Look up the "conversation prefix" (everything except the latest user turn)
|
||||||
|
# in the session cache. A hit lets us:
|
||||||
|
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
|
||||||
|
# 2. Send only the new user message instead of the whole history.
|
||||||
|
# 3. Stick the request to the pool instance that originally served it.
|
||||||
|
ask_mode = settings.default_ask_mode
|
||||||
|
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||||
|
ask_mode = "agent"
|
||||||
|
|
||||||
|
reuse_eligible = (
|
||||||
|
session_cache.enabled
|
||||||
|
and ask_mode == "chat"
|
||||||
|
and len(messages_dump) >= 2
|
||||||
|
)
|
||||||
|
lookup_key: str | None = None
|
||||||
|
write_key: str | None = None
|
||||||
|
cached_session_id: str | None = None
|
||||||
|
cached_instance_name: str | None = None
|
||||||
|
if reuse_eligible:
|
||||||
|
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||||
|
write_key = session_cache.build_key(api_key, messages_dump)
|
||||||
|
entry = await session_cache.get(lookup_key)
|
||||||
|
if entry is not None:
|
||||||
|
cached_session_id = entry.session_id
|
||||||
|
cached_instance_name = entry.instance_name or None
|
||||||
|
|
||||||
|
# Instance selection: prefer cached instance for continuity, else normal affinity.
|
||||||
|
affinity = cached_instance_name or _affinity_key_for(req)
|
||||||
inst = p.pick(affinity_key=affinity)
|
inst = p.pick(affinity_key=affinity)
|
||||||
|
|
||||||
|
# If cache pointed at a specific instance that's no longer healthy, we already
|
||||||
|
# fell back via pool.pick -> drop the cached session since Lingma on a
|
||||||
|
# different process won't know about it.
|
||||||
|
if cached_instance_name and inst.name != cached_instance_name:
|
||||||
|
logger.info(
|
||||||
|
"session cache instance %s unhealthy, falling back to %s (dropping cached session)",
|
||||||
|
cached_instance_name,
|
||||||
|
inst.name,
|
||||||
|
)
|
||||||
|
cached_session_id = None
|
||||||
|
if lookup_key:
|
||||||
|
await session_cache.invalidate(lookup_key)
|
||||||
|
|
||||||
await _ensure_instance_logged_in(inst)
|
await _ensure_instance_logged_in(inst)
|
||||||
|
|
||||||
models = await inst.client.query_models()
|
models = await inst.client.query_models()
|
||||||
@@ -272,11 +344,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
name_map = build_model_name_map(models)
|
name_map = build_model_name_map(models)
|
||||||
model = resolve_model(req.model, available, settings.default_model, name_map)
|
model = resolve_model(req.model, available, settings.default_model, name_map)
|
||||||
|
|
||||||
ask_mode = settings.default_ask_mode
|
# Prompt construction: on cache hit send only the last user turn so Lingma's
|
||||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
# stored context isn't duplicated.
|
||||||
ask_mode = "agent"
|
if cached_session_id:
|
||||||
|
prompt = _last_user_text(messages_dump)
|
||||||
|
is_reply = True
|
||||||
|
else:
|
||||||
|
prompt = _messages_to_prompt(messages_dump)
|
||||||
|
is_reply = False
|
||||||
|
|
||||||
prompt = _messages_to_prompt([m.model_dump() for m in req.messages])
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
@@ -306,12 +382,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
|
|
||||||
inst.in_flight += 1
|
inst.in_flight += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d",
|
"chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
|
||||||
inst.name,
|
inst.name,
|
||||||
model,
|
model,
|
||||||
ask_mode,
|
ask_mode,
|
||||||
req.stream,
|
req.stream,
|
||||||
prompt_tokens,
|
prompt_tokens,
|
||||||
|
bool(cached_session_id),
|
||||||
extra={
|
extra={
|
||||||
"ctx_instance": inst.name,
|
"ctx_instance": inst.name,
|
||||||
"ctx_model": model,
|
"ctx_model": model,
|
||||||
@@ -320,6 +397,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
"ctx_prompt_tokens": prompt_tokens,
|
"ctx_prompt_tokens": prompt_tokens,
|
||||||
"ctx_in_flight": chat_guard.in_flight,
|
"ctx_in_flight": chat_guard.in_flight,
|
||||||
"ctx_affinity": affinity,
|
"ctx_affinity": affinity,
|
||||||
|
"ctx_session_reuse": bool(cached_session_id),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -330,11 +408,19 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||||
completion_tokens_holder = {"n": 0}
|
completion_tokens_holder = {"n": 0}
|
||||||
|
stream_meta: dict = {}
|
||||||
|
|
||||||
async def event_stream(_ticket=ticket, _inst=inst):
|
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||||
success = False
|
success = False
|
||||||
try:
|
try:
|
||||||
async for chunk in _inst.client.chat_stream(prompt, model, ask_mode):
|
async for chunk in _inst.client.chat_stream(
|
||||||
|
prompt,
|
||||||
|
model,
|
||||||
|
ask_mode,
|
||||||
|
session_id=cached_session_id,
|
||||||
|
is_reply=is_reply,
|
||||||
|
out_meta=_meta,
|
||||||
|
):
|
||||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
||||||
payload = {
|
payload = {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
@@ -383,6 +469,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
|
logger.warning("chat.stream error (inst=%s): %s", _inst.name, exc)
|
||||||
finally:
|
finally:
|
||||||
|
# Persist upstream sessionId only on a clean chat/finish.
|
||||||
|
# Partial streams (cancelled, timed out) leave Lingma's
|
||||||
|
# session in an indeterminate state, so we must not reuse.
|
||||||
|
if success and write_key:
|
||||||
|
sid = _meta.get("session_id")
|
||||||
|
if sid:
|
||||||
|
await session_cache.put(write_key, sid, _inst.name)
|
||||||
await stats_collector.record_chat(
|
await stats_collector.record_chat(
|
||||||
stream=True,
|
stream=True,
|
||||||
success=success,
|
success=success,
|
||||||
@@ -404,7 +497,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await inst.client.chat_complete(prompt, model, ask_mode)
|
result = await inst.client.chat_complete(
|
||||||
|
prompt,
|
||||||
|
model,
|
||||||
|
ask_mode,
|
||||||
|
session_id=cached_session_id,
|
||||||
|
is_reply=is_reply,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
||||||
await stats_collector.record_chat(
|
await stats_collector.record_chat(
|
||||||
@@ -413,6 +512,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=0,
|
completion_tokens=0,
|
||||||
)
|
)
|
||||||
|
# If we used a cached session and the call blew up, drop it so the
|
||||||
|
# next turn can start fresh instead of hitting the same dead session.
|
||||||
|
if cached_session_id and lookup_key:
|
||||||
|
await session_cache.invalidate(lookup_key)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=502,
|
status_code=502,
|
||||||
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
|
detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}},
|
||||||
@@ -425,6 +528,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest):
|
|||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
)
|
)
|
||||||
|
if write_key:
|
||||||
|
sid = result.get("sessionId")
|
||||||
|
if sid:
|
||||||
|
await session_cache.put(write_key, sid, inst.name)
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -534,6 +641,7 @@ async def internal_stats():
|
|||||||
"stats": await stats_collector.snapshot(),
|
"stats": await stats_collector.snapshot(),
|
||||||
"concurrency": chat_guard.stats(),
|
"concurrency": chat_guard.stats(),
|
||||||
"pool": p.stats(),
|
"pool": p.stats(),
|
||||||
|
"session_cache": session_cache.stats(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -543,5 +651,6 @@ async def metrics():
|
|||||||
lines = list(chat_guard.prometheus_lines())
|
lines = list(chat_guard.prometheus_lines())
|
||||||
if pool is not None:
|
if pool is not None:
|
||||||
lines.extend(pool.prometheus_lines())
|
lines.extend(pool.prometheus_lines())
|
||||||
|
lines.extend(session_cache.prometheus_lines())
|
||||||
extra = "\n".join(lines) + "\n"
|
extra = "\n".join(lines) + "\n"
|
||||||
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")
|
return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4")
|
||||||
|
|||||||
165
app/session_cache.py
Normal file
165
app/session_cache.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .logging_config import get_logger
|
||||||
|
from .openai_schema import flatten_content
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger("lingma_gateway.session")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionEntry:
|
||||||
|
session_id: str
|
||||||
|
created_at: float
|
||||||
|
last_used_at: float
|
||||||
|
hit_count: int = 0
|
||||||
|
instance_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def hash_user_context(messages: list[dict]) -> str:
|
||||||
|
"""Hash the user/system/developer turns of a message list.
|
||||||
|
|
||||||
|
We deliberately skip `assistant`/`tool` messages because:
|
||||||
|
- Clients may subtly reformat or trim assistant replies between turns,
|
||||||
|
breaking exact-match keying.
|
||||||
|
- Only the *inputs* are stable, and they're sufficient to identify a
|
||||||
|
conversation prefix for cache lookup.
|
||||||
|
"""
|
||||||
|
h = hashlib.sha1()
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role", "")
|
||||||
|
if role not in ("system", "user", "developer"):
|
||||||
|
continue
|
||||||
|
content = m.get("content")
|
||||||
|
text = content if isinstance(content, str) else flatten_content(content)
|
||||||
|
h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8"))
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCache:
|
||||||
|
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
||||||
|
|
||||||
|
Usage pattern:
|
||||||
|
- On request arrival, look up `build_key(api_key, messages[:-1])`.
|
||||||
|
If hit, reuse that sessionId and send ONLY the latest user message
|
||||||
|
with `isReply=True` so Lingma can hit its KV cache.
|
||||||
|
- After a successful completion, store under `build_key(api_key, messages)`.
|
||||||
|
Since the next turn's `messages[:-1]` contains exactly the current turn's
|
||||||
|
`messages` (in the standard OpenAI append-only pattern), the next lookup
|
||||||
|
will hit. We never have to hash the assistant reply itself.
|
||||||
|
|
||||||
|
Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests
|
||||||
|
for the same conversation may both miss briefly, but that's harmless -- the
|
||||||
|
second write just overwrites the first, both produce valid sessionIds
|
||||||
|
upstream.
|
||||||
|
|
||||||
|
Safety: entries also remember which pool instance served them, so follow-up
|
||||||
|
turns stick to the same Lingma process (which is the only place that
|
||||||
|
knows the session).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0):
|
||||||
|
self.max = max(0, int(max_entries))
|
||||||
|
self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0
|
||||||
|
self._data: "OrderedDict[str, SessionEntry]" = OrderedDict()
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self.hit_total = 0
|
||||||
|
self.miss_total = 0
|
||||||
|
self.evict_total = 0
|
||||||
|
self.expire_total = 0
|
||||||
|
self.invalidate_total = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled(self) -> bool:
|
||||||
|
return self.max > 0
|
||||||
|
|
||||||
|
def build_key(self, api_key: str, messages: list[dict]) -> str:
|
||||||
|
# API key scoping prevents cross-tenant session leakage even when
|
||||||
|
# different clients happen to produce identical histories.
|
||||||
|
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||||
|
return f"{key_scope}:{hash_user_context(messages)}"
|
||||||
|
|
||||||
|
async def get(self, key: str) -> SessionEntry | None:
|
||||||
|
if not self.enabled:
|
||||||
|
return None
|
||||||
|
async with self._lock:
|
||||||
|
entry = self._data.get(key)
|
||||||
|
if entry is None:
|
||||||
|
self.miss_total += 1
|
||||||
|
return None
|
||||||
|
if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl:
|
||||||
|
self._data.pop(key, None)
|
||||||
|
self.expire_total += 1
|
||||||
|
self.miss_total += 1
|
||||||
|
return None
|
||||||
|
entry.last_used_at = time.monotonic()
|
||||||
|
entry.hit_count += 1
|
||||||
|
self._data.move_to_end(key)
|
||||||
|
self.hit_total += 1
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
||||||
|
if not self.enabled or not session_id:
|
||||||
|
return
|
||||||
|
now = time.monotonic()
|
||||||
|
async with self._lock:
|
||||||
|
existing = self._data.get(key)
|
||||||
|
if existing is not None:
|
||||||
|
existing.session_id = session_id
|
||||||
|
existing.last_used_at = now
|
||||||
|
if instance_name:
|
||||||
|
existing.instance_name = instance_name
|
||||||
|
self._data.move_to_end(key)
|
||||||
|
return
|
||||||
|
self._data[key] = SessionEntry(
|
||||||
|
session_id=session_id,
|
||||||
|
created_at=now,
|
||||||
|
last_used_at=now,
|
||||||
|
instance_name=instance_name,
|
||||||
|
)
|
||||||
|
while len(self._data) > self.max:
|
||||||
|
self._data.popitem(last=False)
|
||||||
|
self.evict_total += 1
|
||||||
|
|
||||||
|
async def invalidate(self, key: str) -> None:
|
||||||
|
async with self._lock:
|
||||||
|
if self._data.pop(key, None) is not None:
|
||||||
|
self.invalidate_total += 1
|
||||||
|
|
||||||
|
def stats(self) -> dict:
|
||||||
|
total = self.hit_total + self.miss_total
|
||||||
|
rate = (self.hit_total / total) if total > 0 else 0.0
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"size": len(self._data),
|
||||||
|
"max": self.max,
|
||||||
|
"ttl_sec": self.ttl,
|
||||||
|
"hit_total": self.hit_total,
|
||||||
|
"miss_total": self.miss_total,
|
||||||
|
"hit_rate": round(rate, 4),
|
||||||
|
"evict_total": self.evict_total,
|
||||||
|
"expire_total": self.expire_total,
|
||||||
|
"invalidate_total": self.invalidate_total,
|
||||||
|
}
|
||||||
|
|
||||||
|
def prometheus_lines(self) -> list[str]:
|
||||||
|
return [
|
||||||
|
"# TYPE gateway_session_cache_size gauge",
|
||||||
|
f"gateway_session_cache_size {len(self._data)}",
|
||||||
|
"# TYPE gateway_session_cache_hit_total counter",
|
||||||
|
f"gateway_session_cache_hit_total {self.hit_total}",
|
||||||
|
"# TYPE gateway_session_cache_miss_total counter",
|
||||||
|
f"gateway_session_cache_miss_total {self.miss_total}",
|
||||||
|
"# TYPE gateway_session_cache_evict_total counter",
|
||||||
|
f"gateway_session_cache_evict_total {self.evict_total}",
|
||||||
|
"# TYPE gateway_session_cache_expire_total counter",
|
||||||
|
f"gateway_session_cache_expire_total {self.expire_total}",
|
||||||
|
"# TYPE gateway_session_cache_invalidate_total counter",
|
||||||
|
f"gateway_session_cache_invalidate_total {self.invalidate_total}",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user