feat: harden cache reuse semantics and expand protocol regressions

Stabilize cross-protocol ask-mode/streaming behavior and reduce session-reuse branch collisions, then add focused docs/tests for multimodal normalization and pool/stats/config paths.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-20 14:26:11 +08:00
parent b96b91e5b7
commit 12a4d9584e
9 changed files with 441 additions and 55 deletions

View File

@@ -0,0 +1,193 @@
from __future__ import annotations
import json
import os
import sys
import types
import unittest
from types import SimpleNamespace
from unittest.mock import patch
# app.lingma_pool imports auto_login; tests here don't execute Playwright paths.
# Stub module import so test environments without playwright can import pool code.
_playwright = types.ModuleType("playwright")
_playwright_async = types.ModuleType("playwright.async_api")
class _StubPlaywrightTimeoutError(Exception):
pass
async def _stub_async_playwright():
raise RuntimeError("playwright is stubbed in unit tests")
_playwright_async.TimeoutError = _StubPlaywrightTimeoutError
_playwright_async.async_playwright = _stub_async_playwright
sys.modules.setdefault("playwright", _playwright)
sys.modules.setdefault("playwright.async_api", _playwright_async)
from app.config import _parse_accounts, load_settings
from app.lingma_pool import LingmaPool
from app.stats import StatsCollector, estimate_tokens
def _affinity_key_for_bucket(pool_size: int, bucket_index: int) -> str:
for i in range(20000):
key = f"k-{i}"
if abs(hash(key)) % pool_size == bucket_index:
return key
raise RuntimeError("failed to find affinity key")
class _FakeInstance:
def __init__(self, idx: int, *, healthy: bool, in_flight: int):
self.name = f"inst-{idx}"
self.cfg = SimpleNamespace(index=idx)
self._healthy = healthy
self.in_flight = in_flight
@property
def healthy(self) -> bool:
return self._healthy
class LingmaPoolRoutingTests(unittest.TestCase):
def test_pool_pick_prefers_healthy_affinity_bucket(self) -> None:
inst0 = _FakeInstance(0, healthy=True, in_flight=0)
inst1 = _FakeInstance(1, healthy=True, in_flight=9)
pool = LingmaPool([inst0, inst1])
key = _affinity_key_for_bucket(2, 1)
picked = pool.pick(affinity_key=key)
self.assertIs(picked, inst1)
def test_pool_pick_falls_back_to_least_in_flight_when_affinity_unhealthy(self) -> None:
inst0 = _FakeInstance(0, healthy=True, in_flight=1)
inst1 = _FakeInstance(1, healthy=False, in_flight=0)
inst2 = _FakeInstance(2, healthy=True, in_flight=1)
pool = LingmaPool([inst0, inst1, inst2])
key = _affinity_key_for_bucket(3, 1)
picked = pool.pick(affinity_key=key)
self.assertIs(picked, inst0)
def test_pool_pick_round_robin_when_all_unhealthy(self) -> None:
inst0 = _FakeInstance(0, healthy=False, in_flight=0)
inst1 = _FakeInstance(1, healthy=False, in_flight=0)
inst2 = _FakeInstance(2, healthy=False, in_flight=0)
pool = LingmaPool([inst0, inst1, inst2])
self.assertIs(pool.pick(), inst0)
self.assertIs(pool.pick(), inst1)
self.assertIs(pool.pick(), inst2)
self.assertIs(pool.pick(), inst0)
def test_pool_prometheus_lines_include_required_metrics(self) -> None:
inst0 = _FakeInstance(0, healthy=True, in_flight=2)
inst1 = _FakeInstance(1, healthy=False, in_flight=5)
pool = LingmaPool([inst0, inst1])
text = "\n".join(pool.prometheus_lines())
self.assertIn("# TYPE gateway_pool_instance_in_flight gauge", text)
self.assertIn("# TYPE gateway_pool_instance_ready gauge", text)
self.assertIn('gateway_pool_instance_in_flight{name="inst-0",idx="0"} 2', text)
self.assertIn('gateway_pool_instance_ready{name="inst-0",idx="0"} 1', text)
self.assertIn('gateway_pool_instance_ready{name="inst-1",idx="1"} 0', text)
class StatsCollectorTests(unittest.IsolatedAsyncioTestCase):
def test_estimate_tokens_empty_short_utf8(self) -> None:
self.assertEqual(estimate_tokens(""), 0)
self.assertGreaterEqual(estimate_tokens("a"), 1)
self.assertEqual(estimate_tokens("你好世界"), 3)
async def test_record_chat_updates_counters_and_clamps_negative_tokens(self) -> None:
s = StatsCollector()
await s.record_chat(stream=True, success=True, prompt_tokens=-3, completion_tokens=5)
await s.record_chat(stream=False, success=False, prompt_tokens=2, completion_tokens=-7)
snap = await s.snapshot()
self.assertEqual(snap["chat_requests_total"], 2)
self.assertEqual(snap["chat_requests_success"], 1)
self.assertEqual(snap["chat_requests_error"], 1)
self.assertEqual(snap["chat_stream_requests"], 1)
self.assertEqual(snap["chat_non_stream_requests"], 1)
self.assertEqual(snap["prompt_tokens_estimated_total"], 2)
self.assertEqual(snap["completion_tokens_estimated_total"], 5)
async def test_snapshot_and_prometheus_text_consistency(self) -> None:
s = StatsCollector()
await s.record_chat(stream=True, success=True, prompt_tokens=3, completion_tokens=4)
snap = await s.snapshot()
text = await s.prometheus_text()
self.assertEqual(snap["total_tokens_estimated"], 7)
self.assertIn("gateway_total_tokens_estimated 7", text)
self.assertIn("gateway_chat_requests_total 1", text)
self.assertTrue(text.endswith("\n"))
class ConfigParsingTests(unittest.TestCase):
def test_parse_accounts_accepts_json_csv_newline_formats(self) -> None:
raw_json = json.dumps([
{"username": "u1", "password": "p1"},
{"username": "u2", "password": "p2"},
])
parsed_json = _parse_accounts(raw_json)
self.assertEqual([a.username for a in parsed_json], ["u1", "u2"])
parsed_csv = _parse_accounts("u3:p3,u4:p4")
self.assertEqual([a.username for a in parsed_csv], ["u3", "u4"])
parsed_nl = _parse_accounts("u5:p5\nu6:p6")
self.assertEqual([a.username for a in parsed_nl], ["u5", "u6"])
def test_parse_accounts_allows_bundle_only_in_json(self) -> None:
raw = json.dumps([{"session_bundle": "abc"}])
parsed = _parse_accounts(raw)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].username, "")
self.assertEqual(parsed[0].password, "")
self.assertEqual(parsed[0].session_bundle_b64, "abc")
def test_parse_accounts_csv_splits_only_first_colon(self) -> None:
parsed = _parse_accounts("u:p:with:colon")
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].username, "u")
self.assertEqual(parsed[0].password, "p:with:colon")
def test_load_settings_creates_bundle_only_account_without_credentials(self) -> None:
with patch.dict(os.environ, {"LINGMA_SESSION_BUNDLE": "abc"}, clear=True):
settings = load_settings()
self.assertEqual(len(settings.accounts), 1)
self.assertEqual(settings.accounts[0].username, "")
self.assertEqual(settings.accounts[0].password, "")
self.assertEqual(settings.accounts[0].session_bundle_b64, "abc")
def test_load_settings_invalid_instance_count_fallback(self) -> None:
with patch.dict(
os.environ,
{"LINGMA_ACCOUNTS": "u1:p1,u2:p2", "LINGMA_INSTANCE_COUNT": "not-a-number"},
clear=True,
):
settings_with_accounts = load_settings()
self.assertEqual(settings_with_accounts.instance_count, 2)
with patch.dict(os.environ, {"LINGMA_INSTANCE_COUNT": "not-a-number"}, clear=True):
settings_without_accounts = load_settings()
self.assertEqual(settings_without_accounts.instance_count, 1)
if __name__ == "__main__":
unittest.main()

View File

@@ -17,11 +17,12 @@ class SchemaNormalizationTests(unittest.TestCase):
[
{"type": "text", "text": "hello"},
{"type": "image_url", "image_url": {"url": "x"}},
{"type": "input_image", "image_url": {"url": "y"}},
{"type": "input_audio", "input_audio": {"data": "x"}},
{"type": "text", "text": "world"},
]
)
self.assertEqual(out, "hello\n[image]\n[audio]\nworld")
self.assertEqual(out, "hello\n[image]\n[image]\n[audio]\nworld")
def test_anthropic_flatten_content_with_tool_blocks(self) -> None:
out = flatten_anthropic_content(

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import unittest
from app.session_cache import SessionCache, hash_user_context
from app.session_cache import SessionCache, hash_branch_context, hash_user_context
class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
@@ -17,6 +17,21 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
]
self.assertEqual(hash_user_context(base), hash_user_context(with_extra))
def test_hash_branch_context_distinguishes_assistant_tool_branch(self) -> None:
base = [
{"role": "system", "content": "S"},
{"role": "user", "content": "U"},
{"role": "assistant", "content": "A1"},
{"role": "tool", "content": "T1", "tool_call_id": "call-1"},
]
changed = [
{"role": "system", "content": "S"},
{"role": "user", "content": "U"},
{"role": "assistant", "content": "A2"},
{"role": "tool", "content": "T1", "tool_call_id": "call-1"},
]
self.assertNotEqual(hash_branch_context(base), hash_branch_context(changed))
def test_build_key_changes_with_tool_config(self) -> None:
cache = SessionCache(max_entries=8, ttl_sec=60)
msgs = [{"role": "user", "content": "hi"}]
@@ -26,6 +41,14 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(key1, key2)
self.assertNotEqual(key1, key3)
def test_build_key_keeps_legacy_shape_without_branch_context(self) -> None:
cache = SessionCache(max_entries=8, ttl_sec=60)
msgs = [{"role": "user", "content": "hi"}]
legacy = cache.build_key("k", msgs)
with_branch = cache.build_key("k", msgs, branch_context="abc")
self.assertEqual(legacy.count(":"), 2)
self.assertEqual(with_branch.count(":"), 3)
async def test_lru_evicts_oldest(self) -> None:
cache = SessionCache(max_entries=2, ttl_sec=600)
await cache.put("k1", "s1")

View File

@@ -15,9 +15,10 @@ class _FakeSessionCache:
self.put_calls: list[tuple[str, str, str]] = []
self.invalidate_calls: list[str] = []
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None, branch_context=None) -> str:
marker = "with_tool" if tool_config is not None else "no_tool"
key = f"{api_key}:{len(messages)}:{marker}"
branch_marker = branch_context or "-"
key = f"{api_key}:{len(messages)}:{marker}:branch={branch_marker}"
self.keys.append(key)
return key
@@ -635,6 +636,93 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(fake_cache.put_calls, [])
async def test_openai_session_reuse_lookup_key_separates_branches(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
stream_events=[],
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-branch"},
)
req_a = ChatCompletionsRequest(
model="org_auto",
messages=[
{"role": "system", "content": "S"},
{"role": "user", "content": "U"},
{"role": "assistant", "content": "A1"},
{"role": "user", "content": "next"},
],
stream=False,
)
req_b = ChatCompletionsRequest(
model="org_auto",
messages=[
{"role": "system", "content": "S"},
{"role": "user", "content": "U"},
{"role": "assistant", "content": "A2"},
{"role": "user", "content": "next"},
],
stream=False,
)
with (
patch.object(main, "session_cache", fake_cache),
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
):
await main.v1_chat_completions(req_a, _make_request("/v1/chat/completions"))
await main.v1_chat_completions(req_b, _make_request("/v1/chat/completions"))
self.assertGreaterEqual(len(fake_cache.get_calls), 4)
self.assertNotEqual(fake_cache.get_calls[0], fake_cache.get_calls[2])
self.assertEqual(fake_cache.get_calls[1], fake_cache.get_calls[3])
async def test_openai_and_anthropic_resolve_same_default_ask_mode_without_tooling(self) -> None:
openai_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
anthropic_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
openai_req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
anthropic_req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=128,
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(openai_spy))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
):
await main.v1_chat_completions(openai_req, _make_request("/v1/chat/completions"))
with (
patch.object(main, "pool", _FakePool(_FakeInstance(anthropic_spy))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
_SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False),
):
await main.v1_messages(
anthropic_req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
self.assertEqual(openai_spy.last_complete_args[2], "chat")
self.assertEqual(anthropic_spy.last_complete_args[2], "chat")
async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = AnthropicMessagesRequest(