diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..a079912 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Makes `tests.*` importable for unittest module discovery. diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index b6da1f7..02e47d1 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -16,7 +16,14 @@ class _FakeSessionCache: self.put_calls: list[tuple[str, str, str]] = [] self.invalidate_calls: list[str] = [] - def build_key(self, api_key: str, messages: list[dict], *, tool_config=None, branch_context=None) -> str: + def build_key( + self, + api_key: str, + messages: list[dict], + *, + tool_config=None, + branch_context=None, + ) -> str: marker = "with_tool" if tool_config is not None else "no_tool" branch_marker = branch_context or "-" key = f"{api_key}:{len(messages)}:{marker}:branch={branch_marker}" @@ -33,6 +40,7 @@ class _FakeSessionCache: async def invalidate(self, key: str) -> None: self.invalidate_calls.append(key) + # app.main imports playwright via auto_login; tests don't exercise that path. # Inject a lightweight stub so unit tests run without installing playwright. _playwright = types.ModuleType("playwright") @@ -172,7 +180,9 @@ class _SettingsPatch: self._kwargs = kwargs def __enter__(self): - self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()] + self._patchers = [ + patch.object(main.settings, k, v) for k, v in self._kwargs.items() + ] for p in self._patchers: p.start() return self @@ -211,10 +221,16 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) payload = json.loads(response.body) message = payload["choices"][0]["message"] @@ -227,11 +243,13 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"query": "gateway"}, ) - async def test_openai_non_stream_fallbacks_to_structured_tool_call_for_forced_tool(self) -> None: + async def test_openai_non_stream_does_not_synthesize_tool_call_from_plain_json( + self, + ) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ - "text": "```json\n{\"arguments\": {\"query\": \"gateway\"}}\n```", + "text": '```json\n{"arguments": {"query": "gateway"}}\n```', "toolEvents": [], "sessionId": "sess-fallback-openai", }, @@ -240,34 +258,39 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") - self.assertEqual(message["content"], "") - self.assertIsInstance(message["tool_calls"], list) - self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") - self.assertEqual( - json.loads(message["tool_calls"][0]["function"]["arguments"]), - {"query": "gateway"}, - ) + self.assertEqual(payload["choices"][0]["finish_reason"], "stop") + self.assertIn("arguments", message["content"]) + self.assertIsNone(message["tool_calls"]) - async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None: + async def test_openai_non_stream_does_not_synthesize_tool_call_from_tool_code( + self, + ) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ - "text": "```tool_code\nlookup(query=\"gateway\")\n```", + "text": '```tool_code\nlookup(query="gateway")\n```', "toolEvents": [], "sessionId": "sess-fallback-tool-code-openai", }, @@ -276,33 +299,39 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") - self.assertEqual(message["content"], "") - self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") - self.assertEqual( - json.loads(message["tool_calls"][0]["function"]["arguments"]), - {"query": "gateway"}, - ) + self.assertEqual(payload["choices"][0]["finish_reason"], "stop") + self.assertIn('lookup(query="gateway")', message["content"]) + self.assertIsNone(message["tool_calls"]) - async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool_with_positional_arg(self) -> None: + async def test_openai_non_stream_does_not_synthesize_tool_call_from_tool_code_positional_arg( + self, + ) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ - "text": "```tool_code\nlookup(\"gateway\")\n```", + "text": '```tool_code\nlookup("gateway")\n```', "toolEvents": [], "sessionId": "sess-fallback-tool-code-openai-positional", }, @@ -330,26 +359,30 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") - self.assertEqual(message["content"], "") - self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") - self.assertEqual( - json.loads(message["tool_calls"][0]["function"]["arguments"]), - {"query": "gateway"}, - ) + self.assertEqual(payload["choices"][0]["finish_reason"], "stop") + self.assertIn('lookup("gateway")', message["content"]) + self.assertIsNone(message["tool_calls"]) - async def test_openai_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None: + async def test_openai_stream_does_not_synthesize_tool_call_from_tool_code( + self, + ) -> None: fake_client = _FakeClient( stream_events=[ {"type": "text", "text": "```tool_code\n"}, - {"type": "text", "text": 'lookup(query=\"gateway\")\n'}, + {"type": "text", "text": 'lookup(query="gateway")\n'}, {"type": "text", "text": "```"}, ], complete_result={}, @@ -358,29 +391,41 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=True, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) body = await _collect_stream(response) - chunks = [json.loads(line[6:]) for line in body.splitlines() if line.startswith("data: {")] - tool_call_chunk = next(chunk for chunk in chunks if chunk["choices"] and chunk["choices"][0]["delta"].get("tool_calls")) - tool_call = tool_call_chunk["choices"][0]["delta"]["tool_calls"][0] - - self.assertIn('"tool_calls"', body) - self.assertEqual(tool_call["function"]["name"], "lookup") - self.assertEqual(json.loads(tool_call["function"]["arguments"]), {"query": "gateway"}) - self.assertNotIn('lookup(query=\\"gateway\\")', body) - self.assertIn('"finish_reason": "tool_calls"', body) - self.assertIn('data: [DONE]', body) + chunks = [ + json.loads(line[6:]) + for line in body.splitlines() + if line.startswith("data: {") + ] + self.assertFalse( + any( + chunk.get("choices") and chunk["choices"][0]["delta"].get("tool_calls") + for chunk in chunks + ) + ) + self.assertNotIn('"tool_calls"', body) + self.assertIn('"finish_reason": "stop"', body) + self.assertIn("data: [DONE]", body) async def test_openai_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( @@ -407,10 +452,16 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) body = await _collect_stream(response) self.assertIn('"tool_calls"', body) @@ -419,7 +470,6 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"usage"', body) self.assertIn("data: [DONE]", body) - async def test_openai_stream_emits_text_delta_only_once_without_tools(self) -> None: fake_client = _FakeClient( stream_events=[ @@ -436,10 +486,16 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) body = await _collect_stream(response) self.assertEqual(body.count('"content": "你好"'), 1) @@ -449,8 +505,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): async def test_openai_stream_filters_tool_events_by_allowlist(self) -> None: fake_client = _FakeClient( stream_events=[ - {"type": "tool", "tool": {"id": "call_blocked", "name": "write_file", "input": {"path": "a.txt"}}}, - {"type": "tool", "tool": {"id": "call_allowed", "name": "lookup", "input": {"query": "gateway"}}}, + { + "type": "tool", + "tool": { + "id": "call_blocked", + "name": "write_file", + "input": {"path": "a.txt"}, + }, + }, + { + "type": "tool", + "tool": { + "id": "call_allowed", + "name": "lookup", + "input": {"query": "gateway"}, + }, + }, {"type": "text", "text": "hello"}, ], complete_result={}, @@ -461,7 +531,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): stream=True, tools=[ {"type": "function", "function": {"name": "lookup", "parameters": {}}}, - {"type": "function", "function": {"name": "write_file", "parameters": {}}}, + { + "type": "function", + "function": {"name": "write_file", "parameters": {}}, + }, ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) @@ -469,11 +542,17 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) body = await _collect_stream(response) self.assertIn('"name": "lookup"', body) @@ -507,15 +586,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -526,11 +612,13 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(payload["content"][1]["name"], "lookup") self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") - async def test_anthropic_non_stream_fallbacks_to_structured_tool_blocks_for_forced_tool(self) -> None: + async def test_anthropic_non_stream_does_not_synthesize_tool_blocks_from_plain_json( + self, + ) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ - "text": "{\"input\": {\"k\": \"v\"}, \"result\": {\"value\": 1}}", + "text": '{"input": {"k": "v"}, "result": {"value": 1}}', "toolEvents": [], "sessionId": "sess-fallback-anthropic", }, @@ -540,31 +628,39 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): max_tokens=256, messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "tool", "name": "lookup"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) payload = json.loads(response.body) types = [item["type"] for item in payload["content"]] - self.assertEqual(types, ["tool_use", "tool_result"]) + self.assertEqual(types, ["text"]) self.assertEqual(payload["stop_reason"], "end_turn") - self.assertEqual(payload["content"][0]["name"], "lookup") - self.assertEqual(payload["content"][1]["tool_use_id"], "toolu_fallback_0") + self.assertIn('"input"', payload["content"][0]["text"]) async def test_openai_stream_tool_call_indices_are_stable(self) -> None: fake_client = _FakeClient( @@ -597,10 +693,16 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) body = await _collect_stream(response) self.assertIn('"id": "call_a"', body) @@ -608,7 +710,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"index": 0', body) self.assertIn('"index": 1', body) - async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None: + async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing( + self, + ) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ @@ -632,15 +736,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -649,7 +760,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(payload["content"]), 1) self.assertEqual(payload["content"][0]["type"], "tool_use") - async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None: + async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing( + self, + ) -> None: fake_client = _FakeClient( stream_events=[ { @@ -672,15 +785,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) body = await _collect_stream(response) @@ -688,10 +808,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "tool_use"', body) self.assertIn('"stop_reason": "tool_use"', body) - async def test_anthropic_stream_does_not_fallback_to_synthetic_tool_blocks_for_forced_tool(self) -> None: + async def test_anthropic_stream_does_not_fallback_to_synthetic_tool_blocks_for_forced_tool( + self, + ) -> None: fake_client = _FakeClient( stream_events=[ - {"type": "text", "text": '```json\n{"input": {"k": "v"}, "result": {"value": 1}}\n```'} + { + "type": "text", + "text": '```json\n{"input": {"k": "v"}, "result": {"value": 1}}\n```', + } ], complete_result={}, ) @@ -700,22 +825,31 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): max_tokens=256, messages=[{"role": "user", "content": "hi"}], stream=True, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "tool", "name": "lookup"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) body = await _collect_stream(response) @@ -751,15 +885,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) body = await _collect_stream(response) @@ -771,8 +912,6 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "text_delta"', body) self.assertIn("event: message_stop", body) - - async def test_anthropic_stream_filters_tool_events_by_allowlist(self) -> None: fake_client = _FakeClient( stream_events=[ @@ -804,8 +943,14 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): messages=[{"role": "user", "content": "hi"}], stream=True, tools=[ - {"name": "lookup", "input_schema": {"type": "object", "properties": {}}}, - {"name": "write_file", "input_schema": {"type": "object", "properties": {}}}, + { + "name": "lookup", + "input_schema": {"type": "object", "properties": {}}, + }, + { + "name": "write_file", + "input_schema": {"type": "object", "properties": {}}, + }, ], tool_choice={"type": "tool", "name": "lookup"}, ) @@ -813,8 +958,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): @@ -822,7 +971,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) body = await _collect_stream(response) @@ -832,23 +984,29 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "tool_result"', body) self.assertIn('"stop_reason": "end_turn"', body) - - async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True), ): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) @@ -861,23 +1019,33 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(spy_client.last_complete_args[2], "agent") async def test_openai_stream_forwards_tool_config_when_enabled(self) -> None: - spy_client = _SpyClient(stream_events=[{"type": "text", "text": "ok"}], complete_result={}) + spy_client = _SpyClient( + stream_events=[{"type": "text", "text": "ok"}], complete_result={} + ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=True, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True), ): - response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) await _collect_stream(response) self.assertIn("tool_config", spy_client.last_stream_kwargs) @@ -887,21 +1055,31 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIsInstance(cfg["tool_choice"], dict) self.assertEqual(spy_client.last_stream_args[2], "agent") - async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_openai_non_stream_does_not_forward_tool_config_when_disabled( + self, + ) -> None: + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=False), ): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) @@ -910,17 +1088,20 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) self.assertEqual(spy_client.last_complete_args[2], "agent") - - async def test_openai_non_stream_filters_tools_by_allowlist(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, tools=[ {"type": "function", "function": {"name": "lookup", "parameters": {}}}, - {"type": "function", "function": {"name": "write_file", "parameters": {}}}, + { + "type": "function", + "function": {"name": "write_file", "parameters": {}}, + }, ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) @@ -928,35 +1109,53 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) cfg = spy_client.last_complete_kwargs["tool_config"] - self.assertEqual([tool["function"]["name"] for tool in cfg["tools"]], ["lookup"]) + self.assertEqual( + [tool["function"]["name"] for tool in cfg["tools"]], ["lookup"] + ) self.assertEqual(cfg["tool_choice"], req.tool_choice) - async def test_openai_non_stream_rejects_forced_tool_outside_allowlist(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_openai_non_stream_rejects_forced_tool_outside_allowlist( + self, + ) -> None: + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "write_file"}}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): with self.assertRaises(main.HTTPException) as cm: - await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) self.assertEqual(cm.exception.status_code, 400) self.assertEqual(cm.exception.detail["error"]["type"], "invalid_request_error") @@ -975,7 +1174,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"role": "user", "content": "turn-2"}, ], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) @@ -983,8 +1184,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): patch.object(main, "session_cache", fake_cache), patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(tool_forward_enabled=True), ): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) @@ -993,12 +1198,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(fake_cache.get_calls, []) self.assertEqual(fake_cache.put_calls, []) - async def test_openai_session_reuse_lookup_key_separates_branches(self) -> None: fake_cache = _FakeSessionCache() fake_client = _FakeClient( stream_events=[], - complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-branch"}, + complete_result={ + "text": "ok", + "toolEvents": [], + "sessionId": "sess-branch", + }, ) req_a = ChatCompletionsRequest( @@ -1026,8 +1234,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): patch.object(main, "session_cache", fake_cache), patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), ): await main.v1_chat_completions(req_a, _make_request("/v1/chat/completions")) @@ -1037,9 +1249,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertNotEqual(fake_cache.get_calls[0], fake_cache.get_calls[2]) self.assertEqual(fake_cache.get_calls[1], fake_cache.get_calls[3]) - async def test_openai_and_anthropic_resolve_same_default_ask_mode_without_tooling(self) -> None: - openai_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) - anthropic_spy = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_openai_and_anthropic_resolve_same_default_ask_mode_without_tooling( + self, + ) -> None: + openai_spy = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) + anthropic_spy = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) openai_req = ChatCompletionsRequest( model="org_auto", @@ -1056,17 +1274,27 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(openai_spy))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), ): - await main.v1_chat_completions(openai_req, _make_request("/v1/chat/completions")) + await main.v1_chat_completions( + openai_req, _make_request("/v1/chat/completions") + ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(anthropic_spy))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(default_ask_mode="chat", tool_forward_enabled=False), ): @@ -1074,7 +1302,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): anthropic_req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1082,21 +1313,29 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(anthropic_spy.last_complete_args[2], "chat") async def test_anthropic_stream_forwards_tool_config_when_enabled(self) -> None: - spy_client = _SpyClient(stream_events=[{"type": "text", "text": "ok"}], complete_result={}) + spy_client = _SpyClient( + stream_events=[{"type": "text", "text": "ok"}], complete_result={} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=128, messages=[{"role": "user", "content": "hi"}], stream=True, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "tool", "name": "lookup"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=True, default_ask_mode="chat"), ): @@ -1104,7 +1343,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) await _collect_stream(response) @@ -1115,22 +1357,32 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(cfg["tools"]), 1) self.assertEqual(spy_client.last_stream_args[2], "agent") - async def test_anthropic_non_stream_does_not_forward_tool_config_when_disabled(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_anthropic_non_stream_does_not_forward_tool_config_when_disabled( + self, + ) -> None: + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=128, messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "tool", "name": "lookup"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=False, default_ask_mode="chat"), ): @@ -1138,7 +1390,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1147,21 +1402,32 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(spy_client.last_complete_args[2], "agent") async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=128, messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"name": "write_file", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + { + "name": "write_file", + "input_schema": {"type": "object", "properties": {}}, + } + ], tool_choice={"type": "auto"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=True, default_ask_mode="chat"), ): @@ -1169,7 +1435,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1179,17 +1448,24 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(cfg["tools"]), 1) self.assertEqual(spy_client.last_complete_args[2], "agent") - async def test_anthropic_non_stream_filters_tools_by_allowlist(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=128, messages=[{"role": "user", "content": "hi"}], stream=False, tools=[ - {"name": "lookup", "input_schema": {"type": "object", "properties": {}}}, - {"name": "write_file", "input_schema": {"type": "object", "properties": {}}}, + { + "name": "lookup", + "input_schema": {"type": "object", "properties": {}}, + }, + { + "name": "write_file", + "input_schema": {"type": "object", "properties": {}}, + }, ], tool_choice={"type": "tool", "name": "lookup"}, ) @@ -1197,8 +1473,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): @@ -1206,7 +1486,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1214,22 +1497,32 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual([tool["name"] for tool in cfg["tools"]], ["lookup"]) self.assertEqual(cfg["tool_choice"], req.tool_choice) - async def test_anthropic_non_stream_rejects_forced_tool_outside_allowlist(self) -> None: - spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_anthropic_non_stream_rejects_forced_tool_outside_allowlist( + self, + ) -> None: + spy_client = _SpyClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=128, messages=[{"role": "user", "content": "hi"}], stream=False, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "tool", "name": "write_file"}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), ): @@ -1237,7 +1530,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1261,7 +1557,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"role": "user", "content": "turn-2"}, ], stream=False, - tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}} + ], tool_choice={"type": "auto"}, ) @@ -1269,15 +1567,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): patch.object(main, "session_cache", fake_cache), patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), - patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1297,13 +1602,23 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req, _make_request( "/v1/messages/count_tokens", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) payload = json.loads(response.body) self.assertEqual(response.status_code, 200) - self.assertEqual(payload, {"input_tokens": main.estimate_tokens(main._messages_to_prompt(main.anthropic_to_internal_messages(req)))}) + self.assertEqual( + payload, + { + "input_tokens": main.estimate_tokens( + main._messages_to_prompt(main.anthropic_to_internal_messages(req)) + ) + }, + ) async def test_anthropic_count_tokens_requires_authentication(self) -> None: req = AnthropicMessagesRequest( @@ -1348,8 +1663,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(payload["type"], "error") self.assertEqual(payload["error"]["type"], "authentication_error") - async def test_anthropic_messages_backpressure_returns_overloaded_error(self) -> None: - fake_client = _FakeClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + async def test_anthropic_messages_backpressure_returns_overloaded_error( + self, + ) -> None: + fake_client = _FakeClient( + stream_events=[], complete_result={"text": "ok", "toolEvents": []} + ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=64, @@ -1364,14 +1683,19 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", fake_guard), - patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", - headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, ), ) @@ -1411,7 +1735,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(payload["object"], "response") self.assertEqual(payload["status"], "completed") self.assertEqual(payload["output_text"], "done") - self.assertEqual(payload["usage"], {"input_tokens": 4, "output_tokens": 2, "total_tokens": 6}) + self.assertEqual( + payload["usage"], {"input_tokens": 4, "output_tokens": 2, "total_tokens": 6} + ) self.assertEqual(payload["output"][0]["type"], "message") self.assertEqual(payload["output"][0]["content"][0]["type"], "output_text") @@ -1419,9 +1745,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): chat_req = mock_chat.await_args.args[0] self.assertIsInstance(chat_req, ChatCompletionsRequest) messages_dump = [m.model_dump() for m in chat_req.messages] - self.assertEqual(messages_dump, [{"role": "user", "content": "hello from responses", "name": None, "tool_call_id": None, "tool_calls": None}]) + self.assertEqual( + messages_dump, + [ + { + "role": "user", + "content": "hello from responses", + "name": None, + "tool_call_id": None, + "tool_calls": None, + } + ], + ) - async def test_responses_non_stream_maps_chat_tool_calls_to_function_call_output(self) -> None: + async def test_responses_non_stream_maps_chat_tool_calls_to_function_call_output( + self, + ) -> None: req = ResponsesRequest( model="org_auto", input="tool please", @@ -1444,7 +1783,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): "type": "function", "function": { "name": "lookup", - "arguments": "{\"q\":\"gateway\"}", + "arguments": '{"q":"gateway"}', }, } ], @@ -1461,29 +1800,50 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): payload = json.loads(response.body) self.assertEqual(payload["status"], "completed") self.assertEqual(payload["output_text"], "") - self.assertEqual(payload["usage"], {"input_tokens": 8, "output_tokens": 3, "total_tokens": 11}) + self.assertEqual( + payload["usage"], + {"input_tokens": 8, "output_tokens": 3, "total_tokens": 11}, + ) self.assertEqual(len(payload["output"]), 1) self.assertEqual(payload["output"][0]["type"], "function_call") self.assertEqual(payload["output"][0]["call_id"], "call_1") self.assertEqual(payload["output"][0]["id"], "call_1") self.assertEqual(payload["output"][0]["name"], "lookup") - self.assertEqual(payload["output"][0]["arguments"], "{\"q\":\"gateway\"}") + self.assertEqual(payload["output"][0]["arguments"], '{"q":"gateway"}') - async def test_responses_forwards_input_tools_and_tool_choice_to_chat_request(self) -> None: + async def test_responses_forwards_input_tools_and_tool_choice_to_chat_request( + self, + ) -> None: req = ResponsesRequest( model="org_auto", instructions="be concise", input=[ {"role": "user", "content": [{"type": "text", "text": "first"}]}, - {"type": "function_call_output", "call_id": "call_1", "output": {"ok": True}}, + { + "type": "function_call_output", + "call_id": "call_1", + "output": {"ok": True}, + }, "follow up", ], stream=False, - tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], tool_choice={"type": "function", "function": {"name": "lookup"}}, ) - mock_chat = AsyncMock(return_value=JSONResponse(content={"id": "chatcmpl-x", "created": 1, "model": "org_auto", "choices": [{"message": {"role": "assistant", "content": "ok"}}], "usage": {}})) + mock_chat = AsyncMock( + return_value=JSONResponse( + content={ + "id": "chatcmpl-x", + "created": 1, + "model": "org_auto", + "choices": [{"message": {"role": "assistant", "content": "ok"}}], + "usage": {}, + } + ) + ) with patch.object(main, "v1_chat_completions", mock_chat): await main.v1_responses(req, _make_request("/v1/responses")) @@ -1503,7 +1863,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(messages_dump[3]["role"], "user") self.assertEqual(messages_dump[3]["content"], "follow up") - async def test_responses_stream_bridges_text_tool_and_completed_events(self) -> None: + async def test_responses_stream_bridges_text_tool_and_completed_events( + self, + ) -> None: async def _chat_sse(): yield b'data: {"choices": [{"delta": {"content": "hello"}}]}\n\n' yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\": \\"x\\"}"}}]}}]}\n\n' @@ -1535,12 +1897,12 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "response.completed"', body) self.assertIn('"input_tokens": 3', body) self.assertIn('"output_tokens": 2', body) - self.assertIn('data: [DONE]', body) + self.assertIn("data: [DONE]", body) async def test_responses_stream_accumulates_fragmented_tool_arguments(self) -> None: async def _chat_sse(): yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n' - yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": " \\\"x\\\"}"}}]}}]}\n\n' + yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": " \\"x\\"}"}}]}}]}\n\n' yield b"data: [DONE]\n\n" req = ResponsesRequest(model="org_auto", input="hi", stream=True) @@ -1554,17 +1916,19 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "response.function_call_arguments.delta"', body) self.assertIn('"delta": "{\\"q\\":"', body) - self.assertIn('"delta": " \\\"x\\\"}"', body) + self.assertIn('"delta": " \\"x\\"}"', body) self.assertIn('"type": "response.function_call_arguments.done"', body) - self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) + self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body) self.assertIn('"type": "response.output_item.done"', body) - self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) - self.assertIn('data: [DONE]', body) + self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body) + self.assertIn("data: [DONE]", body) - async def test_responses_stream_accumulates_fragmented_tool_arguments_without_repeated_id_or_name(self) -> None: + async def test_responses_stream_accumulates_fragmented_tool_arguments_without_repeated_id_or_name( + self, + ) -> None: async def _chat_sse(): yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n' - yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\\"x\\\"}"}}]}}]}\n\n' + yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\"x\\"}"}}]}}]}\n\n' yield b"data: [DONE]\n\n" req = ResponsesRequest(model="org_auto", input="hi", stream=True) @@ -1579,18 +1943,22 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(body.count('"item_id": "call_1"'), 3) self.assertIn('"name": "lookup"', body) self.assertIn('"delta": "{\\"q\\":"', body) - self.assertIn('"delta": " \\\"x\\\"}"', body) - self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) - self.assertIn('data: [DONE]', body) + self.assertIn('"delta": " \\"x\\"}"', body) + self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body) + self.assertIn("data: [DONE]", body) - async def test_responses_stream_emits_completed_when_upstream_closes_without_done(self) -> None: + async def test_responses_stream_emits_completed_when_upstream_closes_without_done( + self, + ) -> None: async def _chat_sse_without_done(): yield b'data: {"choices": [{"delta": {"content": "partial"}}]}\n\n' yield b'data: {"usage": {"prompt_tokens": 7, "completion_tokens": 1, "total_tokens": 8}, "choices": [{"delta": {}}]}\n\n' req = ResponsesRequest(model="org_auto", input="hi", stream=True) mock_chat = AsyncMock( - return_value=StreamingResponse(_chat_sse_without_done(), media_type="text/event-stream") + return_value=StreamingResponse( + _chat_sse_without_done(), media_type="text/event-stream" + ) ) with patch.object(main, "v1_chat_completions", mock_chat): @@ -1602,16 +1970,20 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "response.completed"', body) self.assertIn('"input_tokens": 7', body) self.assertIn('"output_tokens": 1', body) - self.assertIn('data: [DONE]', body) + self.assertIn("data: [DONE]", body) - async def test_responses_stream_emits_completed_when_upstream_iterator_errors(self) -> None: + async def test_responses_stream_emits_completed_when_upstream_iterator_errors( + self, + ) -> None: async def _chat_sse_error(): yield b'data: {"choices": [{"delta": {"content": "partial"}}]}\n\n' raise RuntimeError("boom") req = ResponsesRequest(model="org_auto", input="hi", stream=True) mock_chat = AsyncMock( - return_value=StreamingResponse(_chat_sse_error(), media_type="text/event-stream") + return_value=StreamingResponse( + _chat_sse_error(), media_type="text/event-stream" + ) ) with patch.object(main, "v1_chat_completions", mock_chat): @@ -1621,7 +1993,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "response.output_text.delta"', body) self.assertIn('"delta": "partial"', body) self.assertIn('"type": "response.completed"', body) - self.assertIn('data: [DONE]', body) + self.assertIn("data: [DONE]", body) async def test_responses_stream_emits_completed_when_upstream_cancels(self) -> None: async def _chat_sse_cancelled(): @@ -1630,7 +2002,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): req = ResponsesRequest(model="org_auto", input="hi", stream=True) mock_chat = AsyncMock( - return_value=StreamingResponse(_chat_sse_cancelled(), media_type="text/event-stream") + return_value=StreamingResponse( + _chat_sse_cancelled(), media_type="text/event-stream" + ) ) with patch.object(main, "v1_chat_completions", mock_chat): @@ -1640,8 +2014,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "response.output_text.delta"', body) self.assertIn('"delta": "partial"', body) self.assertIn('"type": "response.completed"', body) - self.assertIn('data: [DONE]', body) - + self.assertIn("data: [DONE]", body) async def test_responses_alias_matches_v1_responses_behavior(self) -> None: req = ResponsesRequest(model="org_auto", input="hello", stream=False) @@ -1669,7 +2042,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): mock_chat.assert_awaited_once() req = ResponsesRequest(model="org_auto", input="hi", stream=False) - mock_chat = AsyncMock(return_value=Response(content="not-json", media_type="text/plain")) + mock_chat = AsyncMock( + return_value=Response(content="not-json", media_type="text/plain") + ) with patch.object(main, "v1_chat_completions", mock_chat): with self.assertRaises(main.HTTPException) as cm: @@ -1693,31 +2068,43 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): cfg_a = { "provider": "openai", - "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + "tools": [ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], "tool_choice": {"type": "function", "function": {"name": "lookup"}}, } cfg_a_reordered = { "tool_choice": {"function": {"name": "lookup"}, "type": "function"}, - "tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}], + "tools": [ + {"function": {"parameters": {}, "name": "lookup"}, "type": "function"} + ], "provider": "openai", } cfg_b = { "provider": "openai", - "tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}], + "tools": [ + { + "type": "function", + "function": {"name": "lookup_v2", "parameters": {}}, + } + ], "tool_choice": {"type": "function", "function": {"name": "lookup_v2"}}, } key_no_tool = cache.build_key("api-key", messages) key_a = cache.build_key("api-key", messages, tool_config=cfg_a) - key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered) + key_a_reordered = cache.build_key( + "api-key", messages, tool_config=cfg_a_reordered + ) key_b = cache.build_key("api-key", messages, tool_config=cfg_b) self.assertNotEqual(key_no_tool, key_a) self.assertEqual(key_a, key_a_reordered) self.assertNotEqual(key_a, key_b) - - def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None: + def test_handle_server_message_drops_unroutable_tool_event_without_request_id( + self, + ) -> None: from app.lingma_client import LspWsRpcClient rpc = LspWsRpcClient("ws://127.0.0.1:1") @@ -1861,7 +2248,6 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): }, ) - def test_tool_sync_triggers_approve_and_invoke_result_requests(self) -> None: from app.lingma_client import LspWsRpcClient @@ -1900,7 +2286,9 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): self.assertIn("tool/call/approve", methods) self.assertIn("tool/invokeResult", methods) - approve = next(item for item in decoded if item.get("method") == "tool/call/approve") + approve = next( + item for item in decoded if item.get("method") == "tool/call/approve" + ) self.assertEqual( approve["params"], { @@ -1912,7 +2300,9 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): }, ) - invoke_result = next(item for item in decoded if item.get("method") == "tool/invokeResult") + invoke_result = next( + item for item in decoded if item.get("method") == "tool/invokeResult" + ) self.assertEqual(invoke_result["params"]["toolCallId"], "call-1") self.assertEqual(invoke_result["params"]["name"], "run_in_terminal") self.assertTrue(invoke_result["params"]["success"])