diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py index 6ab982d..10ab7ae 100644 --- a/app/http/tool_bridge.py +++ b/app/http/tool_bridge.py @@ -161,6 +161,27 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None: return parsed if isinstance(parsed, dict) else None +def _json_tool_candidate_from_text(text: str) -> dict[str, Any] | None: + raw = text.strip() + if not raw: + return None + if raw.startswith("```") and raw.endswith("```"): + raw = raw[3:-3].strip() + if raw.lower().startswith("json"): + raw = raw[4:].strip() + try: + parsed = json.loads(raw) + except Exception: + return None + if isinstance(parsed, dict): + return parsed + if isinstance(parsed, list) and parsed: + first = parsed[0] + if isinstance(first, dict): + return first + return None + + def _tool_code_single_arg_name( tools: list[dict[str, Any]] | None, forced_tool_name: str ) -> str | None: @@ -199,11 +220,15 @@ def _tool_code_object_from_text( single_arg_name: str | None = None, ) -> dict[str, Any] | None: raw = text.strip() - if not raw.startswith("```tool_code") or not raw.endswith("```"): + if not raw.startswith("```") or not raw.endswith("```"): return None lines = raw.splitlines() if len(lines) < 2: return None + fence = lines[0].strip().lower() + language = fence[3:].strip() + if language and language not in {"tool_code", "python", "py"}: + return None body = "\n".join(lines[1:-1]).strip() try: parsed = ast.parse(body, mode="eval") @@ -239,7 +264,7 @@ def _forced_tool_event_from_text( *, single_arg_name: str | None = None, ) -> dict[str, Any] | None: - parsed = _json_object_from_text(text) + parsed = _json_tool_candidate_from_text(text) if parsed is None: parsed = _tool_code_object_from_text( text, forced_tool_name, single_arg_name=single_arg_name diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index 02e47d1..84c6421 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -243,7 +243,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"query": "gateway"}, ) - async def test_openai_non_stream_does_not_synthesize_tool_call_from_plain_json( + async def test_openai_non_stream_synthesizes_tool_call_from_plain_json( self, ) -> None: fake_client = _FakeClient( @@ -280,11 +280,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "stop") - self.assertIn("arguments", message["content"]) - self.assertIsNone(message["tool_calls"]) + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) - async def test_openai_non_stream_does_not_synthesize_tool_call_from_tool_code( + async def test_openai_non_stream_synthesizes_tool_call_from_tool_code( self, ) -> None: fake_client = _FakeClient( @@ -321,11 +325,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "stop") - self.assertIn('lookup(query="gateway")', message["content"]) - self.assertIsNone(message["tool_calls"]) + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) - async def test_openai_non_stream_does_not_synthesize_tool_call_from_tool_code_positional_arg( + async def test_openai_non_stream_synthesizes_tool_call_from_tool_code_positional_arg( self, ) -> None: fake_client = _FakeClient( @@ -372,11 +380,15 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): payload = json.loads(response.body) message = payload["choices"][0]["message"] - self.assertEqual(payload["choices"][0]["finish_reason"], "stop") - self.assertIn('lookup("gateway")', message["content"]) - self.assertIsNone(message["tool_calls"]) + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) - async def test_openai_stream_does_not_synthesize_tool_call_from_tool_code( + async def test_openai_stream_synthesizes_tool_call_from_tool_code( self, ) -> None: fake_client = _FakeClient( @@ -417,16 +429,59 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): for line in body.splitlines() if line.startswith("data: {") ] - self.assertFalse( + self.assertTrue( any( chunk.get("choices") and chunk["choices"][0]["delta"].get("tool_calls") for chunk in chunks ) ) - self.assertNotIn('"tool_calls"', body) - self.assertIn('"finish_reason": "stop"', body) + self.assertIn('"tool_calls"', body) + self.assertIn('"finish_reason": "tool_calls"', body) self.assertIn("data: [DONE]", body) + async def test_openai_non_stream_synthesizes_tool_call_from_json_array(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": '```json\n[{"name": "lookup", "arguments": {"query": "gateway"}}]\n```', + "toolEvents": [], + "sessionId": "sess-fallback-openai-json-array", + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}} + ], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), + ): + response = await main.v1_chat_completions( + req, _make_request("/v1/chat/completions") + ) + + payload = json.loads(response.body) + message = payload["choices"][0]["message"] + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) + async def test_openai_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[