perf: session reuse for multi-turn latency
- Add SessionCache (LRU + TTL, per-API-key scoped) mapping conversation-prefix hash -> upstream Lingma sessionId. - Hash only user/system/developer turns so client-side assistant reformatting doesn't invalidate the key. - On cache hit: reuse sessionId, send only the latest user message with isReply=true, and stick the request to the instance that originally served it. - LingmaGatewayClient.chat_complete/chat_stream accept session_id/is_reply and report the real finish.sessionId via out_meta so we persist what Lingma actually allocated. - Invalidate cache on non-stream failure; skip writes on cancelled/partial streams. - Expose cache stats in /internal/stats and /metrics. - Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES / SESSION_CACHE_TTL_SEC (documented in README + .env.example). Made-with: Cursor
This commit is contained in:
165
app/session_cache.py
Normal file
165
app/session_cache.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .logging_config import get_logger
|
||||
from .openai_schema import flatten_content
|
||||
|
||||
|
||||
logger = get_logger("lingma_gateway.session")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionEntry:
|
||||
session_id: str
|
||||
created_at: float
|
||||
last_used_at: float
|
||||
hit_count: int = 0
|
||||
instance_name: str = ""
|
||||
|
||||
|
||||
def hash_user_context(messages: list[dict]) -> str:
|
||||
"""Hash the user/system/developer turns of a message list.
|
||||
|
||||
We deliberately skip `assistant`/`tool` messages because:
|
||||
- Clients may subtly reformat or trim assistant replies between turns,
|
||||
breaking exact-match keying.
|
||||
- Only the *inputs* are stable, and they're sufficient to identify a
|
||||
conversation prefix for cache lookup.
|
||||
"""
|
||||
h = hashlib.sha1()
|
||||
for m in messages:
|
||||
role = m.get("role", "")
|
||||
if role not in ("system", "user", "developer"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
text = content if isinstance(content, str) else flatten_content(content)
|
||||
h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8"))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
class SessionCache:
|
||||
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
||||
|
||||
Usage pattern:
|
||||
- On request arrival, look up `build_key(api_key, messages[:-1])`.
|
||||
If hit, reuse that sessionId and send ONLY the latest user message
|
||||
with `isReply=True` so Lingma can hit its KV cache.
|
||||
- After a successful completion, store under `build_key(api_key, messages)`.
|
||||
Since the next turn's `messages[:-1]` contains exactly the current turn's
|
||||
`messages` (in the standard OpenAI append-only pattern), the next lookup
|
||||
will hit. We never have to hash the assistant reply itself.
|
||||
|
||||
Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests
|
||||
for the same conversation may both miss briefly, but that's harmless -- the
|
||||
second write just overwrites the first, both produce valid sessionIds
|
||||
upstream.
|
||||
|
||||
Safety: entries also remember which pool instance served them, so follow-up
|
||||
turns stick to the same Lingma process (which is the only place that
|
||||
knows the session).
|
||||
"""
|
||||
|
||||
def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0):
|
||||
self.max = max(0, int(max_entries))
|
||||
self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0
|
||||
self._data: "OrderedDict[str, SessionEntry]" = OrderedDict()
|
||||
self._lock = asyncio.Lock()
|
||||
self.hit_total = 0
|
||||
self.miss_total = 0
|
||||
self.evict_total = 0
|
||||
self.expire_total = 0
|
||||
self.invalidate_total = 0
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.max > 0
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict]) -> str:
|
||||
# API key scoping prevents cross-tenant session leakage even when
|
||||
# different clients happen to produce identical histories.
|
||||
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||
return f"{key_scope}:{hash_user_context(messages)}"
|
||||
|
||||
async def get(self, key: str) -> SessionEntry | None:
|
||||
if not self.enabled:
|
||||
return None
|
||||
async with self._lock:
|
||||
entry = self._data.get(key)
|
||||
if entry is None:
|
||||
self.miss_total += 1
|
||||
return None
|
||||
if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl:
|
||||
self._data.pop(key, None)
|
||||
self.expire_total += 1
|
||||
self.miss_total += 1
|
||||
return None
|
||||
entry.last_used_at = time.monotonic()
|
||||
entry.hit_count += 1
|
||||
self._data.move_to_end(key)
|
||||
self.hit_total += 1
|
||||
return entry
|
||||
|
||||
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
||||
if not self.enabled or not session_id:
|
||||
return
|
||||
now = time.monotonic()
|
||||
async with self._lock:
|
||||
existing = self._data.get(key)
|
||||
if existing is not None:
|
||||
existing.session_id = session_id
|
||||
existing.last_used_at = now
|
||||
if instance_name:
|
||||
existing.instance_name = instance_name
|
||||
self._data.move_to_end(key)
|
||||
return
|
||||
self._data[key] = SessionEntry(
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
last_used_at=now,
|
||||
instance_name=instance_name,
|
||||
)
|
||||
while len(self._data) > self.max:
|
||||
self._data.popitem(last=False)
|
||||
self.evict_total += 1
|
||||
|
||||
async def invalidate(self, key: str) -> None:
|
||||
async with self._lock:
|
||||
if self._data.pop(key, None) is not None:
|
||||
self.invalidate_total += 1
|
||||
|
||||
def stats(self) -> dict:
|
||||
total = self.hit_total + self.miss_total
|
||||
rate = (self.hit_total / total) if total > 0 else 0.0
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"size": len(self._data),
|
||||
"max": self.max,
|
||||
"ttl_sec": self.ttl,
|
||||
"hit_total": self.hit_total,
|
||||
"miss_total": self.miss_total,
|
||||
"hit_rate": round(rate, 4),
|
||||
"evict_total": self.evict_total,
|
||||
"expire_total": self.expire_total,
|
||||
"invalidate_total": self.invalidate_total,
|
||||
}
|
||||
|
||||
def prometheus_lines(self) -> list[str]:
|
||||
return [
|
||||
"# TYPE gateway_session_cache_size gauge",
|
||||
f"gateway_session_cache_size {len(self._data)}",
|
||||
"# TYPE gateway_session_cache_hit_total counter",
|
||||
f"gateway_session_cache_hit_total {self.hit_total}",
|
||||
"# TYPE gateway_session_cache_miss_total counter",
|
||||
f"gateway_session_cache_miss_total {self.miss_total}",
|
||||
"# TYPE gateway_session_cache_evict_total counter",
|
||||
f"gateway_session_cache_evict_total {self.evict_total}",
|
||||
"# TYPE gateway_session_cache_expire_total counter",
|
||||
f"gateway_session_cache_expire_total {self.expire_total}",
|
||||
"# TYPE gateway_session_cache_invalidate_total counter",
|
||||
f"gateway_session_cache_invalidate_total {self.invalidate_total}",
|
||||
]
|
||||
Reference in New Issue
Block a user