fix: harden responses streaming and tool-call fallback

Ensure /v1/responses streams always terminate with response.completed and normalize Lingma tool_code fallbacks into structured tool calls, including single-argument forms.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
mmc
2026-04-20 19:24:02 +08:00
parent 866a212573
commit d0df089282
6 changed files with 927 additions and 18 deletions

View File

@@ -5,6 +5,11 @@ import os
from dataclasses import dataclass, field
def _csv_env(raw: str) -> list[str]:
return [item.strip() for item in (raw or "").replace("\n", ",").split(",") if item.strip()]
@dataclass
class LingmaAccount:
username: str
@@ -45,6 +50,7 @@ class Settings:
session_cache_max_entries: int = 256
session_cache_ttl_sec: float = 1800.0
tool_forward_enabled: bool = False
tool_allowlist: list[str] = field(default_factory=list)
def _bool_env(name: str, default: bool) -> bool:
@@ -177,4 +183,5 @@ def load_settings() -> Settings:
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
tool_allowlist=_csv_env(os.getenv("TOOL_ALLOWLIST", "")),
)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import ast
import asyncio
import hashlib
import json
@@ -358,6 +359,68 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage"))
def _tool_allowlist() -> set[str]:
return {name.strip() for name in settings.tool_allowlist if isinstance(name, str) and name.strip()}
def _openai_tool_name(tool: Any) -> str | None:
if not isinstance(tool, dict):
return None
if tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
name = tool.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
return None
def _anthropic_tool_name(tool: Any) -> str | None:
if not isinstance(tool, dict):
return None
name = tool.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
fn = tool.get("function")
if isinstance(fn, dict):
nested_name = fn.get("name")
if isinstance(nested_name, str) and nested_name.strip():
return nested_name.strip()
return None
def _filter_allowed_tools(tools: list[dict[str, Any]], *, provider: str) -> list[dict[str, Any]]:
allowlist = _tool_allowlist()
if not allowlist:
return tools
name_fn = _openai_tool_name if provider == "openai" else _anthropic_tool_name
return [tool for tool in tools if (name := name_fn(tool)) and name in allowlist]
def _ensure_tool_choice_allowed(tool_choice: Any, *, provider: str) -> None:
allowlist = _tool_allowlist()
if not allowlist:
return
forced_name = (
_openai_forced_tool_name(tool_choice)
if provider == "openai"
else _anthropic_forced_tool_name(tool_choice)
)
if forced_name and forced_name not in allowlist:
raise HTTPException(
status_code=400,
detail={
"error": {
"type": "invalid_request_error",
"message": f"tool '{forced_name}' is not allowed",
}
},
)
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
@@ -365,9 +428,11 @@ def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
_ensure_tool_choice_allowed(req.tool_choice, provider="openai")
tools = _filter_allowed_tools(req.tools or [], provider="openai")
return {
"provider": "openai",
"tools": req.tools or [],
"tools": tools,
"tool_choice": req.tool_choice,
}
@@ -379,9 +444,11 @@ def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | No
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
_ensure_tool_choice_allowed(req.tool_choice, provider="anthropic")
tools = _filter_allowed_tools(req.tools or [], provider="anthropic")
return {
"provider": "anthropic",
"tools": req.tools or [],
"tools": tools,
"tool_choice": req.tool_choice,
}
@@ -537,8 +604,85 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None:
return parsed if isinstance(parsed, dict) else None
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None:
if not isinstance(tools, list):
return None
for tool in tools:
if not isinstance(tool, dict):
continue
schema: dict[str, Any] | None = None
if tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict) and fn.get("name") == forced_tool_name:
params = fn.get("parameters")
if isinstance(params, dict):
schema = params
elif tool.get("name") == forced_tool_name:
input_schema = tool.get("input_schema")
if isinstance(input_schema, dict):
schema = input_schema
if not isinstance(schema, dict):
continue
properties = schema.get("properties")
if not isinstance(properties, dict) or len(properties) != 1:
return None
only_name = next(iter(properties.keys()), None)
if isinstance(only_name, str) and only_name.strip():
return only_name
return None
return None
def _tool_code_object_from_text(
text: str,
forced_tool_name: str,
*,
single_arg_name: str | None = None,
) -> dict[str, Any] | None:
raw = text.strip()
if not raw.startswith("```tool_code") or not raw.endswith("```"):
return None
lines = raw.splitlines()
if len(lines) < 2:
return None
body = "\n".join(lines[1:-1]).strip()
try:
parsed = ast.parse(body, mode="eval")
except Exception:
return None
call = parsed.body
if not isinstance(call, ast.Call):
return None
if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name:
return None
arguments: dict[str, Any] = {}
if call.args:
if len(call.args) != 1 or call.keywords or not single_arg_name:
return None
try:
arguments[single_arg_name] = ast.literal_eval(call.args[0])
except Exception:
return None
return {"arguments": arguments}
for kw in call.keywords:
if kw.arg is None:
return None
try:
arguments[kw.arg] = ast.literal_eval(kw.value)
except Exception:
return None
return {"arguments": arguments}
def _forced_tool_event_from_text(
text: str,
forced_tool_name: str,
*,
single_arg_name: str | None = None,
) -> dict[str, Any] | None:
parsed = _json_object_from_text(text)
if parsed is None:
parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name)
if parsed is None:
return None
@@ -756,11 +900,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
stream_meta: dict = {}
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
forced_tool_single_arg_name = _tool_code_single_arg_name(req.tools, forced_tool_name) if forced_tool_name else None
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
tool_call_indexes: dict[str, int] = {}
saw_tool_call = False
buffered_text_parts: list[str] = []
try:
async for chunk in _inst.client.chat_stream(
prompt,
@@ -809,6 +956,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
text = _stream_text(chunk)
if not text:
continue
buffered_text_parts.append(text)
completion_tokens_holder["n"] += estimate_tokens(text)
payload = {
"id": completion_id,
@@ -825,6 +973,39 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
if not saw_tool_call and forced_tool_name:
fallback_event = _forced_tool_event_from_text(
"".join(buffered_text_parts),
forced_tool_name,
single_arg_name=forced_tool_single_arg_name,
)
if fallback_event is not None:
saw_tool_call = True
tool_id = "call_fallback_0"
idx = 0
tool_call_indexes[tool_id] = idx
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": idx,
**_openai_tool_call(fallback_event, forced_id=tool_id),
}
]
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
@@ -945,7 +1126,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
if not saw_tool_call:
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
if forced_tool_name:
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
fallback_event = _forced_tool_event_from_text(
message_content,
forced_tool_name,
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
)
if fallback_event is not None:
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
saw_tool_call = True
@@ -1169,6 +1354,42 @@ async def _responses_stream_from_chat_stream(
created_at = int(time.time())
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
completed_sent = False
output_item_id = f"msg_{uuid.uuid4().hex}"
output_index = 0
content_index = 0
output_text_parts: list[str] = []
function_call_items: list[dict[str, Any]] = []
function_call_index_by_id: dict[str, int] = {}
function_call_arguments_by_id: dict[str, str] = {}
function_call_name_by_id: dict[str, str] = {}
function_call_id_by_upstream_index: dict[int, str] = {}
def _message_item(status: str) -> dict[str, Any]:
return {
"id": output_item_id,
"type": "message",
"role": "assistant",
"status": status,
"content": [
{
"type": "output_text",
"text": "".join(output_text_parts),
}
],
}
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
return {
"id": call_id,
"type": "function_call",
"call_id": call_id,
"name": name,
"arguments": arguments,
"status": status,
}
def _completed_output_items() -> list[dict[str, Any]]:
return [_message_item("completed"), *function_call_items]
def _completed_frame() -> str:
return _sse_data(
@@ -1180,11 +1401,94 @@ async def _responses_stream_from_chat_stream(
"created_at": created_at,
"status": "completed",
"model": model,
"output": _completed_output_items(),
"usage": usage,
},
}
)
def _finish_output_item_frames() -> list[str]:
frames = [
_sse_data(
{
"type": "response.output_text.done",
"response_id": response_id,
"item_id": output_item_id,
"output_index": output_index,
"content_index": content_index,
"text": "".join(output_text_parts),
}
),
_sse_data(
{
"type": "response.output_item.done",
"response_id": response_id,
"output_index": output_index,
"item": _message_item("completed"),
}
),
]
for idx, item in enumerate(function_call_items, start=1):
frames.append(
_sse_data(
{
"type": "response.function_call_arguments.done",
"response_id": response_id,
"item_id": item["id"],
"output_index": idx,
"arguments": item["arguments"],
}
)
)
frames.append(
_sse_data(
{
"type": "response.output_item.done",
"response_id": response_id,
"output_index": idx,
"item": item,
}
)
)
return frames
def _ensure_function_call_item(call_id: str) -> list[str]:
existing_index = function_call_index_by_id.get(call_id)
name = function_call_name_by_id.get(call_id, "tool")
arguments = function_call_arguments_by_id.get(call_id, "")
if existing_index is not None:
function_call_items[existing_index] = _function_call_item(
call_id,
status="completed",
name=name,
arguments=arguments,
)
return []
item = _function_call_item(
call_id,
status="completed",
name=name,
arguments=arguments,
)
function_call_items.append(item)
item_index = len(function_call_items) - 1
function_call_index_by_id[call_id] = item_index
return [
_sse_data(
{
"type": "response.output_item.added",
"response_id": response_id,
"output_index": item_index + 1,
"item": _function_call_item(
call_id,
status="in_progress",
name=name,
arguments="",
),
}
)
]
yield _sse_data(
{
"type": "response.created",
@@ -1194,9 +1498,18 @@ async def _responses_stream_from_chat_stream(
"created_at": created_at,
"status": "in_progress",
"model": model,
"output": [],
},
}
)
yield _sse_data(
{
"type": "response.output_item.added",
"response_id": response_id,
"output_index": output_index,
"item": _message_item("in_progress"),
}
)
try:
async for part in chat_stream.body_iterator:
@@ -1207,6 +1520,8 @@ async def _responses_stream_from_chat_stream(
continue
body = frame[len("data:") :].strip()
if body == "[DONE]":
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
@@ -1229,10 +1544,14 @@ async def _responses_stream_from_chat_stream(
text = delta.get("content")
if isinstance(text, str) and text:
output_text_parts.append(text)
yield _sse_data(
{
"type": "response.output_text.delta",
"response_id": response_id,
"item_id": output_item_id,
"output_index": output_index,
"content_index": content_index,
"delta": text,
}
)
@@ -1243,35 +1562,59 @@ async def _responses_stream_from_chat_stream(
if not isinstance(tool_call, dict):
continue
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
call_id = str(tool_call.get("id") or f"call_{idx}")
yield _sse_data(
{
"type": "response.function_call.delta",
"response_id": response_id,
"item_id": call_id,
"name": str(fn.get("name") or "tool"),
"arguments": str(fn.get("arguments") or "{}"),
}
upstream_index_raw = tool_call.get("index")
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
call_id = str(
tool_call.get("id")
or function_call_id_by_upstream_index.get(upstream_index)
or f"call_{upstream_index}"
)
function_call_id_by_upstream_index[upstream_index] = call_id
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
function_call_name_by_id[call_id] = name
arguments_delta = str(fn.get("arguments") or "")
accumulated_arguments = (
function_call_arguments_by_id.get(call_id, "") + arguments_delta
)
function_call_arguments_by_id[call_id] = accumulated_arguments
for event in _ensure_function_call_item(call_id):
yield event
if arguments_delta:
yield _sse_data(
{
"type": "response.function_call_arguments.delta",
"response_id": response_id,
"item_id": call_id,
"output_index": function_call_index_by_id[call_id] + 1,
"delta": arguments_delta,
}
)
except asyncio.CancelledError:
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
return
except Exception:
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
return
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
@app.post("/responses", dependencies=[Depends(auth_guard)])
@app.post("/v1/responses", dependencies=[Depends(auth_guard)])
async def v1_responses(req: ResponsesRequest, request: Request):
chat_req = _responses_to_chat_request(req)
@@ -1388,7 +1731,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
)
# ------------------------------------------------------------- session reuse
tool_config = _anthropic_tool_config(req)
try:
tool_config = _anthropic_tool_config(req)
except HTTPException as exc:
detail = exc.detail if isinstance(exc.detail, dict) else {}
error = detail.get("error") if isinstance(detail.get("error"), dict) else {}
message = error.get("message") or str(detail) or "invalid tool configuration"
return _anthropic_error(exc.status_code, "invalid_request_error", message)
has_tooling_context = _anthropic_has_tooling_context(req)
ask_mode = _resolve_ask_mode(req.model, has_tooling_context)
@@ -1749,7 +2098,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if not saw_tool_event:
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
if forced_tool_name:
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
fallback_event = _forced_tool_event_from_text(
text,
forced_tool_name,
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
)
if fallback_event is not None:
content_blocks = []
tool_id = "toolu_fallback_0"