diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py index fc71fcd..f4e49fb 100644 --- a/app/http/tool_bridge.py +++ b/app/http/tool_bridge.py @@ -58,6 +58,40 @@ def _tool_event_allowed( return bool(forced_tool_name and tool_name == forced_tool_name) +def _allowed_tool_event( + tool: Any, + *, + tool_config: dict[str, Any] | None, + forced_tool_name: str | None = None, +) -> dict[str, Any] | None: + if not isinstance(tool, dict): + return None + tool_name = str(tool.get("name") or "") + if not _tool_event_allowed(tool_name, tool_config, forced_tool_name=forced_tool_name): + return None + return tool + + +def _allowed_tool_events( + tool_events: Any, + *, + tool_config: dict[str, Any] | None, + forced_tool_name: str | None = None, +) -> list[dict[str, Any]]: + if not isinstance(tool_events, list): + return [] + out: list[dict[str, Any]] = [] + for item in tool_events: + allowed = _allowed_tool_event( + item, + tool_config=tool_config, + forced_tool_name=forced_tool_name, + ) + if allowed is not None: + out.append(allowed) + return out + + def _openai_forced_tool_name(tool_choice: Any) -> str | None: if not isinstance(tool_choice, dict): return None @@ -222,6 +256,21 @@ def _forced_tool_event_from_text( return event +def _forced_tool_fallback_event( + text: str, + *, + forced_tool_name: str | None, + tools: list[dict[str, Any]] | None = None, +) -> dict[str, Any] | None: + if not forced_tool_name: + return None + return _forced_tool_event_from_text( + text, + forced_tool_name, + single_arg_name=_tool_code_single_arg_name(tools, forced_tool_name), + ) + + def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: return { "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), diff --git a/app/main.py b/app/main.py index c34433c..9965505 100644 --- a/app/main.py +++ b/app/main.py @@ -39,17 +39,17 @@ from .http.responses_adapter import ( _sse_data, ) from .http.tool_bridge import ( + _allowed_tool_events, _anthropic_forced_tool_name, _anthropic_tool_name as _shared_anthropic_tool_name, _anthropic_tool_result_block, _anthropic_tool_use_block, _forced_tool_event_from_text, - _json_object_from_text, + _forced_tool_fallback_event, _json_string, _openai_forced_tool_name, _openai_tool_call, _openai_tool_name as _shared_openai_tool_name, - _tool_code_object_from_text, _tool_code_single_arg_name, _tool_event_allowed, ) @@ -892,36 +892,29 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): sid = result.get("sessionId") if sid: await session_cache.put(write_key, sid, inst.name) - tool_events = result.get("toolEvents") or [] + forced_tool_name = _openai_forced_tool_name(req.tool_choice) + tool_events = _allowed_tool_events( + result.get("toolEvents"), + tool_config=tool_config, + forced_tool_name=forced_tool_name, + ) message_content = result.get("text") or "" tool_calls: list[dict[str, Any]] = [] saw_tool_call = False - forced_tool_name = _openai_forced_tool_name(req.tool_choice) - if isinstance(tool_events, list): - for idx, item in enumerate(tool_events): - if isinstance(item, dict): - tool_name = str(item.get("name") or "") - if not _tool_event_allowed( - tool_name, - tool_config, - forced_tool_name=forced_tool_name, - ): - continue - - tool_id = str(item.get("id") or f"call_{idx}") - tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) - saw_tool_call = True - if not saw_tool_call: - if forced_tool_name: - fallback_event = _forced_tool_event_from_text( - message_content, - forced_tool_name, - single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name), - ) - if fallback_event is not None: - tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0")) - saw_tool_call = True - message_content = "" + for idx, item in enumerate(tool_events): + tool_id = str(item.get("id") or f"call_{idx}") + tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) + saw_tool_call = True + if not saw_tool_call and forced_tool_name: + fallback_event = _forced_tool_fallback_event( + message_content, + forced_tool_name=forced_tool_name, + tools=req.tools, + ) + if fallback_event is not None: + tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0")) + saw_tool_call = True + message_content = "" response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), @@ -1664,47 +1657,38 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): content_blocks: list[dict[str, Any]] = [] if text: content_blocks.append({"type": "text", "text": text}) - tool_events = result.get("toolEvents") or [] + forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) + tool_events = _allowed_tool_events( + result.get("toolEvents"), + tool_config=tool_config, + forced_tool_name=forced_tool_name, + ) saw_pending_tool_use = False saw_tool_event = False - if isinstance(tool_events, list): - for idx, item in enumerate(tool_events): - if not isinstance(item, dict): - continue - - tool_name = str(item.get("name") or "") - if not _tool_event_allowed( - tool_name, - tool_config, - forced_tool_name=_anthropic_forced_tool_name(req.tool_choice), - ): - continue + for idx, item in enumerate(tool_events): + saw_tool_event = True + tool_id = str(item.get("id") or f"toolu_nonstream_{idx}") + content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id)) + tool_result = _anthropic_tool_result_block(item, forced_id=tool_id) + if tool_result is not None: + content_blocks.append(tool_result) + else: + saw_pending_tool_use = True - saw_tool_event = True - tool_id = str(item.get("id") or f"toolu_nonstream_{idx}") - content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id)) - tool_result = _anthropic_tool_result_block(item, forced_id=tool_id) + if not saw_tool_event and forced_tool_name: + fallback_event = _forced_tool_fallback_event( + text, + forced_tool_name=forced_tool_name, + tools=req.tools, + ) + if fallback_event is not None: + content_blocks = [] + tool_id = "toolu_fallback_0" + content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id)) + tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id) + saw_pending_tool_use = tool_result is None if tool_result is not None: content_blocks.append(tool_result) - else: - saw_pending_tool_use = True - - if not saw_tool_event: - forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) - if forced_tool_name: - fallback_event = _forced_tool_event_from_text( - text, - forced_tool_name, - single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name), - ) - if fallback_event is not None: - content_blocks = [] - tool_id = "toolu_fallback_0" - content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id)) - tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id) - saw_pending_tool_use = tool_result is None - if tool_result is not None: - content_blocks.append(tool_result) response_body: dict = {