fix: close forced tool-choice with structured fallback
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
111
app/main.py
111
app/main.py
@@ -452,6 +452,93 @@ def _json_string(value: Any) -> str:
|
|||||||
return "{}"
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_forced_tool_name(tool_choice: Any) -> str | None:
|
||||||
|
if not isinstance(tool_choice, dict):
|
||||||
|
return None
|
||||||
|
fn = tool_choice.get("function")
|
||||||
|
if isinstance(fn, dict):
|
||||||
|
name = fn.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_forced_tool_name(tool_choice: Any) -> str | None:
|
||||||
|
if not isinstance(tool_choice, dict):
|
||||||
|
return None
|
||||||
|
if tool_choice.get("type") == "tool":
|
||||||
|
name = tool_choice.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
fn = tool_choice.get("function")
|
||||||
|
if isinstance(fn, dict):
|
||||||
|
name = fn.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _json_object_from_text(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
if raw.startswith("```") and raw.endswith("```"):
|
||||||
|
raw = raw[3:-3].strip()
|
||||||
|
if raw.lower().startswith("json"):
|
||||||
|
raw = raw[4:].strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
|
||||||
|
parsed = _json_object_from_text(text)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
explicit_name: Any = parsed.get("name") or parsed.get("tool")
|
||||||
|
fn = parsed.get("function")
|
||||||
|
if explicit_name is None and isinstance(fn, dict):
|
||||||
|
explicit_name = fn.get("name")
|
||||||
|
if explicit_name is not None and str(explicit_name) != forced_tool_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_input: Any = None
|
||||||
|
if "input" in parsed:
|
||||||
|
tool_input = parsed.get("input")
|
||||||
|
elif "arguments" in parsed:
|
||||||
|
args = parsed.get("arguments")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
tool_input = json.loads(args)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tool_input = args
|
||||||
|
elif isinstance(fn, dict) and "arguments" in fn:
|
||||||
|
args = fn.get("arguments")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
tool_input = json.loads(args)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tool_input = args
|
||||||
|
else:
|
||||||
|
reserved = {"name", "tool", "function", "arguments", "input", "result"}
|
||||||
|
tool_input = {k: v for k, v in parsed.items() if k not in reserved}
|
||||||
|
|
||||||
|
event: dict[str, Any] = {
|
||||||
|
"name": forced_tool_name,
|
||||||
|
"input": tool_input if tool_input is not None else {},
|
||||||
|
}
|
||||||
|
if "result" in parsed:
|
||||||
|
event["result"] = parsed.get("result")
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
||||||
@@ -800,6 +887,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
tool_id = str(item.get("id") or f"call_{idx}")
|
tool_id = str(item.get("id") or f"call_{idx}")
|
||||||
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||||
saw_tool_call = True
|
saw_tool_call = True
|
||||||
|
if not saw_tool_call:
|
||||||
|
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
|
||||||
|
if forced_tool_name:
|
||||||
|
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
|
||||||
|
if fallback_event is not None:
|
||||||
|
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
|
||||||
|
saw_tool_call = True
|
||||||
|
message_content = ""
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -1249,10 +1344,12 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
content_blocks.append({"type": "text", "text": text})
|
content_blocks.append({"type": "text", "text": text})
|
||||||
tool_events = result.get("toolEvents") or []
|
tool_events = result.get("toolEvents") or []
|
||||||
saw_pending_tool_use = False
|
saw_pending_tool_use = False
|
||||||
|
saw_tool_event = False
|
||||||
if isinstance(tool_events, list):
|
if isinstance(tool_events, list):
|
||||||
for idx, item in enumerate(tool_events):
|
for idx, item in enumerate(tool_events):
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
|
saw_tool_event = True
|
||||||
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
||||||
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
||||||
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
||||||
@@ -1261,7 +1358,21 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
else:
|
else:
|
||||||
saw_pending_tool_use = True
|
saw_pending_tool_use = True
|
||||||
|
|
||||||
|
if not saw_tool_event:
|
||||||
|
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
|
||||||
|
if forced_tool_name:
|
||||||
|
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
|
||||||
|
if fallback_event is not None:
|
||||||
|
content_blocks = []
|
||||||
|
tool_id = "toolu_fallback_0"
|
||||||
|
content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id))
|
||||||
|
tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id)
|
||||||
|
saw_pending_tool_use = tool_result is None
|
||||||
|
if tool_result is not None:
|
||||||
|
content_blocks.append(tool_result)
|
||||||
|
|
||||||
response_body: dict = {
|
response_body: dict = {
|
||||||
|
|
||||||
"id": message_id,
|
"id": message_id,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
|||||||
@@ -224,6 +224,42 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
{"query": "gateway"},
|
{"query": "gateway"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def test_openai_non_stream_fallbacks_to_structured_tool_call_for_forced_tool(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "```json\n{\"arguments\": {\"query\": \"gateway\"}}\n```",
|
||||||
|
"toolEvents": [],
|
||||||
|
"sessionId": "sess-fallback-openai",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
):
|
||||||
|
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
message = payload["choices"][0]["message"]
|
||||||
|
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||||
|
self.assertEqual(message["content"], "")
|
||||||
|
self.assertIsInstance(message["tool_calls"], list)
|
||||||
|
self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup")
|
||||||
|
self.assertEqual(
|
||||||
|
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||||
|
{"query": "gateway"},
|
||||||
|
)
|
||||||
|
|
||||||
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
stream_events=[
|
stream_events=[
|
||||||
@@ -306,6 +342,46 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(payload["content"][1]["name"], "lookup")
|
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||||
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||||
|
|
||||||
|
async def test_anthropic_non_stream_fallbacks_to_structured_tool_blocks_for_forced_tool(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "{\"input\": {\"k\": \"v\"}, \"result\": {\"value\": 1}}",
|
||||||
|
"toolEvents": [],
|
||||||
|
"sessionId": "sess-fallback-anthropic",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||||
|
tool_choice={"type": "tool", "name": "lookup"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
types = [item["type"] for item in payload["content"]]
|
||||||
|
self.assertEqual(types, ["tool_use", "tool_result"])
|
||||||
|
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||||
|
self.assertEqual(payload["content"][0]["name"], "lookup")
|
||||||
|
self.assertEqual(payload["content"][1]["tool_use_id"], "toolu_fallback_0")
|
||||||
|
|
||||||
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
stream_events=[
|
stream_events=[
|
||||||
|
|||||||
Reference in New Issue
Block a user