feat: intercept literal [tool_calls] arrays in generated text and map to actual function calls
This commit is contained in:
@@ -182,6 +182,20 @@ def _json_tool_candidate_from_text(text: str) -> dict[str, Any] | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_calls_from_text(text: str) -> list[dict[str, Any]] | None:
|
||||||
|
text = text.strip()
|
||||||
|
match = re.search(r"\[tool_calls\]\s*(\[.*\])", text, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(1))
|
||||||
|
if isinstance(parsed, list) and len(parsed) > 0 and isinstance(parsed[0], dict):
|
||||||
|
return parsed
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _tool_code_single_arg_name(
|
def _tool_code_single_arg_name(
|
||||||
tools: list[dict[str, Any]] | None, forced_tool_name: str
|
tools: list[dict[str, Any]] | None, forced_tool_name: str
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
|
|||||||
55
app/main.py
55
app/main.py
@@ -44,6 +44,7 @@ from .http.tool_bridge import (
|
|||||||
_anthropic_tool_result_block,
|
_anthropic_tool_result_block,
|
||||||
_anthropic_tool_use_block,
|
_anthropic_tool_use_block,
|
||||||
_extract_function_call_event_from_text,
|
_extract_function_call_event_from_text,
|
||||||
|
_extract_tool_calls_from_text,
|
||||||
_forced_tool_fallback_event,
|
_forced_tool_fallback_event,
|
||||||
_json_string,
|
_json_string,
|
||||||
_openai_forced_tool_name,
|
_openai_forced_tool_name,
|
||||||
@@ -656,9 +657,51 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
continue
|
continue
|
||||||
buffered_text_parts.append(text)
|
buffered_text_parts.append(text)
|
||||||
completion_tokens_holder["n"] += estimate_tokens(text)
|
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||||
|
|
||||||
|
full_text = "".join(buffered_text_parts)
|
||||||
|
if req.tools:
|
||||||
|
if "[tool_calls]".startswith(full_text) or "[tool_calls]" in full_text:
|
||||||
|
continue
|
||||||
|
|
||||||
if forced_tool_name and not saw_tool_call:
|
if forced_tool_name and not saw_tool_call:
|
||||||
continue
|
continue
|
||||||
yield _text_payload(text)
|
|
||||||
|
# Yield all buffered text
|
||||||
|
text_to_yield = "".join(buffered_text_parts)
|
||||||
|
buffered_text_parts.clear()
|
||||||
|
yield _text_payload(text_to_yield)
|
||||||
|
|
||||||
|
if buffered_text_parts and not saw_tool_call:
|
||||||
|
merged_text = "".join(buffered_text_parts)
|
||||||
|
|
||||||
|
extracted_tool_calls = _extract_tool_calls_from_text(merged_text)
|
||||||
|
if extracted_tool_calls:
|
||||||
|
saw_tool_call = True
|
||||||
|
for i, tc in enumerate(extracted_tool_calls):
|
||||||
|
tool_id = str(tc.get("id") or f"call_inferred_{i}")
|
||||||
|
tool_call_indexes[tool_id] = i
|
||||||
|
payload = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": i,
|
||||||
|
**_openai_tool_call(tc, forced_id=tool_id),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
buffered_text_parts.clear()
|
||||||
|
|
||||||
if buffered_text_parts and forced_tool_name and not saw_tool_call:
|
if buffered_text_parts and forced_tool_name and not saw_tool_call:
|
||||||
merged_text = "".join(buffered_text_parts)
|
merged_text = "".join(buffered_text_parts)
|
||||||
@@ -808,6 +851,16 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
tool_id = str(item.get("id") or f"call_{idx}")
|
tool_id = str(item.get("id") or f"call_{idx}")
|
||||||
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||||
saw_tool_call = True
|
saw_tool_call = True
|
||||||
|
|
||||||
|
if not saw_tool_call:
|
||||||
|
extracted_tool_calls = _extract_tool_calls_from_text(message_content)
|
||||||
|
if extracted_tool_calls:
|
||||||
|
for idx, tc in enumerate(extracted_tool_calls):
|
||||||
|
tool_id = str(tc.get("id") or f"call_inferred_{idx}")
|
||||||
|
tool_calls.append(_openai_tool_call(tc, forced_id=tool_id))
|
||||||
|
saw_tool_call = True
|
||||||
|
message_content = ""
|
||||||
|
|
||||||
if not saw_tool_call and forced_tool_name:
|
if not saw_tool_call and forced_tool_name:
|
||||||
inferred = _extract_function_call_event_from_text(
|
inferred = _extract_function_call_event_from_text(
|
||||||
message_content,
|
message_content,
|
||||||
|
|||||||
Reference in New Issue
Block a user