from __future__ import annotations import asyncio import hashlib import time from collections import OrderedDict from dataclasses import dataclass from .logging_config import get_logger from .openai_schema import flatten_content logger = get_logger("lingma_gateway.session") @dataclass class SessionEntry: session_id: str created_at: float last_used_at: float hit_count: int = 0 instance_name: str = "" def hash_user_context(messages: list[dict]) -> str: """Hash the user/system/developer turns of a message list. We deliberately skip `assistant`/`tool` messages because: - Clients may subtly reformat or trim assistant replies between turns, breaking exact-match keying. - Only the *inputs* are stable, and they're sufficient to identify a conversation prefix for cache lookup. """ h = hashlib.sha1() for m in messages: role = m.get("role", "") if role not in ("system", "user", "developer"): continue content = m.get("content") text = content if isinstance(content, str) else flatten_content(content) h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8")) return h.hexdigest() class SessionCache: """LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId. Usage pattern: - On request arrival, look up `build_key(api_key, messages[:-1])`. If hit, reuse that sessionId and send ONLY the latest user message with `isReply=True` so Lingma can hit its KV cache. - After a successful completion, store under `build_key(api_key, messages)`. Since the next turn's `messages[:-1]` contains exactly the current turn's `messages` (in the standard OpenAI append-only pattern), the next lookup will hit. We never have to hash the assistant reply itself. Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests for the same conversation may both miss briefly, but that's harmless -- the second write just overwrites the first, both produce valid sessionIds upstream. Safety: entries also remember which pool instance served them, so follow-up turns stick to the same Lingma process (which is the only place that knows the session). """ def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0): self.max = max(0, int(max_entries)) self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0 self._data: "OrderedDict[str, SessionEntry]" = OrderedDict() self._lock = asyncio.Lock() self.hit_total = 0 self.miss_total = 0 self.evict_total = 0 self.expire_total = 0 self.invalidate_total = 0 @property def enabled(self) -> bool: return self.max > 0 def build_key(self, api_key: str, messages: list[dict]) -> str: # API key scoping prevents cross-tenant session leakage even when # different clients happen to produce identical histories. key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12] return f"{key_scope}:{hash_user_context(messages)}" async def get(self, key: str) -> SessionEntry | None: if not self.enabled: return None async with self._lock: entry = self._data.get(key) if entry is None: self.miss_total += 1 return None if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl: self._data.pop(key, None) self.expire_total += 1 self.miss_total += 1 return None entry.last_used_at = time.monotonic() entry.hit_count += 1 self._data.move_to_end(key) self.hit_total += 1 return entry async def put(self, key: str, session_id: str, instance_name: str = "") -> None: if not self.enabled or not session_id: return now = time.monotonic() async with self._lock: existing = self._data.get(key) if existing is not None: existing.session_id = session_id existing.last_used_at = now if instance_name: existing.instance_name = instance_name self._data.move_to_end(key) return self._data[key] = SessionEntry( session_id=session_id, created_at=now, last_used_at=now, instance_name=instance_name, ) while len(self._data) > self.max: self._data.popitem(last=False) self.evict_total += 1 async def invalidate(self, key: str) -> None: async with self._lock: if self._data.pop(key, None) is not None: self.invalidate_total += 1 def stats(self) -> dict: total = self.hit_total + self.miss_total rate = (self.hit_total / total) if total > 0 else 0.0 return { "enabled": self.enabled, "size": len(self._data), "max": self.max, "ttl_sec": self.ttl, "hit_total": self.hit_total, "miss_total": self.miss_total, "hit_rate": round(rate, 4), "evict_total": self.evict_total, "expire_total": self.expire_total, "invalidate_total": self.invalidate_total, } def prometheus_lines(self) -> list[str]: return [ "# TYPE gateway_session_cache_size gauge", f"gateway_session_cache_size {len(self._data)}", "# TYPE gateway_session_cache_hit_total counter", f"gateway_session_cache_hit_total {self.hit_total}", "# TYPE gateway_session_cache_miss_total counter", f"gateway_session_cache_miss_total {self.miss_total}", "# TYPE gateway_session_cache_evict_total counter", f"gateway_session_cache_evict_total {self.evict_total}", "# TYPE gateway_session_cache_expire_total counter", f"gateway_session_cache_expire_total {self.expire_total}", "# TYPE gateway_session_cache_invalidate_total counter", f"gateway_session_cache_invalidate_total {self.invalidate_total}", ]