Files
lingma-openai-gateway/app/session_cache.py
GitHub Actions dfdb7087dc perf: session reuse for multi-turn latency
- Add SessionCache (LRU + TTL, per-API-key scoped) mapping
  conversation-prefix hash -> upstream Lingma sessionId.
- Hash only user/system/developer turns so client-side
  assistant reformatting doesn't invalidate the key.
- On cache hit: reuse sessionId, send only the latest user
  message with isReply=true, and stick the request to the
  instance that originally served it.
- LingmaGatewayClient.chat_complete/chat_stream accept
  session_id/is_reply and report the real finish.sessionId
  via out_meta so we persist what Lingma actually allocated.
- Invalidate cache on non-stream failure; skip writes on
  cancelled/partial streams.
- Expose cache stats in /internal/stats and /metrics.
- Configurable via SESSION_REUSE_ENABLED / SESSION_CACHE_MAX_ENTRIES
  / SESSION_CACHE_TTL_SEC (documented in README + .env.example).

Made-with: Cursor
2026-04-18 08:10:39 +08:00

166 lines
6.2 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
import time
from collections import OrderedDict
from dataclasses import dataclass
from .logging_config import get_logger
from .openai_schema import flatten_content
logger = get_logger("lingma_gateway.session")
@dataclass
class SessionEntry:
session_id: str
created_at: float
last_used_at: float
hit_count: int = 0
instance_name: str = ""
def hash_user_context(messages: list[dict]) -> str:
"""Hash the user/system/developer turns of a message list.
We deliberately skip `assistant`/`tool` messages because:
- Clients may subtly reformat or trim assistant replies between turns,
breaking exact-match keying.
- Only the *inputs* are stable, and they're sufficient to identify a
conversation prefix for cache lookup.
"""
h = hashlib.sha1()
for m in messages:
role = m.get("role", "")
if role not in ("system", "user", "developer"):
continue
content = m.get("content")
text = content if isinstance(content, str) else flatten_content(content)
h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8"))
return h.hexdigest()
class SessionCache:
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
Usage pattern:
- On request arrival, look up `build_key(api_key, messages[:-1])`.
If hit, reuse that sessionId and send ONLY the latest user message
with `isReply=True` so Lingma can hit its KV cache.
- After a successful completion, store under `build_key(api_key, messages)`.
Since the next turn's `messages[:-1]` contains exactly the current turn's
`messages` (in the standard OpenAI append-only pattern), the next lookup
will hit. We never have to hash the assistant reply itself.
Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests
for the same conversation may both miss briefly, but that's harmless -- the
second write just overwrites the first, both produce valid sessionIds
upstream.
Safety: entries also remember which pool instance served them, so follow-up
turns stick to the same Lingma process (which is the only place that
knows the session).
"""
def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0):
self.max = max(0, int(max_entries))
self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0
self._data: "OrderedDict[str, SessionEntry]" = OrderedDict()
self._lock = asyncio.Lock()
self.hit_total = 0
self.miss_total = 0
self.evict_total = 0
self.expire_total = 0
self.invalidate_total = 0
@property
def enabled(self) -> bool:
return self.max > 0
def build_key(self, api_key: str, messages: list[dict]) -> str:
# API key scoping prevents cross-tenant session leakage even when
# different clients happen to produce identical histories.
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
return f"{key_scope}:{hash_user_context(messages)}"
async def get(self, key: str) -> SessionEntry | None:
if not self.enabled:
return None
async with self._lock:
entry = self._data.get(key)
if entry is None:
self.miss_total += 1
return None
if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl:
self._data.pop(key, None)
self.expire_total += 1
self.miss_total += 1
return None
entry.last_used_at = time.monotonic()
entry.hit_count += 1
self._data.move_to_end(key)
self.hit_total += 1
return entry
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
if not self.enabled or not session_id:
return
now = time.monotonic()
async with self._lock:
existing = self._data.get(key)
if existing is not None:
existing.session_id = session_id
existing.last_used_at = now
if instance_name:
existing.instance_name = instance_name
self._data.move_to_end(key)
return
self._data[key] = SessionEntry(
session_id=session_id,
created_at=now,
last_used_at=now,
instance_name=instance_name,
)
while len(self._data) > self.max:
self._data.popitem(last=False)
self.evict_total += 1
async def invalidate(self, key: str) -> None:
async with self._lock:
if self._data.pop(key, None) is not None:
self.invalidate_total += 1
def stats(self) -> dict:
total = self.hit_total + self.miss_total
rate = (self.hit_total / total) if total > 0 else 0.0
return {
"enabled": self.enabled,
"size": len(self._data),
"max": self.max,
"ttl_sec": self.ttl,
"hit_total": self.hit_total,
"miss_total": self.miss_total,
"hit_rate": round(rate, 4),
"evict_total": self.evict_total,
"expire_total": self.expire_total,
"invalidate_total": self.invalidate_total,
}
def prometheus_lines(self) -> list[str]:
return [
"# TYPE gateway_session_cache_size gauge",
f"gateway_session_cache_size {len(self._data)}",
"# TYPE gateway_session_cache_hit_total counter",
f"gateway_session_cache_hit_total {self.hit_total}",
"# TYPE gateway_session_cache_miss_total counter",
f"gateway_session_cache_miss_total {self.miss_total}",
"# TYPE gateway_session_cache_evict_total counter",
f"gateway_session_cache_evict_total {self.evict_total}",
"# TYPE gateway_session_cache_expire_total counter",
f"gateway_session_cache_expire_total {self.expire_total}",
"# TYPE gateway_session_cache_invalidate_total counter",
f"gateway_session_cache_invalidate_total {self.invalidate_total}",
]