fix: harden responses streaming and tool-call fallback
Ensure /v1/responses streams always terminate with response.completed and normalize Lingma tool_code fallbacks into structured tool calls, including single-argument forms. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
383
app/main.py
383
app/main.py
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
@@ -358,6 +359,68 @@ def _include_usage(stream_options: dict | None) -> bool:
|
||||
return bool(stream_options.get("include_usage"))
|
||||
|
||||
|
||||
def _tool_allowlist() -> set[str]:
|
||||
return {name.strip() for name in settings.tool_allowlist if isinstance(name, str) and name.strip()}
|
||||
|
||||
|
||||
def _openai_tool_name(tool: Any) -> str | None:
|
||||
if not isinstance(tool, dict):
|
||||
return None
|
||||
if tool.get("type") == "function":
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
name = tool.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _anthropic_tool_name(tool: Any) -> str | None:
|
||||
if not isinstance(tool, dict):
|
||||
return None
|
||||
name = tool.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict):
|
||||
nested_name = fn.get("name")
|
||||
if isinstance(nested_name, str) and nested_name.strip():
|
||||
return nested_name.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _filter_allowed_tools(tools: list[dict[str, Any]], *, provider: str) -> list[dict[str, Any]]:
|
||||
allowlist = _tool_allowlist()
|
||||
if not allowlist:
|
||||
return tools
|
||||
name_fn = _openai_tool_name if provider == "openai" else _anthropic_tool_name
|
||||
return [tool for tool in tools if (name := name_fn(tool)) and name in allowlist]
|
||||
|
||||
|
||||
def _ensure_tool_choice_allowed(tool_choice: Any, *, provider: str) -> None:
|
||||
allowlist = _tool_allowlist()
|
||||
if not allowlist:
|
||||
return
|
||||
forced_name = (
|
||||
_openai_forced_tool_name(tool_choice)
|
||||
if provider == "openai"
|
||||
else _anthropic_forced_tool_name(tool_choice)
|
||||
)
|
||||
if forced_name and forced_name not in allowlist:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": f"tool '{forced_name}' is not allowed",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
|
||||
if not settings.tool_forward_enabled:
|
||||
return None
|
||||
@@ -365,9 +428,11 @@ def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
_ensure_tool_choice_allowed(req.tool_choice, provider="openai")
|
||||
tools = _filter_allowed_tools(req.tools or [], provider="openai")
|
||||
return {
|
||||
"provider": "openai",
|
||||
"tools": req.tools or [],
|
||||
"tools": tools,
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
@@ -379,9 +444,11 @@ def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | No
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
_ensure_tool_choice_allowed(req.tool_choice, provider="anthropic")
|
||||
tools = _filter_allowed_tools(req.tools or [], provider="anthropic")
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
"tools": req.tools or [],
|
||||
"tools": tools,
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
@@ -537,8 +604,85 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None:
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
|
||||
def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None:
|
||||
if not isinstance(tools, list):
|
||||
return None
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
schema: dict[str, Any] | None = None
|
||||
if tool.get("type") == "function":
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name") == forced_tool_name:
|
||||
params = fn.get("parameters")
|
||||
if isinstance(params, dict):
|
||||
schema = params
|
||||
elif tool.get("name") == forced_tool_name:
|
||||
input_schema = tool.get("input_schema")
|
||||
if isinstance(input_schema, dict):
|
||||
schema = input_schema
|
||||
if not isinstance(schema, dict):
|
||||
continue
|
||||
properties = schema.get("properties")
|
||||
if not isinstance(properties, dict) or len(properties) != 1:
|
||||
return None
|
||||
only_name = next(iter(properties.keys()), None)
|
||||
if isinstance(only_name, str) and only_name.strip():
|
||||
return only_name
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _tool_code_object_from_text(
|
||||
text: str,
|
||||
forced_tool_name: str,
|
||||
*,
|
||||
single_arg_name: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
raw = text.strip()
|
||||
if not raw.startswith("```tool_code") or not raw.endswith("```"):
|
||||
return None
|
||||
lines = raw.splitlines()
|
||||
if len(lines) < 2:
|
||||
return None
|
||||
body = "\n".join(lines[1:-1]).strip()
|
||||
try:
|
||||
parsed = ast.parse(body, mode="eval")
|
||||
except Exception:
|
||||
return None
|
||||
call = parsed.body
|
||||
if not isinstance(call, ast.Call):
|
||||
return None
|
||||
if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name:
|
||||
return None
|
||||
arguments: dict[str, Any] = {}
|
||||
if call.args:
|
||||
if len(call.args) != 1 or call.keywords or not single_arg_name:
|
||||
return None
|
||||
try:
|
||||
arguments[single_arg_name] = ast.literal_eval(call.args[0])
|
||||
except Exception:
|
||||
return None
|
||||
return {"arguments": arguments}
|
||||
for kw in call.keywords:
|
||||
if kw.arg is None:
|
||||
return None
|
||||
try:
|
||||
arguments[kw.arg] = ast.literal_eval(kw.value)
|
||||
except Exception:
|
||||
return None
|
||||
return {"arguments": arguments}
|
||||
|
||||
|
||||
def _forced_tool_event_from_text(
|
||||
text: str,
|
||||
forced_tool_name: str,
|
||||
*,
|
||||
single_arg_name: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
parsed = _json_object_from_text(text)
|
||||
if parsed is None:
|
||||
parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
@@ -756,11 +900,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
completion_tokens_holder = {"n": 0}
|
||||
stream_meta: dict = {}
|
||||
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
|
||||
forced_tool_single_arg_name = _tool_code_single_arg_name(req.tools, forced_tool_name) if forced_tool_name else None
|
||||
|
||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||
success = False
|
||||
tool_call_indexes: dict[str, int] = {}
|
||||
saw_tool_call = False
|
||||
buffered_text_parts: list[str] = []
|
||||
try:
|
||||
async for chunk in _inst.client.chat_stream(
|
||||
prompt,
|
||||
@@ -809,6 +956,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
text = _stream_text(chunk)
|
||||
if not text:
|
||||
continue
|
||||
buffered_text_parts.append(text)
|
||||
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
@@ -825,6 +973,39 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
if not saw_tool_call and forced_tool_name:
|
||||
fallback_event = _forced_tool_event_from_text(
|
||||
"".join(buffered_text_parts),
|
||||
forced_tool_name,
|
||||
single_arg_name=forced_tool_single_arg_name,
|
||||
)
|
||||
if fallback_event is not None:
|
||||
saw_tool_call = True
|
||||
tool_id = "call_fallback_0"
|
||||
idx = 0
|
||||
tool_call_indexes[tool_id] = idx
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": idx,
|
||||
**_openai_tool_call(fallback_event, forced_id=tool_id),
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
done_payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -945,7 +1126,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
if not saw_tool_call:
|
||||
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
|
||||
if forced_tool_name:
|
||||
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
|
||||
fallback_event = _forced_tool_event_from_text(
|
||||
message_content,
|
||||
forced_tool_name,
|
||||
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
|
||||
)
|
||||
if fallback_event is not None:
|
||||
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
|
||||
saw_tool_call = True
|
||||
@@ -1169,6 +1354,42 @@ async def _responses_stream_from_chat_stream(
|
||||
created_at = int(time.time())
|
||||
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
completed_sent = False
|
||||
output_item_id = f"msg_{uuid.uuid4().hex}"
|
||||
output_index = 0
|
||||
content_index = 0
|
||||
output_text_parts: list[str] = []
|
||||
function_call_items: list[dict[str, Any]] = []
|
||||
function_call_index_by_id: dict[str, int] = {}
|
||||
function_call_arguments_by_id: dict[str, str] = {}
|
||||
function_call_name_by_id: dict[str, str] = {}
|
||||
function_call_id_by_upstream_index: dict[int, str] = {}
|
||||
|
||||
def _message_item(status: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": output_item_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": status,
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function_call",
|
||||
"call_id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
def _completed_output_items() -> list[dict[str, Any]]:
|
||||
return [_message_item("completed"), *function_call_items]
|
||||
|
||||
def _completed_frame() -> str:
|
||||
return _sse_data(
|
||||
@@ -1180,11 +1401,94 @@ async def _responses_stream_from_chat_stream(
|
||||
"created_at": created_at,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": _completed_output_items(),
|
||||
"usage": usage,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def _finish_output_item_frames() -> list[str]:
|
||||
frames = [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_text.done",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
),
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("completed"),
|
||||
}
|
||||
),
|
||||
]
|
||||
for idx, item in enumerate(function_call_items, start=1):
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"response_id": response_id,
|
||||
"item_id": item["id"],
|
||||
"output_index": idx,
|
||||
"arguments": item["arguments"],
|
||||
}
|
||||
)
|
||||
)
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": idx,
|
||||
"item": item,
|
||||
}
|
||||
)
|
||||
)
|
||||
return frames
|
||||
|
||||
def _ensure_function_call_item(call_id: str) -> list[str]:
|
||||
existing_index = function_call_index_by_id.get(call_id)
|
||||
name = function_call_name_by_id.get(call_id, "tool")
|
||||
arguments = function_call_arguments_by_id.get(call_id, "")
|
||||
if existing_index is not None:
|
||||
function_call_items[existing_index] = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return []
|
||||
item = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
function_call_items.append(item)
|
||||
item_index = len(function_call_items) - 1
|
||||
function_call_index_by_id[call_id] = item_index
|
||||
return [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": item_index + 1,
|
||||
"item": _function_call_item(
|
||||
call_id,
|
||||
status="in_progress",
|
||||
name=name,
|
||||
arguments="",
|
||||
),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.created",
|
||||
@@ -1194,9 +1498,18 @@ async def _responses_stream_from_chat_stream(
|
||||
"created_at": created_at,
|
||||
"status": "in_progress",
|
||||
"model": model,
|
||||
"output": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("in_progress"),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for part in chat_stream.body_iterator:
|
||||
@@ -1207,6 +1520,8 @@ async def _responses_stream_from_chat_stream(
|
||||
continue
|
||||
body = frame[len("data:") :].strip()
|
||||
if body == "[DONE]":
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
@@ -1229,10 +1544,14 @@ async def _responses_stream_from_chat_stream(
|
||||
|
||||
text = delta.get("content")
|
||||
if isinstance(text, str) and text:
|
||||
output_text_parts.append(text)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_text.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"delta": text,
|
||||
}
|
||||
)
|
||||
@@ -1243,35 +1562,59 @@ async def _responses_stream_from_chat_stream(
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
|
||||
call_id = str(tool_call.get("id") or f"call_{idx}")
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.function_call.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": call_id,
|
||||
"name": str(fn.get("name") or "tool"),
|
||||
"arguments": str(fn.get("arguments") or "{}"),
|
||||
}
|
||||
upstream_index_raw = tool_call.get("index")
|
||||
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
|
||||
call_id = str(
|
||||
tool_call.get("id")
|
||||
or function_call_id_by_upstream_index.get(upstream_index)
|
||||
or f"call_{upstream_index}"
|
||||
)
|
||||
function_call_id_by_upstream_index[upstream_index] = call_id
|
||||
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
|
||||
function_call_name_by_id[call_id] = name
|
||||
arguments_delta = str(fn.get("arguments") or "")
|
||||
accumulated_arguments = (
|
||||
function_call_arguments_by_id.get(call_id, "") + arguments_delta
|
||||
)
|
||||
function_call_arguments_by_id[call_id] = accumulated_arguments
|
||||
for event in _ensure_function_call_item(call_id):
|
||||
yield event
|
||||
if arguments_delta:
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": call_id,
|
||||
"output_index": function_call_index_by_id[call_id] + 1,
|
||||
"delta": arguments_delta,
|
||||
}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
except Exception:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
|
||||
@app.post("/responses", dependencies=[Depends(auth_guard)])
|
||||
@app.post("/v1/responses", dependencies=[Depends(auth_guard)])
|
||||
async def v1_responses(req: ResponsesRequest, request: Request):
|
||||
chat_req = _responses_to_chat_request(req)
|
||||
@@ -1388,7 +1731,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------- session reuse
|
||||
tool_config = _anthropic_tool_config(req)
|
||||
try:
|
||||
tool_config = _anthropic_tool_config(req)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail if isinstance(exc.detail, dict) else {}
|
||||
error = detail.get("error") if isinstance(detail.get("error"), dict) else {}
|
||||
message = error.get("message") or str(detail) or "invalid tool configuration"
|
||||
return _anthropic_error(exc.status_code, "invalid_request_error", message)
|
||||
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||
|
||||
ask_mode = _resolve_ask_mode(req.model, has_tooling_context)
|
||||
@@ -1749,7 +2098,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
if not saw_tool_event:
|
||||
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
|
||||
if forced_tool_name:
|
||||
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
|
||||
fallback_event = _forced_tool_event_from_text(
|
||||
text,
|
||||
forced_tool_name,
|
||||
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
|
||||
)
|
||||
if fallback_event is not None:
|
||||
content_blocks = []
|
||||
tool_id = "toolu_fallback_0"
|
||||
|
||||
Reference in New Issue
Block a user