diff --git a/DESIGN.md b/DESIGN.md index b87fe36..873fe54 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -518,7 +518,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进 ### 5.3 session cache 只哈希 user/system/developer 消息 - **问题**:OpenAI 客户端常常会规范化 / 裁剪 assistant 消息(例如 trim 末尾空白、去掉思考内容),导致下一轮的 `messages[:-1]` 跟上一轮的 `messages` 不完全字节相等。 -- **方案**:`hash_user_context` 只对 `system / user / developer` 三种 role 做 SHA1;assistant/tool 不参与。只要**用户输入路径**稳定,哈希就稳定。 +- **方案**:`hash_user_context` 只对 `system / user / developer` 三种 role 做 SHA1;assistant/tool 不参与。只要**用户输入路径**稳定,哈希就稳定。多模态会先在归一化阶段降级为占位符(如 `[image]` / `[audio]`)再参与哈希,因此会保留“模态存在”信号但不保留原始媒体内容。 - **权衡**:理论上客户端篡改 assistant 语义(比如把模型的回答改成相反的)时,cache 依然命中,但 Lingma 侧自己持有 session 原版历史,下一轮还是按原版继续。对用户意图的偏离不可见。这是 OK 的——客户端本来就不该篡改 assistant 内容。 ### 5.4 session cache 写入用 `write_key = hash(messages)`,查询用 `lookup_key = hash(messages[:-1])` diff --git a/README.md b/README.md index 9c5b681..bb3e43e 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ - OpenAI:`/v1/models`、`/v1/chat/completions`(含 stream) - Anthropic:`/v1/messages`、`/v1/messages/count_tokens`(含 stream) - 内置:多实例池、会话复用、Prometheus 指标、登录态 bundle 注入 +- 多模态降级:OpenAI `image_url` / `input_image` 转 `[image]`,`input_audio` 转 `[audio]`;Anthropic `image` 转 `[image]` > 架构设计与二开细节请看 [`DESIGN.md`](./DESIGN.md)。 diff --git a/app/anthropic_schema.py b/app/anthropic_schema.py index af854b1..05dece0 100644 --- a/app/anthropic_schema.py +++ b/app/anthropic_schema.py @@ -119,10 +119,8 @@ def anthropic_to_internal_messages(req: AnthropicMessagesRequest) -> list[dict]: """Project an Anthropic request into the gateway's internal message list. Internal shape matches what `_messages_to_prompt` already expects: - `[{"role": "system"|"user"|"assistant", "content": "..."}]`. This means - session-cache hashing is identical across OpenAI and Anthropic callers — - a user who migrates between the two endpoints keeps their session affinity - as long as they send the same conversation prefix. + `[{"role": "system"|"user"|"assistant", "content": "..."}]`. This keeps + user-input cache hashing aligned across OpenAI and Anthropic callers. """ out: list[dict] = [] if req.system: diff --git a/app/main.py b/app/main.py index 9e00920..c0c7068 100644 --- a/app/main.py +++ b/app/main.py @@ -38,7 +38,7 @@ from .openai_schema import ( flatten_content, ) from .session_bundle import encode_bundle, pack_workdir -from .session_cache import SessionCache +from .session_cache import SessionCache, hash_branch_context from .stats import StatsCollector, estimate_tokens @@ -57,6 +57,12 @@ session_cache = SessionCache( ttl_sec=settings.session_cache_ttl_sec, ) +STREAMING_RESPONSE_HEADERS = { + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", +} + def _require_pool() -> LingmaPool: if pool is None: @@ -416,6 +422,43 @@ def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool: return False +def _resolve_ask_mode(model: str, has_tooling_context: bool) -> str: + model_name = (model or "").lower() + if model_name in {"lingma-agent", "agent"} or has_tooling_context: + return "agent" + return settings.default_ask_mode + + +async def _apply_cached_instance_or_invalidate( + *, + protocol: str, + inst: PoolInstance, + cached_instance_name: str | None, + cached_session_id: str | None, + lookup_key: str | None, +) -> str | None: + if cached_instance_name and inst.name != cached_instance_name: + logger.info( + "%s session cache instance %s unhealthy, falling back to %s", + protocol, + cached_instance_name, + inst.name, + ) + if lookup_key: + await session_cache.invalidate(lookup_key) + return None + return cached_session_id + + + +def _streaming_response(event_stream) -> StreamingResponse: + return StreamingResponse( + event_stream, + media_type="text/event-stream", + headers=STREAMING_RESPONSE_HEADERS, + ) + + def _stream_event_type(event: Any) -> str: if isinstance(event, dict): t = event.get("type") @@ -595,9 +638,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_config = _openai_tool_config(req) has_tooling_context = _openai_has_tooling_context(req, messages_dump) - ask_mode = settings.default_ask_mode - if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context: - ask_mode = "agent" + ask_mode = _resolve_ask_mode(req.model, has_tooling_context) reuse_eligible = ( session_cache.enabled @@ -610,29 +651,38 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: - lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) - write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config) + prefix_branch_context = hash_branch_context(messages_dump[:-1]) + lookup_key = session_cache.build_key( + api_key, + messages_dump[:-1], + tool_config=tool_config, + branch_context=prefix_branch_context, + ) + write_key = session_cache.build_key( + api_key, + messages_dump, + tool_config=tool_config, + branch_context=hash_branch_context(messages_dump), + ) entry = await session_cache.get(lookup_key) + if entry is None: + legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) + entry = await session_cache.get(legacy_lookup_key) + if entry is not None: + lookup_key = legacy_lookup_key if entry is not None: cached_session_id = entry.session_id cached_instance_name = entry.instance_name or None - - # Instance selection: prefer cached instance for continuity, else normal affinity. affinity = cached_instance_name or _affinity_key_for(req) inst = p.pick(affinity_key=affinity) - # If cache pointed at a specific instance that's no longer healthy, we already - # fell back via pool.pick -> drop the cached session since Lingma on a - # different process won't know about it. - if cached_instance_name and inst.name != cached_instance_name: - logger.info( - "session cache instance %s unhealthy, falling back to %s (dropping cached session)", - cached_instance_name, - inst.name, - ) - cached_session_id = None - if lookup_key: - await session_cache.invalidate(lookup_key) + cached_session_id = await _apply_cached_instance_or_invalidate( + protocol="chat", + inst=inst, + cached_instance_name=cached_instance_name, + cached_session_id=cached_session_id, + lookup_key=lookup_key, + ) await _ensure_instance_logged_in(inst) @@ -831,15 +881,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): _ticket.release() ticket_transferred = True - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache, no-transform", - "X-Accel-Buffering": "no", - "Connection": "keep-alive", - }, - ) + return _streaming_response(event_stream()) + try: result = await inst.client.chat_complete( @@ -1329,9 +1372,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): tool_config = _anthropic_tool_config(req) has_tooling_context = _anthropic_has_tooling_context(req) - ask_mode = settings.default_ask_mode - if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context: - ask_mode = "agent" + ask_mode = _resolve_ask_mode(req.model, has_tooling_context) reuse_eligible = ( session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context @@ -1341,9 +1382,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: - lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) - write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config) + prefix_branch_context = hash_branch_context(messages_dump[:-1]) + lookup_key = session_cache.build_key( + api_key, + messages_dump[:-1], + tool_config=tool_config, + branch_context=prefix_branch_context, + ) + write_key = session_cache.build_key( + api_key, + messages_dump, + tool_config=tool_config, + branch_context=hash_branch_context(messages_dump), + ) entry = await session_cache.get(lookup_key) + if entry is None: + legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) + entry = await session_cache.get(legacy_lookup_key) + if entry is not None: + lookup_key = legacy_lookup_key if entry is not None: cached_session_id = entry.session_id cached_instance_name = entry.instance_name or None @@ -1613,15 +1670,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): _ticket.release() ticket_transferred = True - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache, no-transform", - "X-Accel-Buffering": "no", - "Connection": "keep-alive", - }, - ) + return _streaming_response(event_stream()) + # ------------------------------------------------------------- non-stream try: diff --git a/app/session_cache.py b/app/session_cache.py index 1ada29a..9dde779 100644 --- a/app/session_cache.py +++ b/app/session_cache.py @@ -26,7 +26,7 @@ class SessionEntry: def hash_user_context(messages: list[dict]) -> str: """Hash the user/system/developer turns of a message list. - We deliberately skip `assistant`/`tool` messages because: + We deliberately skip `assistant`/`tool` messages here because: - Clients may subtly reformat or trim assistant replies between turns, breaking exact-match keying. - Only the *inputs* are stable, and they're sufficient to identify a @@ -43,6 +43,28 @@ def hash_user_context(messages: list[dict]) -> str: return h.hexdigest() +def hash_branch_context(messages: list[dict]) -> str: + """Hash assistant/tool turns to reduce branch collisions.""" + h = hashlib.sha1() + for m in messages: + role = m.get("role", "") + if role not in ("assistant", "tool"): + continue + content = m.get("content") + text = content if isinstance(content, str) else flatten_content(content) + tool_calls = m.get("tool_calls") + if tool_calls is not None: + try: + tool_calls_text = json.dumps(tool_calls, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + except Exception: + tool_calls_text = str(tool_calls) + else: + tool_calls_text = "" + tool_call_id = m.get("tool_call_id") or "" + h.update(f"{role}\x1f{text or ''}\x1f{tool_calls_text}\x1f{tool_call_id}\x1e".encode("utf-8")) + return h.hexdigest() + + def _tool_fingerprint(tool_config: dict | None) -> str: if not isinstance(tool_config, dict): return "-" @@ -90,11 +112,21 @@ class SessionCache: def enabled(self) -> bool: return self.max > 0 - def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str: + def build_key( + self, + api_key: str, + messages: list[dict], + *, + tool_config: dict | None = None, + branch_context: str | None = None, + ) -> str: # API key scoping prevents cross-tenant session leakage even when # different clients happen to produce identical histories. key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12] - return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}" + base = f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}" + if not branch_context: + return base + return f"{base}:{branch_context}" async def get(self, key: str) -> SessionEntry | None: if not self.enabled: diff --git a/tests/test_pool_stats_config.py b/tests/test_pool_stats_config.py new file mode 100644 index 0000000..81099db --- /dev/null +++ b/tests/test_pool_stats_config.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import json +import os +import sys +import types +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +# app.lingma_pool imports auto_login; tests here don't execute Playwright paths. +# Stub module import so test environments without playwright can import pool code. +_playwright = types.ModuleType("playwright") +_playwright_async = types.ModuleType("playwright.async_api") + + +class _StubPlaywrightTimeoutError(Exception): + pass + + +async def _stub_async_playwright(): + raise RuntimeError("playwright is stubbed in unit tests") + + +_playwright_async.TimeoutError = _StubPlaywrightTimeoutError +_playwright_async.async_playwright = _stub_async_playwright +sys.modules.setdefault("playwright", _playwright) +sys.modules.setdefault("playwright.async_api", _playwright_async) + +from app.config import _parse_accounts, load_settings +from app.lingma_pool import LingmaPool +from app.stats import StatsCollector, estimate_tokens + + +def _affinity_key_for_bucket(pool_size: int, bucket_index: int) -> str: + for i in range(20000): + key = f"k-{i}" + if abs(hash(key)) % pool_size == bucket_index: + return key + raise RuntimeError("failed to find affinity key") + + +class _FakeInstance: + def __init__(self, idx: int, *, healthy: bool, in_flight: int): + self.name = f"inst-{idx}" + self.cfg = SimpleNamespace(index=idx) + self._healthy = healthy + self.in_flight = in_flight + + @property + def healthy(self) -> bool: + return self._healthy + + +class LingmaPoolRoutingTests(unittest.TestCase): + def test_pool_pick_prefers_healthy_affinity_bucket(self) -> None: + inst0 = _FakeInstance(0, healthy=True, in_flight=0) + inst1 = _FakeInstance(1, healthy=True, in_flight=9) + pool = LingmaPool([inst0, inst1]) + + key = _affinity_key_for_bucket(2, 1) + picked = pool.pick(affinity_key=key) + + self.assertIs(picked, inst1) + + def test_pool_pick_falls_back_to_least_in_flight_when_affinity_unhealthy(self) -> None: + inst0 = _FakeInstance(0, healthy=True, in_flight=1) + inst1 = _FakeInstance(1, healthy=False, in_flight=0) + inst2 = _FakeInstance(2, healthy=True, in_flight=1) + pool = LingmaPool([inst0, inst1, inst2]) + + key = _affinity_key_for_bucket(3, 1) + picked = pool.pick(affinity_key=key) + + self.assertIs(picked, inst0) + + def test_pool_pick_round_robin_when_all_unhealthy(self) -> None: + inst0 = _FakeInstance(0, healthy=False, in_flight=0) + inst1 = _FakeInstance(1, healthy=False, in_flight=0) + inst2 = _FakeInstance(2, healthy=False, in_flight=0) + pool = LingmaPool([inst0, inst1, inst2]) + + self.assertIs(pool.pick(), inst0) + self.assertIs(pool.pick(), inst1) + self.assertIs(pool.pick(), inst2) + self.assertIs(pool.pick(), inst0) + + def test_pool_prometheus_lines_include_required_metrics(self) -> None: + inst0 = _FakeInstance(0, healthy=True, in_flight=2) + inst1 = _FakeInstance(1, healthy=False, in_flight=5) + pool = LingmaPool([inst0, inst1]) + + text = "\n".join(pool.prometheus_lines()) + + self.assertIn("# TYPE gateway_pool_instance_in_flight gauge", text) + self.assertIn("# TYPE gateway_pool_instance_ready gauge", text) + self.assertIn('gateway_pool_instance_in_flight{name="inst-0",idx="0"} 2', text) + self.assertIn('gateway_pool_instance_ready{name="inst-0",idx="0"} 1', text) + self.assertIn('gateway_pool_instance_ready{name="inst-1",idx="1"} 0', text) + + +class StatsCollectorTests(unittest.IsolatedAsyncioTestCase): + def test_estimate_tokens_empty_short_utf8(self) -> None: + self.assertEqual(estimate_tokens(""), 0) + self.assertGreaterEqual(estimate_tokens("a"), 1) + self.assertEqual(estimate_tokens("你好世界"), 3) + + async def test_record_chat_updates_counters_and_clamps_negative_tokens(self) -> None: + s = StatsCollector() + + await s.record_chat(stream=True, success=True, prompt_tokens=-3, completion_tokens=5) + await s.record_chat(stream=False, success=False, prompt_tokens=2, completion_tokens=-7) + snap = await s.snapshot() + + self.assertEqual(snap["chat_requests_total"], 2) + self.assertEqual(snap["chat_requests_success"], 1) + self.assertEqual(snap["chat_requests_error"], 1) + self.assertEqual(snap["chat_stream_requests"], 1) + self.assertEqual(snap["chat_non_stream_requests"], 1) + self.assertEqual(snap["prompt_tokens_estimated_total"], 2) + self.assertEqual(snap["completion_tokens_estimated_total"], 5) + + async def test_snapshot_and_prometheus_text_consistency(self) -> None: + s = StatsCollector() + + await s.record_chat(stream=True, success=True, prompt_tokens=3, completion_tokens=4) + snap = await s.snapshot() + text = await s.prometheus_text() + + self.assertEqual(snap["total_tokens_estimated"], 7) + self.assertIn("gateway_total_tokens_estimated 7", text) + self.assertIn("gateway_chat_requests_total 1", text) + self.assertTrue(text.endswith("\n")) + + +class ConfigParsingTests(unittest.TestCase): + def test_parse_accounts_accepts_json_csv_newline_formats(self) -> None: + raw_json = json.dumps([ + {"username": "u1", "password": "p1"}, + {"username": "u2", "password": "p2"}, + ]) + parsed_json = _parse_accounts(raw_json) + self.assertEqual([a.username for a in parsed_json], ["u1", "u2"]) + + parsed_csv = _parse_accounts("u3:p3,u4:p4") + self.assertEqual([a.username for a in parsed_csv], ["u3", "u4"]) + + parsed_nl = _parse_accounts("u5:p5\nu6:p6") + self.assertEqual([a.username for a in parsed_nl], ["u5", "u6"]) + + def test_parse_accounts_allows_bundle_only_in_json(self) -> None: + raw = json.dumps([{"session_bundle": "abc"}]) + parsed = _parse_accounts(raw) + + self.assertEqual(len(parsed), 1) + self.assertEqual(parsed[0].username, "") + self.assertEqual(parsed[0].password, "") + self.assertEqual(parsed[0].session_bundle_b64, "abc") + + def test_parse_accounts_csv_splits_only_first_colon(self) -> None: + parsed = _parse_accounts("u:p:with:colon") + + self.assertEqual(len(parsed), 1) + self.assertEqual(parsed[0].username, "u") + self.assertEqual(parsed[0].password, "p:with:colon") + + def test_load_settings_creates_bundle_only_account_without_credentials(self) -> None: + with patch.dict(os.environ, {"LINGMA_SESSION_BUNDLE": "abc"}, clear=True): + settings = load_settings() + + self.assertEqual(len(settings.accounts), 1) + self.assertEqual(settings.accounts[0].username, "") + self.assertEqual(settings.accounts[0].password, "") + self.assertEqual(settings.accounts[0].session_bundle_b64, "abc") + + def test_load_settings_invalid_instance_count_fallback(self) -> None: + with patch.dict( + os.environ, + {"LINGMA_ACCOUNTS": "u1:p1,u2:p2", "LINGMA_INSTANCE_COUNT": "not-a-number"}, + clear=True, + ): + settings_with_accounts = load_settings() + + self.assertEqual(settings_with_accounts.instance_count, 2) + + with patch.dict(os.environ, {"LINGMA_INSTANCE_COUNT": "not-a-number"}, clear=True): + settings_without_accounts = load_settings() + + self.assertEqual(settings_without_accounts.instance_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_schema_normalization.py b/tests/test_schema_normalization.py index fbfbd95..9f7c7bf 100644 --- a/tests/test_schema_normalization.py +++ b/tests/test_schema_normalization.py @@ -17,11 +17,12 @@ class SchemaNormalizationTests(unittest.TestCase): [ {"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "x"}}, + {"type": "input_image", "image_url": {"url": "y"}}, {"type": "input_audio", "input_audio": {"data": "x"}}, {"type": "text", "text": "world"}, ] ) - self.assertEqual(out, "hello\n[image]\n[audio]\nworld") + self.assertEqual(out, "hello\n[image]\n[image]\n[audio]\nworld") def test_anthropic_flatten_content_with_tool_blocks(self) -> None: out = flatten_anthropic_content( diff --git a/tests/test_session_cache_tooling.py b/tests/test_session_cache_tooling.py index e168d65..92b0d94 100644 --- a/tests/test_session_cache_tooling.py +++ b/tests/test_session_cache_tooling.py @@ -2,7 +2,7 @@ from __future__ import annotations import unittest -from app.session_cache import SessionCache, hash_user_context +from app.session_cache import SessionCache, hash_branch_context, hash_user_context class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase): @@ -17,6 +17,21 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase): ] self.assertEqual(hash_user_context(base), hash_user_context(with_extra)) + def test_hash_branch_context_distinguishes_assistant_tool_branch(self) -> None: + base = [ + {"role": "system", "content": "S"}, + {"role": "user", "content": "U"}, + {"role": "assistant", "content": "A1"}, + {"role": "tool", "content": "T1", "tool_call_id": "call-1"}, + ] + changed = [ + {"role": "system", "content": "S"}, + {"role": "user", "content": "U"}, + {"role": "assistant", "content": "A2"}, + {"role": "tool", "content": "T1", "tool_call_id": "call-1"}, + ] + self.assertNotEqual(hash_branch_context(base), hash_branch_context(changed)) + def test_build_key_changes_with_tool_config(self) -> None: cache = SessionCache(max_entries=8, ttl_sec=60) msgs = [{"role": "user", "content": "hi"}] @@ -26,6 +41,14 @@ class SessionCacheToolingTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(key1, key2) self.assertNotEqual(key1, key3) + def test_build_key_keeps_legacy_shape_without_branch_context(self) -> None: + cache = SessionCache(max_entries=8, ttl_sec=60) + msgs = [{"role": "user", "content": "hi"}] + legacy = cache.build_key("k", msgs) + with_branch = cache.build_key("k", msgs, branch_context="abc") + self.assertEqual(legacy.count(":"), 2) + self.assertEqual(with_branch.count(":"), 3) + async def test_lru_evicts_oldest(self) -> None: cache = SessionCache(max_entries=2, ttl_sec=600) await cache.put("k1", "s1") diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index 57119fe..8b7125a 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -15,9 +15,10 @@ class _FakeSessionCache: self.put_calls: list[tuple[str, str, str]] = [] self.invalidate_calls: list[str] = [] - def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str: + def build_key(self, api_key: str, messages: list[dict], *, tool_config=None, branch_context=None) -> str: marker = "with_tool" if tool_config is not None else "no_tool" - key = f"{api_key}:{len(messages)}:{marker}" + branch_marker = branch_context or "-" + key = f"{api_key}:{len(messages)}:{marker}:branch={branch_marker}" self.keys.append(key) return key @@ -635,6 +636,93 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(fake_cache.put_calls, []) + async def test_openai_session_reuse_lookup_key_separates_branches(self) -> None: + fake_cache = _FakeSessionCache() + fake_client = _FakeClient( + stream_events=[], + complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-branch"}, + ) + + req_a = ChatCompletionsRequest( + model="org_auto", + messages=[ + {"role": "system", "content": "S"}, + {"role": "user", "content": "U"}, + {"role": "assistant", "content": "A1"}, + {"role": "user", "content": "next"}, + ], + stream=False, + ) + req_b = ChatCompletionsRequest( + model="org_auto", + messages=[ + {"role": "system", "content": "S"}, + {"role": "user", "content": "U"}, + {"role": "assistant", "content": "A2"}, + {"role": "user", "content": "next"}, + ], + stream=False, + ) + + with ( + patch.object(main, "session_cache", fake_cache), + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), + ): + await main.v1_chat_completions(req_a, _make_request("/v1/chat/completions")) + await main.v1_chat_completions(req_b, _make_request("/v1/chat/completions")) + + self.assertGreaterEqual(len(fake_cache.get_calls), 4) + self.assertNotEqual(fake_cache.get_calls[0], fake_cache.get_calls[2]) + self.assertEqual(fake_cache.get_calls[1], fake_cache.get_calls[3]) + + async def test_openai_and_anthropic_resolve_same_default_ask_mode_without_tooling(self) -> None: + openai_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + anthropic_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + + openai_req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + anthropic_req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(openai_spy))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), + ): + await main.v1_chat_completions(openai_req, _make_request("/v1/chat/completions")) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(anthropic_spy))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), + ): + await main.v1_messages( + anthropic_req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + self.assertEqual(openai_spy.last_complete_args[2], "chat") + self.assertEqual(anthropic_spy.last_complete_args[2], "chat") + async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None: spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) req = AnthropicMessagesRequest(