diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py index 10ab7ae..6179281 100644 --- a/app/http/tool_bridge.py +++ b/app/http/tool_bridge.py @@ -182,6 +182,20 @@ def _json_tool_candidate_from_text(text: str) -> dict[str, Any] | None: return None +def _extract_tool_calls_from_text(text: str) -> list[dict[str, Any]] | None: + text = text.strip() + match = re.search(r"\[tool_calls\]\s*(\[.*\])", text, re.DOTALL) + if not match: + return None + try: + parsed = json.loads(match.group(1)) + if isinstance(parsed, list) and len(parsed) > 0 and isinstance(parsed[0], dict): + return parsed + except Exception: + pass + return None + + def _tool_code_single_arg_name( tools: list[dict[str, Any]] | None, forced_tool_name: str ) -> str | None: diff --git a/app/main.py b/app/main.py index bb455f3..75fa904 100644 --- a/app/main.py +++ b/app/main.py @@ -44,6 +44,7 @@ from .http.tool_bridge import ( _anthropic_tool_result_block, _anthropic_tool_use_block, _extract_function_call_event_from_text, + _extract_tool_calls_from_text, _forced_tool_fallback_event, _json_string, _openai_forced_tool_name, @@ -656,10 +657,52 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): continue buffered_text_parts.append(text) completion_tokens_holder["n"] += estimate_tokens(text) + + full_text = "".join(buffered_text_parts) + if req.tools: + if "[tool_calls]".startswith(full_text) or "[tool_calls]" in full_text: + continue + if forced_tool_name and not saw_tool_call: continue - yield _text_payload(text) + + # Yield all buffered text + text_to_yield = "".join(buffered_text_parts) + buffered_text_parts.clear() + yield _text_payload(text_to_yield) + if buffered_text_parts and not saw_tool_call: + merged_text = "".join(buffered_text_parts) + + extracted_tool_calls = _extract_tool_calls_from_text(merged_text) + if extracted_tool_calls: + saw_tool_call = True + for i, tc in enumerate(extracted_tool_calls): + tool_id = str(tc.get("id") or f"call_inferred_{i}") + tool_call_indexes[tool_id] = i + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": i, + **_openai_tool_call(tc, forced_id=tool_id), + } + ] + }, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + buffered_text_parts.clear() + if buffered_text_parts and forced_tool_name and not saw_tool_call: merged_text = "".join(buffered_text_parts) inferred = _extract_function_call_event_from_text( @@ -808,6 +851,16 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_id = str(item.get("id") or f"call_{idx}") tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) saw_tool_call = True + + if not saw_tool_call: + extracted_tool_calls = _extract_tool_calls_from_text(message_content) + if extracted_tool_calls: + for idx, tc in enumerate(extracted_tool_calls): + tool_id = str(tc.get("id") or f"call_inferred_{idx}") + tool_calls.append(_openai_tool_call(tc, forced_id=tool_id)) + saw_tool_call = True + message_content = "" + if not saw_tool_call and forced_tool_name: inferred = _extract_function_call_event_from_text( message_content,