from __future__ import annotations import json import sys import types import unittest from unittest.mock import AsyncMock, patch # app.main imports playwright via auto_login; tests don't exercise that path. # Inject a lightweight stub so unit tests run without installing playwright. _playwright = types.ModuleType("playwright") _playwright_async = types.ModuleType("playwright.async_api") class _StubPlaywrightTimeoutError(Exception): pass async def _stub_async_playwright(): raise RuntimeError("playwright is stubbed in unit tests") _playwright_async.TimeoutError = _StubPlaywrightTimeoutError _playwright_async.async_playwright = _stub_async_playwright sys.modules.setdefault("playwright", _playwright) sys.modules.setdefault("playwright.async_api", _playwright_async) from starlette.requests import Request from app.anthropic_schema import AnthropicMessagesRequest from app.openai_schema import ChatCompletionsRequest import app.main as main class _FakeTicket: def __init__(self) -> None: self.released = False def release(self) -> None: self.released = True class _FakeGuard: def __init__(self) -> None: self.in_flight = 0 async def try_acquire(self) -> _FakeTicket: return _FakeTicket() class _FakeClient: def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None: self._stream_events = stream_events self._complete_result = complete_result async def query_models(self) -> dict: return { "chat": [ { "key": "org_auto", "displayName": "Auto", } ] } async def chat_complete(self, *args, **kwargs) -> dict: return self._complete_result async def chat_stream(self, *args, **kwargs): out_meta = kwargs.get("out_meta") if isinstance(out_meta, dict): out_meta["session_id"] = "sess-stream" for event in self._stream_events: yield event class _FakeInstance: def __init__(self, client: _FakeClient) -> None: self.name = "inst-test" self.client = client self.in_flight = 0 class _FakePool: def __init__(self, inst: _FakeInstance) -> None: self._inst = inst def pick(self, affinity_key: str | None = None) -> _FakeInstance: return self._inst def _make_request(path: str, headers: dict[str, str] | None = None) -> Request: header_pairs = [] for k, v in (headers or {}).items(): header_pairs.append((k.lower().encode("latin-1"), v.encode("latin-1"))) scope = { "type": "http", "http_version": "1.1", "method": "POST", "scheme": "http", "path": path, "raw_path": path.encode("latin-1"), "query_string": b"", "headers": header_pairs, "client": ("testclient", 12345), "server": ("testserver", 80), "root_path": "", } return Request(scope) async def _collect_stream(response) -> str: chunks: list[str] = [] async for part in response.body_iterator: if isinstance(part, bytes): chunks.append(part.decode("utf-8")) else: chunks.append(str(part)) return "".join(chunks) class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): async def test_openai_non_stream_bridges_tool_calls(self) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ "text": "done", "toolEvents": [ { "id": "call_123", "name": "search_docs", "input": {"query": "gateway"}, "result": {"ok": True}, } ], "sessionId": "sess-1", "firstTokenLatencyMs": 12, "totalLatencyMs": 34, }, ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=False, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), ): response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) payload = json.loads(response.body) message = payload["choices"][0]["message"] self.assertEqual(message["content"], "done") self.assertIsInstance(message["tool_calls"], list) self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs") self.assertEqual( json.loads(message["tool_calls"][0]["function"]["arguments"]), {"query": "gateway"}, ) async def test_openai_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[ { "type": "tool", "tool": { "id": "call_stream_1", "name": "read_file", "input": {"path": "README.md"}, }, }, {"type": "text", "text": "hello"}, ], complete_result={}, ) req = ChatCompletionsRequest( model="org_auto", messages=[{"role": "user", "content": "hi"}], stream=True, stream_options={"include_usage": True}, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), ): response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) body = await _collect_stream(response) self.assertIn('"tool_calls"', body) self.assertIn('"content": "hello"', body) self.assertIn('"usage"', body) self.assertIn("data: [DONE]", body) async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None: fake_client = _FakeClient( stream_events=[], complete_result={ "text": "ok", "toolEvents": [ { "id": "toolu_1", "name": "lookup", "input": {"k": "v"}, "result": {"value": 1}, } ], "sessionId": "sess-2", }, ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=256, messages=[{"role": "user", "content": "hi"}], stream=False, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, ), ) payload = json.loads(response.body) types = [item["type"] for item in payload["content"]] self.assertEqual(types, ["text", "tool_use", "tool_result"]) self.assertEqual(payload["content"][1]["name"], "lookup") self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[ { "type": "tool", "tool": { "id": "toolu_stream_1", "name": "read", "input": {"file": "a.txt"}, "result": "done", }, }, {"type": "text", "text": "world"}, ], complete_result={}, ) req = AnthropicMessagesRequest( model="claude-3-5-sonnet-20241022", max_tokens=256, messages=[{"role": "user", "content": "hi"}], stream=True, ) with ( patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), patch.object(main, "chat_guard", _FakeGuard()), patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), patch.object(main.settings, "api_keys", ["test-key"]), ): response = await main.v1_messages( req, _make_request( "/v1/messages", headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, ), ) body = await _collect_stream(response) self.assertIn("event: message_start", body) self.assertIn('"type": "tool_use"', body) self.assertIn('"type": "tool_result"', body) self.assertIn('"type": "text_delta"', body) self.assertIn("event: message_stop", body) if __name__ == "__main__": unittest.main()