fix: harden responses streaming and tool-call fallback

Ensure /v1/responses streams always terminate with response.completed and normalize Lingma tool_code fallbacks into structured tool calls, including single-argument forms.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
mmc
2026-04-20 19:24:02 +08:00
parent 866a212573
commit d0df089282
6 changed files with 927 additions and 18 deletions

View File

@@ -48,6 +48,8 @@ DEFAULT_ASK_MODE=chat
# 请求侧 tools/tool_choice 透传到 Lingma默认关闭开启后可支持工具写文件等场景
TOOL_FORWARD_ENABLED=false
# 可选:允许透传的工具名白名单,逗号分隔;为空表示不额外限制
TOOL_ALLOWLIST=
# 专属域(可选)
DEDICATED_DOMAIN_URL=

177
CLAUDE.md
View File

@@ -93,3 +93,180 @@ Both protocols share the same backend pool, backpressure guard, stats, and sessi
- Compose mounts:
- `./data -> /app/data` (persistent Lingma binary/cache/workdirs)
- `./secrets -> /secrets:ro` (session bundles, secrets)
# CLAUDE.md
Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed.
**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment.
## 1. Think Before Coding
**Don't assume. Don't hide confusion. Surface tradeoffs.**
Before implementing:
- State your assumptions explicitly. If uncertain, ask.
- If multiple interpretations exist, present them - don't pick silently.
- If a simpler approach exists, say so. Push back when warranted.
- If something is unclear, stop. Name what's confusing. Ask.
## 2. Simplicity First
**Minimum code that solves the problem. Nothing speculative.**
- No features beyond what was asked.
- No abstractions for single-use code.
- No "flexibility" or "configurability" that wasn't requested.
- No error handling for impossible scenarios.
- If you write 200 lines and it could be 50, rewrite it.
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
## 3. Surgical Changes
**Touch only what you must. Clean up only your own mess.**
When editing existing code:
- Don't "improve" adjacent code, comments, or formatting.
- Don't refactor things that aren't broken.
- Match existing style, even if you'd do it differently.
- If you notice unrelated dead code, mention it - don't delete it.
When your changes create orphans:
- Remove imports/variables/functions that YOUR changes made unused.
- Don't remove pre-existing dead code unless asked.
The test: Every changed line should trace directly to the user's request.
## 4. Goal-Driven Execution
**Define success criteria. Loop until verified.**
Transform tasks into verifiable goals:
- "Add validation" → "Write tests for invalid inputs, then make them pass"
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
- "Refactor X" → "Ensure tests pass before and after"
For multi-step tasks, state a brief plan:
```
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
```
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
---
**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes.
# CLAUDE.md
Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed.
**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment.
## 1. Think Before Coding
**Don't assume. Don't hide confusion. Surface tradeoffs.**
Before implementing:
- State your assumptions explicitly. If uncertain, ask.
- If multiple interpretations exist, present them - don't pick silently.
- If a simpler approach exists, say so. Push back when warranted.
- If something is unclear, stop. Name what's confusing. Ask.
## 2. Simplicity First
**Minimum code that solves the problem. Nothing speculative.**
- No features beyond what was asked.
- No abstractions for single-use code.
- No "flexibility" or "configurability" that wasn't requested.
- No error handling for impossible scenarios.
- If you write 200 lines and it could be 50, rewrite it.
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
## 3. Surgical Changes
**Touch only what you must. Clean up only your own mess.**
When editing existing code:
- Don't "improve" adjacent code, comments, or formatting.
- Don't refactor things that aren't broken.
- Match existing style, even if you'd do it differently.
- If you notice unrelated dead code, mention it - don't delete it.
When your changes create orphans:
- Remove imports/variables/functions that YOUR changes made unused.
- Don't remove pre-existing dead code unless asked.
The test: Every changed line should trace directly to the user's request.
## 4. Goal-Driven Execution
**Define success criteria. Loop until verified.**
Transform tasks into verifiable goals:
- "Add validation" → "Write tests for invalid inputs, then make them pass"
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
- "Refactor X" → "Ensure tests pass before and after"
For multi-step tasks, state a brief plan:
```
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
```
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
---
**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes.
<!-- gitnexus:start -->
# GitNexus — Code Intelligence
This project is indexed by GitNexus as **lingma-openai-gateway** (1093 symbols, 2685 relationships, 97 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely.
> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first.
## Always Do
- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user.
- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows.
- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits.
- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance.
- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`.
## Never Do
- NEVER edit a function, class, or method without first running `gitnexus_impact` on it.
- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis.
- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph.
- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope.
## Resources
| Resource | Use for |
|----------|---------|
| `gitnexus://repo/lingma-openai-gateway/context` | Codebase overview, check index freshness |
| `gitnexus://repo/lingma-openai-gateway/clusters` | All functional areas |
| `gitnexus://repo/lingma-openai-gateway/processes` | All execution flows |
| `gitnexus://repo/lingma-openai-gateway/process/{name}` | Step-by-step execution trace |
## CLI
| Task | Read this skill file |
|------|---------------------|
| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` |
| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` |
| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` |
| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` |
| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` |
| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` |
<!-- gitnexus:end -->

View File

@@ -5,6 +5,11 @@ import os
from dataclasses import dataclass, field
def _csv_env(raw: str) -> list[str]:
return [item.strip() for item in (raw or "").replace("\n", ",").split(",") if item.strip()]
@dataclass
class LingmaAccount:
username: str
@@ -45,6 +50,7 @@ class Settings:
session_cache_max_entries: int = 256
session_cache_ttl_sec: float = 1800.0
tool_forward_enabled: bool = False
tool_allowlist: list[str] = field(default_factory=list)
def _bool_env(name: str, default: bool) -> bool:
@@ -177,4 +183,5 @@ def load_settings() -> Settings:
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
tool_allowlist=_csv_env(os.getenv("TOOL_ALLOWLIST", "")),
)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import ast
import asyncio
import hashlib
import json
@@ -358,6 +359,68 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage"))
def _tool_allowlist() -> set[str]:
return {name.strip() for name in settings.tool_allowlist if isinstance(name, str) and name.strip()}
def _openai_tool_name(tool: Any) -> str | None:
if not isinstance(tool, dict):
return None
if tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
name = tool.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
return None
def _anthropic_tool_name(tool: Any) -> str | None:
if not isinstance(tool, dict):
return None
name = tool.get("name")
if isinstance(name, str) and name.strip():
return name.strip()
fn = tool.get("function")
if isinstance(fn, dict):
nested_name = fn.get("name")
if isinstance(nested_name, str) and nested_name.strip():
return nested_name.strip()
return None
def _filter_allowed_tools(tools: list[dict[str, Any]], *, provider: str) -> list[dict[str, Any]]:
allowlist = _tool_allowlist()
if not allowlist:
return tools
name_fn = _openai_tool_name if provider == "openai" else _anthropic_tool_name
return [tool for tool in tools if (name := name_fn(tool)) and name in allowlist]
def _ensure_tool_choice_allowed(tool_choice: Any, *, provider: str) -> None:
allowlist = _tool_allowlist()
if not allowlist:
return
forced_name = (
_openai_forced_tool_name(tool_choice)
if provider == "openai"
else _anthropic_forced_tool_name(tool_choice)
)
if forced_name and forced_name not in allowlist:
raise HTTPException(
status_code=400,
detail={
"error": {
"type": "invalid_request_error",
"message": f"tool '{forced_name}' is not allowed",
}
},
)
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
@@ -365,9 +428,11 @@ def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
_ensure_tool_choice_allowed(req.tool_choice, provider="openai")
tools = _filter_allowed_tools(req.tools or [], provider="openai")
return {
"provider": "openai",
"tools": req.tools or [],
"tools": tools,
"tool_choice": req.tool_choice,
}
@@ -379,9 +444,11 @@ def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | No
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
_ensure_tool_choice_allowed(req.tool_choice, provider="anthropic")
tools = _filter_allowed_tools(req.tools or [], provider="anthropic")
return {
"provider": "anthropic",
"tools": req.tools or [],
"tools": tools,
"tool_choice": req.tool_choice,
}
@@ -537,8 +604,85 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None:
return parsed if isinstance(parsed, dict) else None
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None:
if not isinstance(tools, list):
return None
for tool in tools:
if not isinstance(tool, dict):
continue
schema: dict[str, Any] | None = None
if tool.get("type") == "function":
fn = tool.get("function")
if isinstance(fn, dict) and fn.get("name") == forced_tool_name:
params = fn.get("parameters")
if isinstance(params, dict):
schema = params
elif tool.get("name") == forced_tool_name:
input_schema = tool.get("input_schema")
if isinstance(input_schema, dict):
schema = input_schema
if not isinstance(schema, dict):
continue
properties = schema.get("properties")
if not isinstance(properties, dict) or len(properties) != 1:
return None
only_name = next(iter(properties.keys()), None)
if isinstance(only_name, str) and only_name.strip():
return only_name
return None
return None
def _tool_code_object_from_text(
text: str,
forced_tool_name: str,
*,
single_arg_name: str | None = None,
) -> dict[str, Any] | None:
raw = text.strip()
if not raw.startswith("```tool_code") or not raw.endswith("```"):
return None
lines = raw.splitlines()
if len(lines) < 2:
return None
body = "\n".join(lines[1:-1]).strip()
try:
parsed = ast.parse(body, mode="eval")
except Exception:
return None
call = parsed.body
if not isinstance(call, ast.Call):
return None
if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name:
return None
arguments: dict[str, Any] = {}
if call.args:
if len(call.args) != 1 or call.keywords or not single_arg_name:
return None
try:
arguments[single_arg_name] = ast.literal_eval(call.args[0])
except Exception:
return None
return {"arguments": arguments}
for kw in call.keywords:
if kw.arg is None:
return None
try:
arguments[kw.arg] = ast.literal_eval(kw.value)
except Exception:
return None
return {"arguments": arguments}
def _forced_tool_event_from_text(
text: str,
forced_tool_name: str,
*,
single_arg_name: str | None = None,
) -> dict[str, Any] | None:
parsed = _json_object_from_text(text)
if parsed is None:
parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name)
if parsed is None:
return None
@@ -756,11 +900,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
completion_tokens_holder = {"n": 0}
stream_meta: dict = {}
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
forced_tool_single_arg_name = _tool_code_single_arg_name(req.tools, forced_tool_name) if forced_tool_name else None
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
tool_call_indexes: dict[str, int] = {}
saw_tool_call = False
buffered_text_parts: list[str] = []
try:
async for chunk in _inst.client.chat_stream(
prompt,
@@ -809,6 +956,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
text = _stream_text(chunk)
if not text:
continue
buffered_text_parts.append(text)
completion_tokens_holder["n"] += estimate_tokens(text)
payload = {
"id": completion_id,
@@ -825,6 +973,39 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
if not saw_tool_call and forced_tool_name:
fallback_event = _forced_tool_event_from_text(
"".join(buffered_text_parts),
forced_tool_name,
single_arg_name=forced_tool_single_arg_name,
)
if fallback_event is not None:
saw_tool_call = True
tool_id = "call_fallback_0"
idx = 0
tool_call_indexes[tool_id] = idx
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": idx,
**_openai_tool_call(fallback_event, forced_id=tool_id),
}
]
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
done_payload = {
"id": completion_id,
"object": "chat.completion.chunk",
@@ -945,7 +1126,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
if not saw_tool_call:
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
if forced_tool_name:
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
fallback_event = _forced_tool_event_from_text(
message_content,
forced_tool_name,
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
)
if fallback_event is not None:
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
saw_tool_call = True
@@ -1169,6 +1354,42 @@ async def _responses_stream_from_chat_stream(
created_at = int(time.time())
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
completed_sent = False
output_item_id = f"msg_{uuid.uuid4().hex}"
output_index = 0
content_index = 0
output_text_parts: list[str] = []
function_call_items: list[dict[str, Any]] = []
function_call_index_by_id: dict[str, int] = {}
function_call_arguments_by_id: dict[str, str] = {}
function_call_name_by_id: dict[str, str] = {}
function_call_id_by_upstream_index: dict[int, str] = {}
def _message_item(status: str) -> dict[str, Any]:
return {
"id": output_item_id,
"type": "message",
"role": "assistant",
"status": status,
"content": [
{
"type": "output_text",
"text": "".join(output_text_parts),
}
],
}
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
return {
"id": call_id,
"type": "function_call",
"call_id": call_id,
"name": name,
"arguments": arguments,
"status": status,
}
def _completed_output_items() -> list[dict[str, Any]]:
return [_message_item("completed"), *function_call_items]
def _completed_frame() -> str:
return _sse_data(
@@ -1180,11 +1401,94 @@ async def _responses_stream_from_chat_stream(
"created_at": created_at,
"status": "completed",
"model": model,
"output": _completed_output_items(),
"usage": usage,
},
}
)
def _finish_output_item_frames() -> list[str]:
frames = [
_sse_data(
{
"type": "response.output_text.done",
"response_id": response_id,
"item_id": output_item_id,
"output_index": output_index,
"content_index": content_index,
"text": "".join(output_text_parts),
}
),
_sse_data(
{
"type": "response.output_item.done",
"response_id": response_id,
"output_index": output_index,
"item": _message_item("completed"),
}
),
]
for idx, item in enumerate(function_call_items, start=1):
frames.append(
_sse_data(
{
"type": "response.function_call_arguments.done",
"response_id": response_id,
"item_id": item["id"],
"output_index": idx,
"arguments": item["arguments"],
}
)
)
frames.append(
_sse_data(
{
"type": "response.output_item.done",
"response_id": response_id,
"output_index": idx,
"item": item,
}
)
)
return frames
def _ensure_function_call_item(call_id: str) -> list[str]:
existing_index = function_call_index_by_id.get(call_id)
name = function_call_name_by_id.get(call_id, "tool")
arguments = function_call_arguments_by_id.get(call_id, "")
if existing_index is not None:
function_call_items[existing_index] = _function_call_item(
call_id,
status="completed",
name=name,
arguments=arguments,
)
return []
item = _function_call_item(
call_id,
status="completed",
name=name,
arguments=arguments,
)
function_call_items.append(item)
item_index = len(function_call_items) - 1
function_call_index_by_id[call_id] = item_index
return [
_sse_data(
{
"type": "response.output_item.added",
"response_id": response_id,
"output_index": item_index + 1,
"item": _function_call_item(
call_id,
status="in_progress",
name=name,
arguments="",
),
}
)
]
yield _sse_data(
{
"type": "response.created",
@@ -1194,9 +1498,18 @@ async def _responses_stream_from_chat_stream(
"created_at": created_at,
"status": "in_progress",
"model": model,
"output": [],
},
}
)
yield _sse_data(
{
"type": "response.output_item.added",
"response_id": response_id,
"output_index": output_index,
"item": _message_item("in_progress"),
}
)
try:
async for part in chat_stream.body_iterator:
@@ -1207,6 +1520,8 @@ async def _responses_stream_from_chat_stream(
continue
body = frame[len("data:") :].strip()
if body == "[DONE]":
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
@@ -1229,10 +1544,14 @@ async def _responses_stream_from_chat_stream(
text = delta.get("content")
if isinstance(text, str) and text:
output_text_parts.append(text)
yield _sse_data(
{
"type": "response.output_text.delta",
"response_id": response_id,
"item_id": output_item_id,
"output_index": output_index,
"content_index": content_index,
"delta": text,
}
)
@@ -1243,35 +1562,59 @@ async def _responses_stream_from_chat_stream(
if not isinstance(tool_call, dict):
continue
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
call_id = str(tool_call.get("id") or f"call_{idx}")
upstream_index_raw = tool_call.get("index")
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
call_id = str(
tool_call.get("id")
or function_call_id_by_upstream_index.get(upstream_index)
or f"call_{upstream_index}"
)
function_call_id_by_upstream_index[upstream_index] = call_id
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
function_call_name_by_id[call_id] = name
arguments_delta = str(fn.get("arguments") or "")
accumulated_arguments = (
function_call_arguments_by_id.get(call_id, "") + arguments_delta
)
function_call_arguments_by_id[call_id] = accumulated_arguments
for event in _ensure_function_call_item(call_id):
yield event
if arguments_delta:
yield _sse_data(
{
"type": "response.function_call.delta",
"type": "response.function_call_arguments.delta",
"response_id": response_id,
"item_id": call_id,
"name": str(fn.get("name") or "tool"),
"arguments": str(fn.get("arguments") or "{}"),
"output_index": function_call_index_by_id[call_id] + 1,
"delta": arguments_delta,
}
)
except asyncio.CancelledError:
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
return
except Exception:
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
completed_sent = True
return
if not completed_sent:
for event in _finish_output_item_frames():
yield event
yield _completed_frame()
yield "data: [DONE]\n\n"
@app.post("/responses", dependencies=[Depends(auth_guard)])
@app.post("/v1/responses", dependencies=[Depends(auth_guard)])
async def v1_responses(req: ResponsesRequest, request: Request):
chat_req = _responses_to_chat_request(req)
@@ -1388,7 +1731,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
)
# ------------------------------------------------------------- session reuse
try:
tool_config = _anthropic_tool_config(req)
except HTTPException as exc:
detail = exc.detail if isinstance(exc.detail, dict) else {}
error = detail.get("error") if isinstance(detail.get("error"), dict) else {}
message = error.get("message") or str(detail) or "invalid tool configuration"
return _anthropic_error(exc.status_code, "invalid_request_error", message)
has_tooling_context = _anthropic_has_tooling_context(req)
ask_mode = _resolve_ask_mode(req.model, has_tooling_context)
@@ -1749,7 +2098,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if not saw_tool_event:
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
if forced_tool_name:
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
fallback_event = _forced_tool_event_from_text(
text,
forced_tool_name,
single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name),
)
if fallback_event is not None:
content_blocks = []
tool_id = "toolu_fallback_0"

View File

@@ -187,6 +187,17 @@ class ConfigParsingTests(unittest.TestCase):
settings_without_accounts = load_settings()
self.assertEqual(settings_without_accounts.instance_count, 1)
def test_load_settings_parses_tool_allowlist_csv(self) -> None:
with patch.dict(os.environ, {"TOOL_ALLOWLIST": " lookup , write_file ,,search_docs "}, clear=True):
settings = load_settings()
self.assertEqual(settings.tool_allowlist, ["lookup", "write_file", "search_docs"])
def test_load_settings_empty_tool_allowlist(self) -> None:
with patch.dict(os.environ, {"TOOL_ALLOWLIST": " , , "}, clear=True):
settings = load_settings()
self.assertEqual(settings.tool_allowlist, [])
if __name__ == "__main__":

View File

@@ -263,6 +263,120 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
{"query": "gateway"},
)
async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "```tool_code\nlookup(query=\"gateway\")\n```",
"toolEvents": [],
"sessionId": "sess-fallback-tool-code-openai",
},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
payload = json.loads(response.body)
message = payload["choices"][0]["message"]
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
self.assertEqual(message["content"], "")
self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup")
self.assertEqual(
json.loads(message["tool_calls"][0]["function"]["arguments"]),
{"query": "gateway"},
)
async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool_with_positional_arg(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "```tool_code\nlookup(\"gateway\")\n```",
"toolEvents": [],
"sessionId": "sess-fallback-tool-code-openai-positional",
},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[
{
"type": "function",
"function": {
"name": "lookup",
"parameters": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
}
],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
payload = json.loads(response.body)
message = payload["choices"][0]["message"]
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
self.assertEqual(message["content"], "")
self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup")
self.assertEqual(
json.loads(message["tool_calls"][0]["function"]["arguments"]),
{"query": "gateway"},
)
async def test_openai_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None:
fake_client = _FakeClient(
stream_events=[
{"type": "text", "text": "```tool_code\n"},
{"type": "text", "text": 'lookup(query=\"gateway\")\n'},
{"type": "text", "text": "```"},
],
complete_result={},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=True,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
body = await _collect_stream(response)
self.assertIn('"tool_calls"', body)
self.assertIn('"name": "lookup"', body)
self.assertIn('{"query": "gateway"}', body)
self.assertIn('"finish_reason": "tool_calls"', body)
self.assertIn('data: [DONE]', body)
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
fake_client = _FakeClient(
stream_events=[
@@ -300,6 +414,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn('"usage"', body)
self.assertIn("data: [DONE]", body)
async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None:
fake_client = _FakeClient(
stream_events=[],
@@ -605,6 +720,57 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(spy_client.last_complete_args[2], "agent")
async def test_openai_non_stream_filters_tools_by_allowlist(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[
{"type": "function", "function": {"name": "lookup", "parameters": {}}},
{"type": "function", "function": {"name": "write_file", "parameters": {}}},
],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
cfg = spy_client.last_complete_kwargs["tool_config"]
self.assertEqual([tool["function"]["name"] for tool in cfg["tools"]], ["lookup"])
self.assertEqual(cfg["tool_choice"], req.tool_choice)
async def test_openai_non_stream_rejects_forced_tool_outside_allowlist(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "write_file"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]),
):
with self.assertRaises(main.HTTPException) as cm:
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertEqual(cm.exception.status_code, 400)
self.assertEqual(cm.exception.detail["error"]["type"], "invalid_request_error")
self.assertIn("write_file", cm.exception.detail["error"]["message"])
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
@@ -757,6 +923,74 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertEqual(len(cfg["tools"]), 1)
self.assertEqual(spy_client.last_complete_args[2], "agent")
async def test_anthropic_non_stream_filters_tools_by_allowlist(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=128,
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[
{"name": "lookup", "input_schema": {"type": "object", "properties": {}}},
{"name": "write_file", "input_schema": {"type": "object", "properties": {}}},
],
tool_choice={"type": "tool", "name": "lookup"},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
_SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]),
):
await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
cfg = spy_client.last_complete_kwargs["tool_config"]
self.assertEqual([tool["name"] for tool in cfg["tools"]], ["lookup"])
self.assertEqual(cfg["tool_choice"], req.tool_choice)
async def test_anthropic_non_stream_rejects_forced_tool_outside_allowlist(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=128,
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
tool_choice={"type": "tool", "name": "write_file"},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
_SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
self.assertEqual(response.status_code, 400)
payload = json.loads(response.body)
self.assertEqual(payload["type"], "error")
self.assertEqual(payload["error"]["type"], "invalid_request_error")
self.assertIn("write_file", payload["error"]["message"])
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
@@ -833,6 +1067,54 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
messages_dump = [m.model_dump() for m in chat_req.messages]
self.assertEqual(messages_dump, [{"role": "user", "content": "hello from responses", "name": None, "tool_call_id": None, "tool_calls": None}])
async def test_responses_non_stream_maps_chat_tool_calls_to_function_call_output(self) -> None:
req = ResponsesRequest(
model="org_auto",
input="tool please",
stream=False,
)
chat_payload = {
"id": "chatcmpl-tools1",
"created": 234,
"model": "org_auto",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "lookup",
"arguments": "{\"q\":\"gateway\"}",
},
}
],
},
}
],
"usage": {"prompt_tokens": 8, "completion_tokens": 3, "total_tokens": 11},
}
mock_chat = AsyncMock(return_value=JSONResponse(content=chat_payload))
with patch.object(main, "v1_chat_completions", mock_chat):
response = await main.v1_responses(req, _make_request("/v1/responses"))
payload = json.loads(response.body)
self.assertEqual(payload["status"], "completed")
self.assertEqual(payload["output_text"], "")
self.assertEqual(payload["usage"], {"input_tokens": 8, "output_tokens": 3, "total_tokens": 11})
self.assertEqual(len(payload["output"]), 1)
self.assertEqual(payload["output"][0]["type"], "function_call")
self.assertEqual(payload["output"][0]["call_id"], "call_1")
self.assertEqual(payload["output"][0]["id"], "call_1")
self.assertEqual(payload["output"][0]["name"], "lookup")
self.assertEqual(payload["output"][0]["arguments"], "{\"q\":\"gateway\"}")
async def test_responses_forwards_input_tools_and_tool_choice_to_chat_request(self) -> None:
req = ResponsesRequest(
model="org_auto",
@@ -883,17 +1165,70 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
response = await main.v1_responses(req, _make_request("/v1/responses"))
body = await _collect_stream(response)
self.assertIn('"type": "response.created"', body)
self.assertIn('"type": "response.output_item.added"', body)
self.assertIn('"type": "response.output_text.delta"', body)
self.assertIn('"delta": "hello"', body)
self.assertIn('"type": "response.function_call.delta"', body)
self.assertIn('"type": "response.function_call_arguments.delta"', body)
self.assertIn('"item_id": "call_1"', body)
self.assertIn('"output_index": 1', body)
self.assertIn('"delta": "{\\"q\\": \\"x\\"}"', body)
self.assertIn('"type": "response.function_call_arguments.done"', body)
self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body)
self.assertIn('"type": "response.output_item.done"', body)
self.assertIn('"type": "function_call"', body)
self.assertIn('"name": "lookup"', body)
self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body)
self.assertIn('"type": "response.completed"', body)
self.assertIn('"input_tokens": 3', body)
self.assertIn('"output_tokens": 2', body)
self.assertIn('data: [DONE]', body)
async def test_responses_stream_accumulates_fragmented_tool_arguments(self) -> None:
async def _chat_sse():
yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n'
yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": " \\\"x\\\"}"}}]}}]}\n\n'
yield b"data: [DONE]\n\n"
req = ResponsesRequest(model="org_auto", input="hi", stream=True)
mock_chat = AsyncMock(
return_value=StreamingResponse(_chat_sse(), media_type="text/event-stream")
)
with patch.object(main, "v1_chat_completions", mock_chat):
response = await main.v1_responses(req, _make_request("/v1/responses"))
body = await _collect_stream(response)
self.assertIn('"type": "response.function_call_arguments.delta"', body)
self.assertIn('"delta": "{\\"q\\":"', body)
self.assertIn('"delta": " \\\"x\\\"}"', body)
self.assertIn('"type": "response.function_call_arguments.done"', body)
self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body)
self.assertIn('"type": "response.output_item.done"', body)
self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body)
self.assertIn('data: [DONE]', body)
async def test_responses_stream_accumulates_fragmented_tool_arguments_without_repeated_id_or_name(self) -> None:
async def _chat_sse():
yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n'
yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\\"x\\\"}"}}]}}]}\n\n'
yield b"data: [DONE]\n\n"
req = ResponsesRequest(model="org_auto", input="hi", stream=True)
mock_chat = AsyncMock(
return_value=StreamingResponse(_chat_sse(), media_type="text/event-stream")
)
with patch.object(main, "v1_chat_completions", mock_chat):
response = await main.v1_responses(req, _make_request("/v1/responses"))
body = await _collect_stream(response)
self.assertEqual(body.count('"item_id": "call_1"'), 3)
self.assertIn('"name": "lookup"', body)
self.assertIn('"delta": "{\\"q\\":"', body)
self.assertIn('"delta": " \\\"x\\\"}"', body)
self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body)
self.assertIn('data: [DONE]', body)
async def test_responses_stream_emits_completed_when_upstream_closes_without_done(self) -> None:
async def _chat_sse_without_done():
yield b'data: {"choices": [{"delta": {"content": "partial"}}]}\n\n'
@@ -954,7 +1289,31 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn('data: [DONE]', body)
async def test_responses_non_stream_returns_502_on_invalid_upstream_json(self) -> None:
async def test_responses_alias_matches_v1_responses_behavior(self) -> None:
req = ResponsesRequest(model="org_auto", input="hello", stream=False)
chat_payload = {
"id": "chatcmpl-alias1",
"created": 123,
"model": "org_auto",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "done"},
}
],
"usage": {"prompt_tokens": 4, "completion_tokens": 2, "total_tokens": 6},
}
mock_chat = AsyncMock(return_value=JSONResponse(content=chat_payload))
with patch.object(main, "v1_chat_completions", mock_chat):
response = await main.v1_responses(req, _make_request("/responses"))
payload = json.loads(response.body)
self.assertEqual(payload["id"], "resp_alias1")
self.assertEqual(payload["status"], "completed")
mock_chat.assert_awaited_once()
req = ResponsesRequest(model="org_auto", input="hi", stream=False)
mock_chat = AsyncMock(return_value=Response(content="not-json", media_type="text/plain"))