diff --git a/app/main.py b/app/main.py index 6a429f2..a8d16c8 100644 --- a/app/main.py +++ b/app/main.py @@ -452,6 +452,93 @@ def _json_string(value: Any) -> str: return "{}" +def _openai_forced_tool_name(tool_choice: Any) -> str | None: + if not isinstance(tool_choice, dict): + return None + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _anthropic_forced_tool_name(tool_choice: Any) -> str | None: + if not isinstance(tool_choice, dict): + return None + if tool_choice.get("type") == "tool": + name = tool_choice.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _json_object_from_text(text: str) -> dict[str, Any] | None: + raw = text.strip() + if not raw: + return None + if raw.startswith("```") and raw.endswith("```"): + raw = raw[3:-3].strip() + if raw.lower().startswith("json"): + raw = raw[4:].strip() + try: + parsed = json.loads(raw) + except Exception: + return None + return parsed if isinstance(parsed, dict) else None + + +def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None: + parsed = _json_object_from_text(text) + if parsed is None: + return None + + explicit_name: Any = parsed.get("name") or parsed.get("tool") + fn = parsed.get("function") + if explicit_name is None and isinstance(fn, dict): + explicit_name = fn.get("name") + if explicit_name is not None and str(explicit_name) != forced_tool_name: + return None + + tool_input: Any = None + if "input" in parsed: + tool_input = parsed.get("input") + elif "arguments" in parsed: + args = parsed.get("arguments") + if isinstance(args, str): + try: + tool_input = json.loads(args) + except Exception: + return None + else: + tool_input = args + elif isinstance(fn, dict) and "arguments" in fn: + args = fn.get("arguments") + if isinstance(args, str): + try: + tool_input = json.loads(args) + except Exception: + return None + else: + tool_input = args + else: + reserved = {"name", "tool", "function", "arguments", "input", "result"} + tool_input = {k: v for k, v in parsed.items() if k not in reserved} + + event: dict[str, Any] = { + "name": forced_tool_name, + "input": tool_input if tool_input is not None else {}, + } + if "result" in parsed: + event["result"] = parsed.get("result") + return event + + def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: return { "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), @@ -800,6 +887,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_id = str(item.get("id") or f"call_{idx}") tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) saw_tool_call = True + if not saw_tool_call: + forced_tool_name = _openai_forced_tool_name(req.tool_choice) + if forced_tool_name: + fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name) + if fallback_event is not None: + tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0")) + saw_tool_call = True + message_content = "" response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), @@ -1249,10 +1344,12 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): content_blocks.append({"type": "text", "text": text}) tool_events = result.get("toolEvents") or [] saw_pending_tool_use = False + saw_tool_event = False if isinstance(tool_events, list): for idx, item in enumerate(tool_events): if not isinstance(item, dict): continue + saw_tool_event = True tool_id = str(item.get("id") or f"toolu_nonstream_{idx}") content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id)) tool_result = _anthropic_tool_result_block(item, forced_id=tool_id) @@ -1261,7 +1358,21 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): else: saw_pending_tool_use = True + if not saw_tool_event: + forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) + if forced_tool_name: + fallback_event = _forced_tool_event_from_text(text, forced_tool_name) + if fallback_event is not None: + content_blocks = [] + tool_id = "toolu_fallback_0" + content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id)) + tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id) + saw_pending_tool_use = tool_result is None + if tool_result is not None: + content_blocks.append(tool_result) + response_body: dict = { + "id": message_id, "type": "message", "role": "assistant", diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index bcf16dd..13ee0df 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -224,6 +224,42 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"query": "gateway"}, ) + async def test_openai_non_stream_fallbacks_to_structured_tool_call_for_forced_tool(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "```json\n{\"arguments\": {\"query\": \"gateway\"}}\n```", + "toolEvents": [], + "sessionId": "sess-fallback-openai", + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + payload = json.loads(response.body) + message = payload["choices"][0]["message"] + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertIsInstance(message["tool_calls"], list) + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) + async def test_openai_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[ @@ -306,6 +342,46 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(payload["content"][1]["name"], "lookup") self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") + async def test_anthropic_non_stream_fallbacks_to_structured_tool_blocks_for_forced_tool(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "{\"input\": {\"k\": \"v\"}, \"result\": {\"value\": 1}}", + "toolEvents": [], + "sessionId": "sess-fallback-anthropic", + }, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tool_choice={"type": "tool", "name": "lookup"}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + payload = json.loads(response.body) + types = [item["type"] for item in payload["content"]] + self.assertEqual(types, ["tool_use", "tool_result"]) + self.assertEqual(payload["stop_reason"], "end_turn") + self.assertEqual(payload["content"][0]["name"], "lookup") + self.assertEqual(payload["content"][1]["tool_use_id"], "toolu_fallback_0") + async def test_openai_stream_tool_call_indices_are_stable(self) -> None: fake_client = _FakeClient( stream_events=[