feat: harden cache reuse semantics and expand protocol regressions
Stabilize cross-protocol ask-mode/streaming behavior and reduce session-reuse branch collisions, then add focused docs/tests for multimodal normalization and pool/stats/config paths. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -518,7 +518,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进
|
||||
### 5.3 session cache 只哈希 user/system/developer 消息
|
||||
|
||||
- **问题**:OpenAI 客户端常常会规范化 / 裁剪 assistant 消息(例如 trim 末尾空白、去掉思考内容),导致下一轮的 `messages[:-1]` 跟上一轮的 `messages` 不完全字节相等。
|
||||
- **方案**:`hash_user_context` 只对 `system / user / developer` 三种 role 做 SHA1;assistant/tool 不参与。只要**用户输入路径**稳定,哈希就稳定。
|
||||
- **方案**:`hash_user_context` 只对 `system / user / developer` 三种 role 做 SHA1;assistant/tool 不参与。只要**用户输入路径**稳定,哈希就稳定。多模态会先在归一化阶段降级为占位符(如 `[image]` / `[audio]`)再参与哈希,因此会保留“模态存在”信号但不保留原始媒体内容。
|
||||
- **权衡**:理论上客户端篡改 assistant 语义(比如把模型的回答改成相反的)时,cache 依然命中,但 Lingma 侧自己持有 session 原版历史,下一轮还是按原版继续。对用户意图的偏离不可见。这是 OK 的——客户端本来就不该篡改 assistant 内容。
|
||||
|
||||
### 5.4 session cache 写入用 `write_key = hash(messages)`,查询用 `lookup_key = hash(messages[:-1])`
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
- OpenAI:`/v1/models`、`/v1/chat/completions`(含 stream)
|
||||
- Anthropic:`/v1/messages`、`/v1/messages/count_tokens`(含 stream)
|
||||
- 内置:多实例池、会话复用、Prometheus 指标、登录态 bundle 注入
|
||||
- 多模态降级:OpenAI `image_url` / `input_image` 转 `[image]`,`input_audio` 转 `[audio]`;Anthropic `image` 转 `[image]`
|
||||
|
||||
> 架构设计与二开细节请看 [`DESIGN.md`](./DESIGN.md)。
|
||||
|
||||
|
||||
@@ -119,10 +119,8 @@ def anthropic_to_internal_messages(req: AnthropicMessagesRequest) -> list[dict]:
|
||||
"""Project an Anthropic request into the gateway's internal message list.
|
||||
|
||||
Internal shape matches what `_messages_to_prompt` already expects:
|
||||
`[{"role": "system"|"user"|"assistant", "content": "..."}]`. This means
|
||||
session-cache hashing is identical across OpenAI and Anthropic callers —
|
||||
a user who migrates between the two endpoints keeps their session affinity
|
||||
as long as they send the same conversation prefix.
|
||||
`[{"role": "system"|"user"|"assistant", "content": "..."}]`. This keeps
|
||||
user-input cache hashing aligned across OpenAI and Anthropic callers.
|
||||
"""
|
||||
out: list[dict] = []
|
||||
if req.system:
|
||||
|
||||
136
app/main.py
136
app/main.py
@@ -38,7 +38,7 @@ from .openai_schema import (
|
||||
flatten_content,
|
||||
)
|
||||
from .session_bundle import encode_bundle, pack_workdir
|
||||
from .session_cache import SessionCache
|
||||
from .session_cache import SessionCache, hash_branch_context
|
||||
from .stats import StatsCollector, estimate_tokens
|
||||
|
||||
|
||||
@@ -57,6 +57,12 @@ session_cache = SessionCache(
|
||||
ttl_sec=settings.session_cache_ttl_sec,
|
||||
)
|
||||
|
||||
STREAMING_RESPONSE_HEADERS = {
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
|
||||
|
||||
def _require_pool() -> LingmaPool:
|
||||
if pool is None:
|
||||
@@ -416,6 +422,43 @@ def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_ask_mode(model: str, has_tooling_context: bool) -> str:
|
||||
model_name = (model or "").lower()
|
||||
if model_name in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
return "agent"
|
||||
return settings.default_ask_mode
|
||||
|
||||
|
||||
async def _apply_cached_instance_or_invalidate(
|
||||
*,
|
||||
protocol: str,
|
||||
inst: PoolInstance,
|
||||
cached_instance_name: str | None,
|
||||
cached_session_id: str | None,
|
||||
lookup_key: str | None,
|
||||
) -> str | None:
|
||||
if cached_instance_name and inst.name != cached_instance_name:
|
||||
logger.info(
|
||||
"%s session cache instance %s unhealthy, falling back to %s",
|
||||
protocol,
|
||||
cached_instance_name,
|
||||
inst.name,
|
||||
)
|
||||
if lookup_key:
|
||||
await session_cache.invalidate(lookup_key)
|
||||
return None
|
||||
return cached_session_id
|
||||
|
||||
|
||||
|
||||
def _streaming_response(event_stream) -> StreamingResponse:
|
||||
return StreamingResponse(
|
||||
event_stream,
|
||||
media_type="text/event-stream",
|
||||
headers=STREAMING_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
def _stream_event_type(event: Any) -> str:
|
||||
if isinstance(event, dict):
|
||||
t = event.get("type")
|
||||
@@ -595,9 +638,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
tool_config = _openai_tool_config(req)
|
||||
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
||||
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
ask_mode = "agent"
|
||||
ask_mode = _resolve_ask_mode(req.model, has_tooling_context)
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled
|
||||
@@ -610,29 +651,38 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
prefix_branch_context = hash_branch_context(messages_dump[:-1])
|
||||
lookup_key = session_cache.build_key(
|
||||
api_key,
|
||||
messages_dump[:-1],
|
||||
tool_config=tool_config,
|
||||
branch_context=prefix_branch_context,
|
||||
)
|
||||
write_key = session_cache.build_key(
|
||||
api_key,
|
||||
messages_dump,
|
||||
tool_config=tool_config,
|
||||
branch_context=hash_branch_context(messages_dump),
|
||||
)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is None:
|
||||
legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
entry = await session_cache.get(legacy_lookup_key)
|
||||
if entry is not None:
|
||||
lookup_key = legacy_lookup_key
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
cached_instance_name = entry.instance_name or None
|
||||
|
||||
# Instance selection: prefer cached instance for continuity, else normal affinity.
|
||||
affinity = cached_instance_name or _affinity_key_for(req)
|
||||
inst = p.pick(affinity_key=affinity)
|
||||
|
||||
# If cache pointed at a specific instance that's no longer healthy, we already
|
||||
# fell back via pool.pick -> drop the cached session since Lingma on a
|
||||
# different process won't know about it.
|
||||
if cached_instance_name and inst.name != cached_instance_name:
|
||||
logger.info(
|
||||
"session cache instance %s unhealthy, falling back to %s (dropping cached session)",
|
||||
cached_instance_name,
|
||||
inst.name,
|
||||
)
|
||||
cached_session_id = None
|
||||
if lookup_key:
|
||||
await session_cache.invalidate(lookup_key)
|
||||
cached_session_id = await _apply_cached_instance_or_invalidate(
|
||||
protocol="chat",
|
||||
inst=inst,
|
||||
cached_instance_name=cached_instance_name,
|
||||
cached_session_id=cached_session_id,
|
||||
lookup_key=lookup_key,
|
||||
)
|
||||
|
||||
await _ensure_instance_logged_in(inst)
|
||||
|
||||
@@ -831,15 +881,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
_ticket.release()
|
||||
|
||||
ticket_transferred = True
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
return _streaming_response(event_stream())
|
||||
|
||||
|
||||
try:
|
||||
result = await inst.client.chat_complete(
|
||||
@@ -1329,9 +1372,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
tool_config = _anthropic_tool_config(req)
|
||||
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
ask_mode = "agent"
|
||||
ask_mode = _resolve_ask_mode(req.model, has_tooling_context)
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
||||
@@ -1341,9 +1382,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
prefix_branch_context = hash_branch_context(messages_dump[:-1])
|
||||
lookup_key = session_cache.build_key(
|
||||
api_key,
|
||||
messages_dump[:-1],
|
||||
tool_config=tool_config,
|
||||
branch_context=prefix_branch_context,
|
||||
)
|
||||
write_key = session_cache.build_key(
|
||||
api_key,
|
||||
messages_dump,
|
||||
tool_config=tool_config,
|
||||
branch_context=hash_branch_context(messages_dump),
|
||||
)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is None:
|
||||
legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
entry = await session_cache.get(legacy_lookup_key)
|
||||
if entry is not None:
|
||||
lookup_key = legacy_lookup_key
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
cached_instance_name = entry.instance_name or None
|
||||
@@ -1613,15 +1670,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
_ticket.release()
|
||||
|
||||
ticket_transferred = True
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
return _streaming_response(event_stream())
|
||||
|
||||
|
||||
# ------------------------------------------------------------- non-stream
|
||||
try:
|
||||
|
||||
@@ -26,7 +26,7 @@ class SessionEntry:
|
||||
def hash_user_context(messages: list[dict]) -> str:
|
||||
"""Hash the user/system/developer turns of a message list.
|
||||
|
||||
We deliberately skip `assistant`/`tool` messages because:
|
||||
We deliberately skip `assistant`/`tool` messages here because:
|
||||
- Clients may subtly reformat or trim assistant replies between turns,
|
||||
breaking exact-match keying.
|
||||
- Only the *inputs* are stable, and they're sufficient to identify a
|
||||
@@ -43,6 +43,28 @@ def hash_user_context(messages: list[dict]) -> str:
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def hash_branch_context(messages: list[dict]) -> str:
|
||||
"""Hash assistant/tool turns to reduce branch collisions."""
|
||||
h = hashlib.sha1()
|
||||
for m in messages:
|
||||
role = m.get("role", "")
|
||||
if role not in ("assistant", "tool"):
|
||||
continue
|
||||
content = m.get("content")
|
||||
text = content if isinstance(content, str) else flatten_content(content)
|
||||
tool_calls = m.get("tool_calls")
|
||||
if tool_calls is not None:
|
||||
try:
|
||||
tool_calls_text = json.dumps(tool_calls, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
except Exception:
|
||||
tool_calls_text = str(tool_calls)
|
||||
else:
|
||||
tool_calls_text = ""
|
||||
tool_call_id = m.get("tool_call_id") or ""
|
||||
h.update(f"{role}\x1f{text or ''}\x1f{tool_calls_text}\x1f{tool_call_id}\x1e".encode("utf-8"))
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _tool_fingerprint(tool_config: dict | None) -> str:
|
||||
if not isinstance(tool_config, dict):
|
||||
return "-"
|
||||
@@ -90,11 +112,21 @@ class SessionCache:
|
||||
def enabled(self) -> bool:
|
||||
return self.max > 0
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
|
||||
def build_key(
|
||||
self,
|
||||
api_key: str,
|
||||
messages: list[dict],
|
||||
*,
|
||||
tool_config: dict | None = None,
|
||||
branch_context: str | None = None,
|
||||
) -> str:
|
||||
# API key scoping prevents cross-tenant session leakage even when
|
||||
# different clients happen to produce identical histories.
|
||||
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||
return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
||||
base = f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
||||
if not branch_context:
|
||||
return base
|
||||
return f"{base}:{branch_context}"
|
||||
|
||||
async def get(self, key: str) -> SessionEntry | None:
|
||||
if not self.enabled:
|
||||
|
||||
193
tests/test_pool_stats_config.py
Normal file
193
tests/test_pool_stats_config.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
# app.lingma_pool imports auto_login; tests here don't execute Playwright paths.
|
||||
# Stub module import so test environments without playwright can import pool code.
|
||||
_playwright = types.ModuleType("playwright")
|
||||
_playwright_async = types.ModuleType("playwright.async_api")
|
||||
|
||||
|
||||
class _StubPlaywrightTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def _stub_async_playwright():
|
||||
raise RuntimeError("playwright is stubbed in unit tests")
|
||||
|
||||
|
||||
_playwright_async.TimeoutError = _StubPlaywrightTimeoutError
|
||||
_playwright_async.async_playwright = _stub_async_playwright
|
||||
sys.modules.setdefault("playwright", _playwright)
|
||||
sys.modules.setdefault("playwright.async_api", _playwright_async)
|
||||
|
||||
from app.config import _parse_accounts, load_settings
|
||||
from app.lingma_pool import LingmaPool
|
||||
from app.stats import StatsCollector, estimate_tokens
|
||||
|
||||
|
||||
def _affinity_key_for_bucket(pool_size: int, bucket_index: int) -> str:
|
||||
for i in range(20000):
|
||||
key = f"k-{i}"
|
||||
if abs(hash(key)) % pool_size == bucket_index:
|
||||
return key
|
||||
raise RuntimeError("failed to find affinity key")
|
||||
|
||||
|
||||
class _FakeInstance:
|
||||
def __init__(self, idx: int, *, healthy: bool, in_flight: int):
|
||||
self.name = f"inst-{idx}"
|
||||
self.cfg = SimpleNamespace(index=idx)
|
||||
self._healthy = healthy
|
||||
self.in_flight = in_flight
|
||||
|
||||
@property
|
||||
def healthy(self) -> bool:
|
||||
return self._healthy
|
||||
|
||||
|
||||
class LingmaPoolRoutingTests(unittest.TestCase):
|
||||
def test_pool_pick_prefers_healthy_affinity_bucket(self) -> None:
|
||||
inst0 = _FakeInstance(0, healthy=True, in_flight=0)
|
||||
inst1 = _FakeInstance(1, healthy=True, in_flight=9)
|
||||
pool = LingmaPool([inst0, inst1])
|
||||
|
||||
key = _affinity_key_for_bucket(2, 1)
|
||||
picked = pool.pick(affinity_key=key)
|
||||
|
||||
self.assertIs(picked, inst1)
|
||||
|
||||
def test_pool_pick_falls_back_to_least_in_flight_when_affinity_unhealthy(self) -> None:
|
||||
inst0 = _FakeInstance(0, healthy=True, in_flight=1)
|
||||
inst1 = _FakeInstance(1, healthy=False, in_flight=0)
|
||||
inst2 = _FakeInstance(2, healthy=True, in_flight=1)
|
||||
pool = LingmaPool([inst0, inst1, inst2])
|
||||
|
||||
key = _affinity_key_for_bucket(3, 1)
|
||||
picked = pool.pick(affinity_key=key)
|
||||
|
||||
self.assertIs(picked, inst0)
|
||||
|
||||
def test_pool_pick_round_robin_when_all_unhealthy(self) -> None:
|
||||
inst0 = _FakeInstance(0, healthy=False, in_flight=0)
|
||||
inst1 = _FakeInstance(1, healthy=False, in_flight=0)
|
||||
inst2 = _FakeInstance(2, healthy=False, in_flight=0)
|
||||
pool = LingmaPool([inst0, inst1, inst2])
|
||||
|
||||
self.assertIs(pool.pick(), inst0)
|
||||
self.assertIs(pool.pick(), inst1)
|
||||
self.assertIs(pool.pick(), inst2)
|
||||
self.assertIs(pool.pick(), inst0)
|
||||
|
||||
def test_pool_prometheus_lines_include_required_metrics(self) -> None:
|
||||
inst0 = _FakeInstance(0, healthy=True, in_flight=2)
|
||||
inst1 = _FakeInstance(1, healthy=False, in_flight=5)
|
||||
pool = LingmaPool([inst0, inst1])
|
||||
|
||||
text = "\n".join(pool.prometheus_lines())
|
||||
|
||||
self.assertIn("# TYPE gateway_pool_instance_in_flight gauge", text)
|
||||
self.assertIn("# TYPE gateway_pool_instance_ready gauge", text)
|
||||
self.assertIn('gateway_pool_instance_in_flight{name="inst-0",idx="0"} 2', text)
|
||||
self.assertIn('gateway_pool_instance_ready{name="inst-0",idx="0"} 1', text)
|
||||
self.assertIn('gateway_pool_instance_ready{name="inst-1",idx="1"} 0', text)
|
||||
|
||||
|
||||
class StatsCollectorTests(unittest.IsolatedAsyncioTestCase):
|
||||
def test_estimate_tokens_empty_short_utf8(self) -> None:
|
||||
self.assertEqual(estimate_tokens(""), 0)
|
||||
self.assertGreaterEqual(estimate_tokens("a"), 1)
|
||||
self.assertEqual(estimate_tokens("你好世界"), 3)
|
||||
|
||||
async def test_record_chat_updates_counters_and_clamps_negative_tokens(self) -> None:
|
||||
s = StatsCollector()
|
||||
|
||||
await s.record_chat(stream=True, success=True, prompt_tokens=-3, completion_tokens=5)
|
||||
await s.record_chat(stream=False, success=False, prompt_tokens=2, completion_tokens=-7)
|
||||
snap = await s.snapshot()
|
||||
|
||||
self.assertEqual(snap["chat_requests_total"], 2)
|
||||
self.assertEqual(snap["chat_requests_success"], 1)
|
||||
self.assertEqual(snap["chat_requests_error"], 1)
|
||||
self.assertEqual(snap["chat_stream_requests"], 1)
|
||||
self.assertEqual(snap["chat_non_stream_requests"], 1)
|
||||
self.assertEqual(snap["prompt_tokens_estimated_total"], 2)
|
||||
self.assertEqual(snap["completion_tokens_estimated_total"], 5)
|
||||
|
||||
async def test_snapshot_and_prometheus_text_consistency(self) -> None:
|
||||
s = StatsCollector()
|
||||
|
||||
await s.record_chat(stream=True, success=True, prompt_tokens=3, completion_tokens=4)
|
||||
snap = await s.snapshot()
|
||||
text = await s.prometheus_text()
|
||||
|
||||
self.assertEqual(snap["total_tokens_estimated"], 7)
|
||||
self.assertIn("gateway_total_tokens_estimated 7", text)
|
||||
self.assertIn("gateway_chat_requests_total 1", text)
|
||||
self.assertTrue(text.endswith("\n"))
|
||||
|
||||
|
||||
class ConfigParsingTests(unittest.TestCase):
|
||||
def test_parse_accounts_accepts_json_csv_newline_formats(self) -> None:
|
||||
raw_json = json.dumps([
|
||||
{"username": "u1", "password": "p1"},
|
||||
{"username": "u2", "password": "p2"},
|
||||
])
|
||||
parsed_json = _parse_accounts(raw_json)
|
||||
self.assertEqual([a.username for a in parsed_json], ["u1", "u2"])
|
||||
|
||||
parsed_csv = _parse_accounts("u3:p3,u4:p4")
|
||||
self.assertEqual([a.username for a in parsed_csv], ["u3", "u4"])
|
||||
|
||||
parsed_nl = _parse_accounts("u5:p5\nu6:p6")
|
||||
self.assertEqual([a.username for a in parsed_nl], ["u5", "u6"])
|
||||
|
||||
def test_parse_accounts_allows_bundle_only_in_json(self) -> None:
|
||||
raw = json.dumps([{"session_bundle": "abc"}])
|
||||
parsed = _parse_accounts(raw)
|
||||
|
||||
self.assertEqual(len(parsed), 1)
|
||||
self.assertEqual(parsed[0].username, "")
|
||||
self.assertEqual(parsed[0].password, "")
|
||||
self.assertEqual(parsed[0].session_bundle_b64, "abc")
|
||||
|
||||
def test_parse_accounts_csv_splits_only_first_colon(self) -> None:
|
||||
parsed = _parse_accounts("u:p:with:colon")
|
||||
|
||||
self.assertEqual(len(parsed), 1)
|
||||
self.assertEqual(parsed[0].username, "u")
|
||||
self.assertEqual(parsed[0].password, "p:with:colon")
|
||||
|
||||
def test_load_settings_creates_bundle_only_account_without_credentials(self) -> None:
|
||||
with patch.dict(os.environ, {"LINGMA_SESSION_BUNDLE": "abc"}, clear=True):
|
||||
settings = load_settings()
|
||||
|
||||
self.assertEqual(len(settings.accounts), 1)
|
||||
self.assertEqual(settings.accounts[0].username, "")
|
||||
self.assertEqual(settings.accounts[0].password, "")
|
||||
self.assertEqual(settings.accounts[0].session_bundle_b64, "abc")
|
||||
|
||||
def test_load_settings_invalid_instance_count_fallback(self) -> None:
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"LINGMA_ACCOUNTS": "u1:p1,u2:p2", "LINGMA_INSTANCE_COUNT": "not-a-number"},
|
||||
clear=True,
|
||||
):
|
||||
settings_with_accounts = load_settings()
|
||||
|
||||
self.assertEqual(settings_with_accounts.instance_count, 2)
|
||||
|
||||
with patch.dict(os.environ, {"LINGMA_INSTANCE_COUNT": "not-a-number"}, clear=True):
|
||||
settings_without_accounts = load_settings()
|
||||
|
||||
self.assertEqual(settings_without_accounts.instance_count, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -17,11 +17,12 @@ class SchemaNormalizationTests(unittest.TestCase):
|
||||
[
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "image_url", "image_url": {"url": "x"}},
|
||||
{"type": "input_image", "image_url": {"url": "y"}},
|
||||
{"type": "input_audio", "input_audio": {"data": "x"}},
|
||||
{"type": "text", "text": "world"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(out, "hello\n[image]\n[audio]\nworld")
|
||||
self.assertEqual(out, "hello\n[image]\n[image]\n[audio]\nworld")
|
||||
|
||||
def test_anthropic_flatten_content_with_tool_blocks(self) -> None:
|
||||
out = flatten_anthropic_content(
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from app.session_cache import SessionCache, hash_user_context
|
||||
from app.session_cache import SessionCache, hash_branch_context, hash_user_context
|
||||
|
||||
|
||||
class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
|
||||
@@ -17,6 +17,21 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
|
||||
]
|
||||
self.assertEqual(hash_user_context(base), hash_user_context(with_extra))
|
||||
|
||||
def test_hash_branch_context_distinguishes_assistant_tool_branch(self) -> None:
|
||||
base = [
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "U"},
|
||||
{"role": "assistant", "content": "A1"},
|
||||
{"role": "tool", "content": "T1", "tool_call_id": "call-1"},
|
||||
]
|
||||
changed = [
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "U"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
{"role": "tool", "content": "T1", "tool_call_id": "call-1"},
|
||||
]
|
||||
self.assertNotEqual(hash_branch_context(base), hash_branch_context(changed))
|
||||
|
||||
def test_build_key_changes_with_tool_config(self) -> None:
|
||||
cache = SessionCache(max_entries=8, ttl_sec=60)
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
@@ -26,6 +41,14 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(key1, key2)
|
||||
self.assertNotEqual(key1, key3)
|
||||
|
||||
def test_build_key_keeps_legacy_shape_without_branch_context(self) -> None:
|
||||
cache = SessionCache(max_entries=8, ttl_sec=60)
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
legacy = cache.build_key("k", msgs)
|
||||
with_branch = cache.build_key("k", msgs, branch_context="abc")
|
||||
self.assertEqual(legacy.count(":"), 2)
|
||||
self.assertEqual(with_branch.count(":"), 3)
|
||||
|
||||
async def test_lru_evicts_oldest(self) -> None:
|
||||
cache = SessionCache(max_entries=2, ttl_sec=600)
|
||||
await cache.put("k1", "s1")
|
||||
|
||||
@@ -15,9 +15,10 @@ class _FakeSessionCache:
|
||||
self.put_calls: list[tuple[str, str, str]] = []
|
||||
self.invalidate_calls: list[str] = []
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None, branch_context=None) -> str:
|
||||
marker = "with_tool" if tool_config is not None else "no_tool"
|
||||
key = f"{api_key}:{len(messages)}:{marker}"
|
||||
branch_marker = branch_context or "-"
|
||||
key = f"{api_key}:{len(messages)}:{marker}:branch={branch_marker}"
|
||||
self.keys.append(key)
|
||||
return key
|
||||
|
||||
@@ -635,6 +636,93 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(fake_cache.put_calls, [])
|
||||
|
||||
|
||||
async def test_openai_session_reuse_lookup_key_separates_branches(self) -> None:
|
||||
fake_cache = _FakeSessionCache()
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-branch"},
|
||||
)
|
||||
|
||||
req_a = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "U"},
|
||||
{"role": "assistant", "content": "A1"},
|
||||
{"role": "user", "content": "next"},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
req_b = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[
|
||||
{"role": "system", "content": "S"},
|
||||
{"role": "user", "content": "U"},
|
||||
{"role": "assistant", "content": "A2"},
|
||||
{"role": "user", "content": "next"},
|
||||
],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "session_cache", fake_cache),
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
|
||||
):
|
||||
await main.v1_chat_completions(req_a, _make_request("/v1/chat/completions"))
|
||||
await main.v1_chat_completions(req_b, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertGreaterEqual(len(fake_cache.get_calls), 4)
|
||||
self.assertNotEqual(fake_cache.get_calls[0], fake_cache.get_calls[2])
|
||||
self.assertEqual(fake_cache.get_calls[1], fake_cache.get_calls[3])
|
||||
|
||||
async def test_openai_and_anthropic_resolve_same_default_ask_mode_without_tooling(self) -> None:
|
||||
openai_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
anthropic_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
|
||||
openai_req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
anthropic_req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=128,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(openai_spy))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
|
||||
):
|
||||
await main.v1_chat_completions(openai_req, _make_request("/v1/chat/completions"))
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(anthropic_spy))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
|
||||
):
|
||||
await main.v1_messages(
|
||||
anthropic_req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(openai_spy.last_complete_args[2], "chat")
|
||||
self.assertEqual(anthropic_spy.last_complete_args[2], "chat")
|
||||
|
||||
async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = AnthropicMessagesRequest(
|
||||
|
||||
Reference in New Issue
Block a user