Stabilize cross-protocol ask-mode/streaming behavior and reduce session-reuse branch collisions, then add focused docs/tests for multimodal normalization and pool/stats/config paths. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
209 lines
7.7 KiB
Python
209 lines
7.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import time
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
|
|
from .logging_config import get_logger
|
|
from .openai_schema import flatten_content
|
|
|
|
|
|
logger = get_logger("lingma_gateway.session")
|
|
|
|
|
|
@dataclass
|
|
class SessionEntry:
|
|
session_id: str
|
|
created_at: float
|
|
last_used_at: float
|
|
hit_count: int = 0
|
|
instance_name: str = ""
|
|
|
|
|
|
def hash_user_context(messages: list[dict]) -> str:
|
|
"""Hash the user/system/developer turns of a message list.
|
|
|
|
We deliberately skip `assistant`/`tool` messages here because:
|
|
- Clients may subtly reformat or trim assistant replies between turns,
|
|
breaking exact-match keying.
|
|
- Only the *inputs* are stable, and they're sufficient to identify a
|
|
conversation prefix for cache lookup.
|
|
"""
|
|
h = hashlib.sha1()
|
|
for m in messages:
|
|
role = m.get("role", "")
|
|
if role not in ("system", "user", "developer"):
|
|
continue
|
|
content = m.get("content")
|
|
text = content if isinstance(content, str) else flatten_content(content)
|
|
h.update(f"{role}\x1f{text or ''}\x1e".encode("utf-8"))
|
|
return h.hexdigest()
|
|
|
|
|
|
def hash_branch_context(messages: list[dict]) -> str:
|
|
"""Hash assistant/tool turns to reduce branch collisions."""
|
|
h = hashlib.sha1()
|
|
for m in messages:
|
|
role = m.get("role", "")
|
|
if role not in ("assistant", "tool"):
|
|
continue
|
|
content = m.get("content")
|
|
text = content if isinstance(content, str) else flatten_content(content)
|
|
tool_calls = m.get("tool_calls")
|
|
if tool_calls is not None:
|
|
try:
|
|
tool_calls_text = json.dumps(tool_calls, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
except Exception:
|
|
tool_calls_text = str(tool_calls)
|
|
else:
|
|
tool_calls_text = ""
|
|
tool_call_id = m.get("tool_call_id") or ""
|
|
h.update(f"{role}\x1f{text or ''}\x1f{tool_calls_text}\x1f{tool_call_id}\x1e".encode("utf-8"))
|
|
return h.hexdigest()
|
|
|
|
|
|
def _tool_fingerprint(tool_config: dict | None) -> str:
|
|
if not isinstance(tool_config, dict):
|
|
return "-"
|
|
try:
|
|
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
except Exception:
|
|
canonical = str(tool_config)
|
|
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
|
|
|
|
|
|
class SessionCache:
|
|
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
|
|
|
Usage pattern:
|
|
- On request arrival, look up `build_key(api_key, messages[:-1])`.
|
|
If hit, reuse that sessionId and send ONLY the latest user message
|
|
with `isReply=True` so Lingma can hit its KV cache.
|
|
- After a successful completion, store under `build_key(api_key, messages)`.
|
|
Since the next turn's `messages[:-1]` contains exactly the current turn's
|
|
`messages` (in the standard OpenAI append-only pattern), the next lookup
|
|
will hit. We never have to hash the assistant reply itself.
|
|
|
|
Concurrency: `asyncio.Lock` guards the ordered dict; two concurrent requests
|
|
for the same conversation may both miss briefly, but that's harmless -- the
|
|
second write just overwrites the first, both produce valid sessionIds
|
|
upstream.
|
|
|
|
Safety: entries also remember which pool instance served them, so follow-up
|
|
turns stick to the same Lingma process (which is the only place that
|
|
knows the session).
|
|
"""
|
|
|
|
def __init__(self, *, max_entries: int = 256, ttl_sec: float = 1800.0):
|
|
self.max = max(0, int(max_entries))
|
|
self.ttl = float(ttl_sec) if ttl_sec and ttl_sec > 0 else 0.0
|
|
self._data: "OrderedDict[str, SessionEntry]" = OrderedDict()
|
|
self._lock = asyncio.Lock()
|
|
self.hit_total = 0
|
|
self.miss_total = 0
|
|
self.evict_total = 0
|
|
self.expire_total = 0
|
|
self.invalidate_total = 0
|
|
|
|
@property
|
|
def enabled(self) -> bool:
|
|
return self.max > 0
|
|
|
|
def build_key(
|
|
self,
|
|
api_key: str,
|
|
messages: list[dict],
|
|
*,
|
|
tool_config: dict | None = None,
|
|
branch_context: str | None = None,
|
|
) -> str:
|
|
# API key scoping prevents cross-tenant session leakage even when
|
|
# different clients happen to produce identical histories.
|
|
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
|
base = f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
|
if not branch_context:
|
|
return base
|
|
return f"{base}:{branch_context}"
|
|
|
|
async def get(self, key: str) -> SessionEntry | None:
|
|
if not self.enabled:
|
|
return None
|
|
async with self._lock:
|
|
entry = self._data.get(key)
|
|
if entry is None:
|
|
self.miss_total += 1
|
|
return None
|
|
if self.ttl > 0 and (time.monotonic() - entry.created_at) > self.ttl:
|
|
self._data.pop(key, None)
|
|
self.expire_total += 1
|
|
self.miss_total += 1
|
|
return None
|
|
entry.last_used_at = time.monotonic()
|
|
entry.hit_count += 1
|
|
self._data.move_to_end(key)
|
|
self.hit_total += 1
|
|
return entry
|
|
|
|
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
|
if not self.enabled or not session_id:
|
|
return
|
|
now = time.monotonic()
|
|
async with self._lock:
|
|
existing = self._data.get(key)
|
|
if existing is not None:
|
|
existing.session_id = session_id
|
|
existing.last_used_at = now
|
|
if instance_name:
|
|
existing.instance_name = instance_name
|
|
self._data.move_to_end(key)
|
|
return
|
|
self._data[key] = SessionEntry(
|
|
session_id=session_id,
|
|
created_at=now,
|
|
last_used_at=now,
|
|
instance_name=instance_name,
|
|
)
|
|
while len(self._data) > self.max:
|
|
self._data.popitem(last=False)
|
|
self.evict_total += 1
|
|
|
|
async def invalidate(self, key: str) -> None:
|
|
async with self._lock:
|
|
if self._data.pop(key, None) is not None:
|
|
self.invalidate_total += 1
|
|
|
|
def stats(self) -> dict:
|
|
total = self.hit_total + self.miss_total
|
|
rate = (self.hit_total / total) if total > 0 else 0.0
|
|
return {
|
|
"enabled": self.enabled,
|
|
"size": len(self._data),
|
|
"max": self.max,
|
|
"ttl_sec": self.ttl,
|
|
"hit_total": self.hit_total,
|
|
"miss_total": self.miss_total,
|
|
"hit_rate": round(rate, 4),
|
|
"evict_total": self.evict_total,
|
|
"expire_total": self.expire_total,
|
|
"invalidate_total": self.invalidate_total,
|
|
}
|
|
|
|
def prometheus_lines(self) -> list[str]:
|
|
return [
|
|
"# TYPE gateway_session_cache_size gauge",
|
|
f"gateway_session_cache_size {len(self._data)}",
|
|
"# TYPE gateway_session_cache_hit_total counter",
|
|
f"gateway_session_cache_hit_total {self.hit_total}",
|
|
"# TYPE gateway_session_cache_miss_total counter",
|
|
f"gateway_session_cache_miss_total {self.miss_total}",
|
|
"# TYPE gateway_session_cache_evict_total counter",
|
|
f"gateway_session_cache_evict_total {self.evict_total}",
|
|
"# TYPE gateway_session_cache_expire_total counter",
|
|
f"gateway_session_cache_expire_total {self.expire_total}",
|
|
"# TYPE gateway_session_cache_invalidate_total counter",
|
|
f"gateway_session_cache_invalidate_total {self.invalidate_total}",
|
|
]
|