From d0df08928232bd2f20325a5726edf9ba9cbde831 Mon Sep 17 00:00:00 2001 From: mmc <853506518@qq.com> Date: Mon, 20 Apr 2026 19:24:02 +0800 Subject: [PATCH] fix: harden responses streaming and tool-call fallback Ensure /v1/responses streams always terminate with response.completed and normalize Lingma tool_code fallbacks into structured tool calls, including single-argument forms. Co-Authored-By: Claude Opus 4.7 --- .env.example | 2 + CLAUDE.md | 177 +++++++++++++++ app/config.py | 7 + app/main.py | 383 ++++++++++++++++++++++++++++++-- tests/test_pool_stats_config.py | 11 + tests/test_tool_call_bridge.py | 365 +++++++++++++++++++++++++++++- 6 files changed, 927 insertions(+), 18 deletions(-) diff --git a/.env.example b/.env.example index bf063c3..d75be95 100644 --- a/.env.example +++ b/.env.example @@ -48,6 +48,8 @@ DEFAULT_ASK_MODE=chat # 请求侧 tools/tool_choice 透传到 Lingma(默认关闭,开启后可支持工具写文件等场景) TOOL_FORWARD_ENABLED=false +# 可选:允许透传的工具名白名单,逗号分隔;为空表示不额外限制 +TOOL_ALLOWLIST= # 专属域(可选) DEDICATED_DOMAIN_URL= diff --git a/CLAUDE.md b/CLAUDE.md index 5d70d99..9687fe7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -93,3 +93,180 @@ Both protocols share the same backend pool, backpressure guard, stats, and sessi - Compose mounts: - `./data -> /app/data` (persistent Lingma binary/cache/workdirs) - `./secrets -> /secrets:ro` (session bundles, secrets) + + +# CLAUDE.md + +Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed. + +**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment. + +## 1. Think Before Coding + +**Don't assume. Don't hide confusion. Surface tradeoffs.** + +Before implementing: +- State your assumptions explicitly. If uncertain, ask. +- If multiple interpretations exist, present them - don't pick silently. +- If a simpler approach exists, say so. Push back when warranted. +- If something is unclear, stop. Name what's confusing. Ask. + +## 2. Simplicity First + +**Minimum code that solves the problem. Nothing speculative.** + +- No features beyond what was asked. +- No abstractions for single-use code. +- No "flexibility" or "configurability" that wasn't requested. +- No error handling for impossible scenarios. +- If you write 200 lines and it could be 50, rewrite it. + +Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify. + +## 3. Surgical Changes + +**Touch only what you must. Clean up only your own mess.** + +When editing existing code: +- Don't "improve" adjacent code, comments, or formatting. +- Don't refactor things that aren't broken. +- Match existing style, even if you'd do it differently. +- If you notice unrelated dead code, mention it - don't delete it. + +When your changes create orphans: +- Remove imports/variables/functions that YOUR changes made unused. +- Don't remove pre-existing dead code unless asked. + +The test: Every changed line should trace directly to the user's request. + +## 4. Goal-Driven Execution + +**Define success criteria. Loop until verified.** + +Transform tasks into verifiable goals: +- "Add validation" → "Write tests for invalid inputs, then make them pass" +- "Fix the bug" → "Write a test that reproduces it, then make it pass" +- "Refactor X" → "Ensure tests pass before and after" + +For multi-step tasks, state a brief plan: +``` +1. [Step] → verify: [check] +2. [Step] → verify: [check] +3. [Step] → verify: [check] +``` + +Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification. + +--- + +**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes. + +# CLAUDE.md + +Behavioral guidelines to reduce common LLM coding mistakes. Merge with project-specific instructions as needed. + +**Tradeoff:** These guidelines bias toward caution over speed. For trivial tasks, use judgment. + +## 1. Think Before Coding + +**Don't assume. Don't hide confusion. Surface tradeoffs.** + +Before implementing: +- State your assumptions explicitly. If uncertain, ask. +- If multiple interpretations exist, present them - don't pick silently. +- If a simpler approach exists, say so. Push back when warranted. +- If something is unclear, stop. Name what's confusing. Ask. + +## 2. Simplicity First + +**Minimum code that solves the problem. Nothing speculative.** + +- No features beyond what was asked. +- No abstractions for single-use code. +- No "flexibility" or "configurability" that wasn't requested. +- No error handling for impossible scenarios. +- If you write 200 lines and it could be 50, rewrite it. + +Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify. + +## 3. Surgical Changes + +**Touch only what you must. Clean up only your own mess.** + +When editing existing code: +- Don't "improve" adjacent code, comments, or formatting. +- Don't refactor things that aren't broken. +- Match existing style, even if you'd do it differently. +- If you notice unrelated dead code, mention it - don't delete it. + +When your changes create orphans: +- Remove imports/variables/functions that YOUR changes made unused. +- Don't remove pre-existing dead code unless asked. + +The test: Every changed line should trace directly to the user's request. + +## 4. Goal-Driven Execution + +**Define success criteria. Loop until verified.** + +Transform tasks into verifiable goals: +- "Add validation" → "Write tests for invalid inputs, then make them pass" +- "Fix the bug" → "Write a test that reproduces it, then make it pass" +- "Refactor X" → "Ensure tests pass before and after" + +For multi-step tasks, state a brief plan: +``` +1. [Step] → verify: [check] +2. [Step] → verify: [check] +3. [Step] → verify: [check] +``` + +Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification. + +--- + +**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes. + + +# GitNexus — Code Intelligence + +This project is indexed by GitNexus as **lingma-openai-gateway** (1093 symbols, 2685 relationships, 97 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. + +> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. + +## Always Do + +- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user. +- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows. +- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits. +- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance. +- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`. + +## Never Do + +- NEVER edit a function, class, or method without first running `gitnexus_impact` on it. +- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis. +- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph. +- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope. + +## Resources + +| Resource | Use for | +|----------|---------| +| `gitnexus://repo/lingma-openai-gateway/context` | Codebase overview, check index freshness | +| `gitnexus://repo/lingma-openai-gateway/clusters` | All functional areas | +| `gitnexus://repo/lingma-openai-gateway/processes` | All execution flows | +| `gitnexus://repo/lingma-openai-gateway/process/{name}` | Step-by-step execution trace | + +## CLI + +| Task | Read this skill file | +|------|---------------------| +| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` | +| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` | +| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` | +| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | +| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | +| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | + + diff --git a/app/config.py b/app/config.py index 75f9e3d..ea1e4d7 100644 --- a/app/config.py +++ b/app/config.py @@ -5,6 +5,11 @@ import os from dataclasses import dataclass, field + +def _csv_env(raw: str) -> list[str]: + return [item.strip() for item in (raw or "").replace("\n", ",").split(",") if item.strip()] + + @dataclass class LingmaAccount: username: str @@ -45,6 +50,7 @@ class Settings: session_cache_max_entries: int = 256 session_cache_ttl_sec: float = 1800.0 tool_forward_enabled: bool = False + tool_allowlist: list[str] = field(default_factory=list) def _bool_env(name: str, default: bool) -> bool: @@ -177,4 +183,5 @@ def load_settings() -> Settings: session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")), session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")), tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False), + tool_allowlist=_csv_env(os.getenv("TOOL_ALLOWLIST", "")), ) diff --git a/app/main.py b/app/main.py index 420fb19..dd262bd 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import asyncio import hashlib import json @@ -358,6 +359,68 @@ def _include_usage(stream_options: dict | None) -> bool: return bool(stream_options.get("include_usage")) +def _tool_allowlist() -> set[str]: + return {name.strip() for name in settings.tool_allowlist if isinstance(name, str) and name.strip()} + + +def _openai_tool_name(tool: Any) -> str | None: + if not isinstance(tool, dict): + return None + if tool.get("type") == "function": + fn = tool.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + name = tool.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _anthropic_tool_name(tool: Any) -> str | None: + if not isinstance(tool, dict): + return None + name = tool.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + fn = tool.get("function") + if isinstance(fn, dict): + nested_name = fn.get("name") + if isinstance(nested_name, str) and nested_name.strip(): + return nested_name.strip() + return None + + +def _filter_allowed_tools(tools: list[dict[str, Any]], *, provider: str) -> list[dict[str, Any]]: + allowlist = _tool_allowlist() + if not allowlist: + return tools + name_fn = _openai_tool_name if provider == "openai" else _anthropic_tool_name + return [tool for tool in tools if (name := name_fn(tool)) and name in allowlist] + + +def _ensure_tool_choice_allowed(tool_choice: Any, *, provider: str) -> None: + allowlist = _tool_allowlist() + if not allowlist: + return + forced_name = ( + _openai_forced_tool_name(tool_choice) + if provider == "openai" + else _anthropic_forced_tool_name(tool_choice) + ) + if forced_name and forced_name not in allowlist: + raise HTTPException( + status_code=400, + detail={ + "error": { + "type": "invalid_request_error", + "message": f"tool '{forced_name}' is not allowed", + } + }, + ) + + def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None: if not settings.tool_forward_enabled: return None @@ -365,9 +428,11 @@ def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None: has_choice = req.tool_choice is not None if not has_tools and not has_choice: return None + _ensure_tool_choice_allowed(req.tool_choice, provider="openai") + tools = _filter_allowed_tools(req.tools or [], provider="openai") return { "provider": "openai", - "tools": req.tools or [], + "tools": tools, "tool_choice": req.tool_choice, } @@ -379,9 +444,11 @@ def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | No has_choice = req.tool_choice is not None if not has_tools and not has_choice: return None + _ensure_tool_choice_allowed(req.tool_choice, provider="anthropic") + tools = _filter_allowed_tools(req.tools or [], provider="anthropic") return { "provider": "anthropic", - "tools": req.tools or [], + "tools": tools, "tool_choice": req.tool_choice, } @@ -537,8 +604,85 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None: return parsed if isinstance(parsed, dict) else None -def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None: +def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None: + if not isinstance(tools, list): + return None + for tool in tools: + if not isinstance(tool, dict): + continue + schema: dict[str, Any] | None = None + if tool.get("type") == "function": + fn = tool.get("function") + if isinstance(fn, dict) and fn.get("name") == forced_tool_name: + params = fn.get("parameters") + if isinstance(params, dict): + schema = params + elif tool.get("name") == forced_tool_name: + input_schema = tool.get("input_schema") + if isinstance(input_schema, dict): + schema = input_schema + if not isinstance(schema, dict): + continue + properties = schema.get("properties") + if not isinstance(properties, dict) or len(properties) != 1: + return None + only_name = next(iter(properties.keys()), None) + if isinstance(only_name, str) and only_name.strip(): + return only_name + return None + return None + + +def _tool_code_object_from_text( + text: str, + forced_tool_name: str, + *, + single_arg_name: str | None = None, +) -> dict[str, Any] | None: + raw = text.strip() + if not raw.startswith("```tool_code") or not raw.endswith("```"): + return None + lines = raw.splitlines() + if len(lines) < 2: + return None + body = "\n".join(lines[1:-1]).strip() + try: + parsed = ast.parse(body, mode="eval") + except Exception: + return None + call = parsed.body + if not isinstance(call, ast.Call): + return None + if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name: + return None + arguments: dict[str, Any] = {} + if call.args: + if len(call.args) != 1 or call.keywords or not single_arg_name: + return None + try: + arguments[single_arg_name] = ast.literal_eval(call.args[0]) + except Exception: + return None + return {"arguments": arguments} + for kw in call.keywords: + if kw.arg is None: + return None + try: + arguments[kw.arg] = ast.literal_eval(kw.value) + except Exception: + return None + return {"arguments": arguments} + + +def _forced_tool_event_from_text( + text: str, + forced_tool_name: str, + *, + single_arg_name: str | None = None, +) -> dict[str, Any] | None: parsed = _json_object_from_text(text) + if parsed is None: + parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name) if parsed is None: return None @@ -756,11 +900,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_tokens_holder = {"n": 0} stream_meta: dict = {} + forced_tool_name = _openai_forced_tool_name(req.tool_choice) + forced_tool_single_arg_name = _tool_code_single_arg_name(req.tools, forced_tool_name) if forced_tool_name else None async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False tool_call_indexes: dict[str, int] = {} saw_tool_call = False + buffered_text_parts: list[str] = [] try: async for chunk in _inst.client.chat_stream( prompt, @@ -809,6 +956,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): text = _stream_text(chunk) if not text: continue + buffered_text_parts.append(text) completion_tokens_holder["n"] += estimate_tokens(text) payload = { "id": completion_id, @@ -825,6 +973,39 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + if not saw_tool_call and forced_tool_name: + fallback_event = _forced_tool_event_from_text( + "".join(buffered_text_parts), + forced_tool_name, + single_arg_name=forced_tool_single_arg_name, + ) + if fallback_event is not None: + saw_tool_call = True + tool_id = "call_fallback_0" + idx = 0 + tool_call_indexes[tool_id] = idx + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": idx, + **_openai_tool_call(fallback_event, forced_id=tool_id), + } + ] + }, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + done_payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -945,7 +1126,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): if not saw_tool_call: forced_tool_name = _openai_forced_tool_name(req.tool_choice) if forced_tool_name: - fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name) + fallback_event = _forced_tool_event_from_text( + message_content, + forced_tool_name, + single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name), + ) if fallback_event is not None: tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0")) saw_tool_call = True @@ -1169,6 +1354,42 @@ async def _responses_stream_from_chat_stream( created_at = int(time.time()) usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} completed_sent = False + output_item_id = f"msg_{uuid.uuid4().hex}" + output_index = 0 + content_index = 0 + output_text_parts: list[str] = [] + function_call_items: list[dict[str, Any]] = [] + function_call_index_by_id: dict[str, int] = {} + function_call_arguments_by_id: dict[str, str] = {} + function_call_name_by_id: dict[str, str] = {} + function_call_id_by_upstream_index: dict[int, str] = {} + + def _message_item(status: str) -> dict[str, Any]: + return { + "id": output_item_id, + "type": "message", + "role": "assistant", + "status": status, + "content": [ + { + "type": "output_text", + "text": "".join(output_text_parts), + } + ], + } + + def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]: + return { + "id": call_id, + "type": "function_call", + "call_id": call_id, + "name": name, + "arguments": arguments, + "status": status, + } + + def _completed_output_items() -> list[dict[str, Any]]: + return [_message_item("completed"), *function_call_items] def _completed_frame() -> str: return _sse_data( @@ -1180,11 +1401,94 @@ async def _responses_stream_from_chat_stream( "created_at": created_at, "status": "completed", "model": model, + "output": _completed_output_items(), "usage": usage, }, } ) + def _finish_output_item_frames() -> list[str]: + frames = [ + _sse_data( + { + "type": "response.output_text.done", + "response_id": response_id, + "item_id": output_item_id, + "output_index": output_index, + "content_index": content_index, + "text": "".join(output_text_parts), + } + ), + _sse_data( + { + "type": "response.output_item.done", + "response_id": response_id, + "output_index": output_index, + "item": _message_item("completed"), + } + ), + ] + for idx, item in enumerate(function_call_items, start=1): + frames.append( + _sse_data( + { + "type": "response.function_call_arguments.done", + "response_id": response_id, + "item_id": item["id"], + "output_index": idx, + "arguments": item["arguments"], + } + ) + ) + frames.append( + _sse_data( + { + "type": "response.output_item.done", + "response_id": response_id, + "output_index": idx, + "item": item, + } + ) + ) + return frames + + def _ensure_function_call_item(call_id: str) -> list[str]: + existing_index = function_call_index_by_id.get(call_id) + name = function_call_name_by_id.get(call_id, "tool") + arguments = function_call_arguments_by_id.get(call_id, "") + if existing_index is not None: + function_call_items[existing_index] = _function_call_item( + call_id, + status="completed", + name=name, + arguments=arguments, + ) + return [] + item = _function_call_item( + call_id, + status="completed", + name=name, + arguments=arguments, + ) + function_call_items.append(item) + item_index = len(function_call_items) - 1 + function_call_index_by_id[call_id] = item_index + return [ + _sse_data( + { + "type": "response.output_item.added", + "response_id": response_id, + "output_index": item_index + 1, + "item": _function_call_item( + call_id, + status="in_progress", + name=name, + arguments="", + ), + } + ) + ] + yield _sse_data( { "type": "response.created", @@ -1194,9 +1498,18 @@ async def _responses_stream_from_chat_stream( "created_at": created_at, "status": "in_progress", "model": model, + "output": [], }, } ) + yield _sse_data( + { + "type": "response.output_item.added", + "response_id": response_id, + "output_index": output_index, + "item": _message_item("in_progress"), + } + ) try: async for part in chat_stream.body_iterator: @@ -1207,6 +1520,8 @@ async def _responses_stream_from_chat_stream( continue body = frame[len("data:") :].strip() if body == "[DONE]": + for event in _finish_output_item_frames(): + yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True @@ -1229,10 +1544,14 @@ async def _responses_stream_from_chat_stream( text = delta.get("content") if isinstance(text, str) and text: + output_text_parts.append(text) yield _sse_data( { "type": "response.output_text.delta", "response_id": response_id, + "item_id": output_item_id, + "output_index": output_index, + "content_index": content_index, "delta": text, } ) @@ -1243,35 +1562,59 @@ async def _responses_stream_from_chat_stream( if not isinstance(tool_call, dict): continue fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {} - call_id = str(tool_call.get("id") or f"call_{idx}") - yield _sse_data( - { - "type": "response.function_call.delta", - "response_id": response_id, - "item_id": call_id, - "name": str(fn.get("name") or "tool"), - "arguments": str(fn.get("arguments") or "{}"), - } + upstream_index_raw = tool_call.get("index") + upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx + call_id = str( + tool_call.get("id") + or function_call_id_by_upstream_index.get(upstream_index) + or f"call_{upstream_index}" ) + function_call_id_by_upstream_index[upstream_index] = call_id + name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool") + function_call_name_by_id[call_id] = name + arguments_delta = str(fn.get("arguments") or "") + accumulated_arguments = ( + function_call_arguments_by_id.get(call_id, "") + arguments_delta + ) + function_call_arguments_by_id[call_id] = accumulated_arguments + for event in _ensure_function_call_item(call_id): + yield event + if arguments_delta: + yield _sse_data( + { + "type": "response.function_call_arguments.delta", + "response_id": response_id, + "item_id": call_id, + "output_index": function_call_index_by_id[call_id] + 1, + "delta": arguments_delta, + } + ) except asyncio.CancelledError: if not completed_sent: + for event in _finish_output_item_frames(): + yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True return except Exception: if not completed_sent: + for event in _finish_output_item_frames(): + yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True return if not completed_sent: + for event in _finish_output_item_frames(): + yield event yield _completed_frame() yield "data: [DONE]\n\n" +@app.post("/responses", dependencies=[Depends(auth_guard)]) @app.post("/v1/responses", dependencies=[Depends(auth_guard)]) async def v1_responses(req: ResponsesRequest, request: Request): chat_req = _responses_to_chat_request(req) @@ -1388,7 +1731,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ) # ------------------------------------------------------------- session reuse - tool_config = _anthropic_tool_config(req) + try: + tool_config = _anthropic_tool_config(req) + except HTTPException as exc: + detail = exc.detail if isinstance(exc.detail, dict) else {} + error = detail.get("error") if isinstance(detail.get("error"), dict) else {} + message = error.get("message") or str(detail) or "invalid tool configuration" + return _anthropic_error(exc.status_code, "invalid_request_error", message) has_tooling_context = _anthropic_has_tooling_context(req) ask_mode = _resolve_ask_mode(req.model, has_tooling_context) @@ -1749,7 +2098,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): if not saw_tool_event: forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) if forced_tool_name: - fallback_event = _forced_tool_event_from_text(text, forced_tool_name) + fallback_event = _forced_tool_event_from_text( + text, + forced_tool_name, + single_arg_name=_tool_code_single_arg_name(req.tools, forced_tool_name), + ) if fallback_event is not None: content_blocks = [] tool_id = "toolu_fallback_0" diff --git a/tests/test_pool_stats_config.py b/tests/test_pool_stats_config.py index 81099db..37af01c 100644 --- a/tests/test_pool_stats_config.py +++ b/tests/test_pool_stats_config.py @@ -187,6 +187,17 @@ class ConfigParsingTests(unittest.TestCase): settings_without_accounts = load_settings() self.assertEqual(settings_without_accounts.instance_count, 1) + def test_load_settings_parses_tool_allowlist_csv(self) -> None: + with patch.dict(os.environ, {"TOOL_ALLOWLIST": " lookup , write_file ,,search_docs "}, clear=True): + settings = load_settings() + + self.assertEqual(settings.tool_allowlist, ["lookup", "write_file", "search_docs"]) + + def test_load_settings_empty_tool_allowlist(self) -> None: + with patch.dict(os.environ, {"TOOL_ALLOWLIST": " , , "}, clear=True): + settings = load_settings() + + self.assertEqual(settings.tool_allowlist, []) if __name__ == "__main__": diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index ea7f12a..e196f86 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -263,6 +263,120 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): {"query": "gateway"}, ) + async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "```tool_code\nlookup(query=\"gateway\")\n```", + "toolEvents": [], + "sessionId": "sess-fallback-tool-code-openai", + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + payload = json.loads(response.body) + message = payload["choices"][0]["message"] + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) + + async def test_openai_non_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool_with_positional_arg(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "```tool_code\nlookup(\"gateway\")\n```", + "toolEvents": [], + "sessionId": "sess-fallback-tool-code-openai-positional", + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + payload = json.loads(response.body) + message = payload["choices"][0]["message"] + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") + self.assertEqual(message["content"], "") + self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) + + async def test_openai_stream_fallbacks_to_tool_code_structured_tool_call_for_forced_tool(self) -> None: + fake_client = _FakeClient( + stream_events=[ + {"type": "text", "text": "```tool_code\n"}, + {"type": "text", "text": 'lookup(query=\"gateway\")\n'}, + {"type": "text", "text": "```"}, + ], + complete_result={}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=True, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + body = await _collect_stream(response) + + self.assertIn('"tool_calls"', body) + self.assertIn('"name": "lookup"', body) + self.assertIn('{"query": "gateway"}', body) + self.assertIn('"finish_reason": "tool_calls"', body) + self.assertIn('data: [DONE]', body) + async def test_openai_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[ @@ -300,6 +414,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"usage"', body) self.assertIn("data: [DONE]", body) + async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None: fake_client = _FakeClient( stream_events=[], @@ -605,6 +720,57 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(spy_client.last_complete_args[2], "agent") + + async def test_openai_non_stream_filters_tools_by_allowlist(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}}, + {"type": "function", "function": {"name": "write_file", "parameters": {}}}, + ], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + cfg = spy_client.last_complete_kwargs["tool_config"] + self.assertEqual([tool["function"]["name"] for tool in cfg["tools"]], ["lookup"]) + self.assertEqual(cfg["tool_choice"], req.tool_choice) + + async def test_openai_non_stream_rejects_forced_tool_outside_allowlist(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "write_file"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + with self.assertRaises(main.HTTPException) as cm: + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + self.assertEqual(cm.exception.status_code, 400) + self.assertEqual(cm.exception.detail["error"]["type"], "invalid_request_error") + self.assertIn("write_file", cm.exception.detail["error"]["message"]) + async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None: fake_cache = _FakeSessionCache() fake_client = _FakeClient( @@ -757,6 +923,74 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(len(cfg["tools"]), 1) self.assertEqual(spy_client.last_complete_args[2], "agent") + + async def test_anthropic_non_stream_filters_tools_by_allowlist(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}}, + {"name": "write_file", "input_schema": {"type": "object", "properties": {}}}, + ], + tool_choice={"type": "tool", "name": "lookup"}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + cfg = spy_client.last_complete_kwargs["tool_config"] + self.assertEqual([tool["name"] for tool in cfg["tools"]], ["lookup"]) + self.assertEqual(cfg["tool_choice"], req.tool_choice) + + async def test_anthropic_non_stream_rejects_forced_tool_outside_allowlist(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tool_choice={"type": "tool", "name": "write_file"}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + self.assertEqual(response.status_code, 400) + payload = json.loads(response.body) + self.assertEqual(payload["type"], "error") + self.assertEqual(payload["error"]["type"], "invalid_request_error") + self.assertIn("write_file", payload["error"]["message"]) + async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None: fake_cache = _FakeSessionCache() fake_client = _FakeClient( @@ -833,6 +1067,54 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): messages_dump = [m.model_dump() for m in chat_req.messages] self.assertEqual(messages_dump, [{"role": "user", "content": "hello from responses", "name": None, "tool_call_id": None, "tool_calls": None}]) + async def test_responses_non_stream_maps_chat_tool_calls_to_function_call_output(self) -> None: + req = ResponsesRequest( + model="org_auto", + input="tool please", + stream=False, + ) + chat_payload = { + "id": "chatcmpl-tools1", + "created": 234, + "model": "org_auto", + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "lookup", + "arguments": "{\"q\":\"gateway\"}", + }, + } + ], + }, + } + ], + "usage": {"prompt_tokens": 8, "completion_tokens": 3, "total_tokens": 11}, + } + + mock_chat = AsyncMock(return_value=JSONResponse(content=chat_payload)) + with patch.object(main, "v1_chat_completions", mock_chat): + response = await main.v1_responses(req, _make_request("/v1/responses")) + + payload = json.loads(response.body) + self.assertEqual(payload["status"], "completed") + self.assertEqual(payload["output_text"], "") + self.assertEqual(payload["usage"], {"input_tokens": 8, "output_tokens": 3, "total_tokens": 11}) + self.assertEqual(len(payload["output"]), 1) + self.assertEqual(payload["output"][0]["type"], "function_call") + self.assertEqual(payload["output"][0]["call_id"], "call_1") + self.assertEqual(payload["output"][0]["id"], "call_1") + self.assertEqual(payload["output"][0]["name"], "lookup") + self.assertEqual(payload["output"][0]["arguments"], "{\"q\":\"gateway\"}") + async def test_responses_forwards_input_tools_and_tool_choice_to_chat_request(self) -> None: req = ResponsesRequest( model="org_auto", @@ -883,17 +1165,70 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): response = await main.v1_responses(req, _make_request("/v1/responses")) body = await _collect_stream(response) - self.assertIn('"type": "response.created"', body) + self.assertIn('"type": "response.output_item.added"', body) self.assertIn('"type": "response.output_text.delta"', body) self.assertIn('"delta": "hello"', body) - self.assertIn('"type": "response.function_call.delta"', body) + self.assertIn('"type": "response.function_call_arguments.delta"', body) self.assertIn('"item_id": "call_1"', body) + self.assertIn('"output_index": 1', body) + self.assertIn('"delta": "{\\"q\\": \\"x\\"}"', body) + self.assertIn('"type": "response.function_call_arguments.done"', body) + self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body) + self.assertIn('"type": "response.output_item.done"', body) + self.assertIn('"type": "function_call"', body) self.assertIn('"name": "lookup"', body) + self.assertIn('"arguments": "{\\"q\\": \\"x\\"}"', body) self.assertIn('"type": "response.completed"', body) self.assertIn('"input_tokens": 3', body) self.assertIn('"output_tokens": 2', body) self.assertIn('data: [DONE]', body) + async def test_responses_stream_accumulates_fragmented_tool_arguments(self) -> None: + async def _chat_sse(): + yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n' + yield b'data: {"choices": [{"delta": {"tool_calls": [{"id": "call_1", "function": {"name": "lookup", "arguments": " \\\"x\\\"}"}}]}}]}\n\n' + yield b"data: [DONE]\n\n" + + req = ResponsesRequest(model="org_auto", input="hi", stream=True) + mock_chat = AsyncMock( + return_value=StreamingResponse(_chat_sse(), media_type="text/event-stream") + ) + + with patch.object(main, "v1_chat_completions", mock_chat): + response = await main.v1_responses(req, _make_request("/v1/responses")) + body = await _collect_stream(response) + + self.assertIn('"type": "response.function_call_arguments.delta"', body) + self.assertIn('"delta": "{\\"q\\":"', body) + self.assertIn('"delta": " \\\"x\\\"}"', body) + self.assertIn('"type": "response.function_call_arguments.done"', body) + self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) + self.assertIn('"type": "response.output_item.done"', body) + self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) + self.assertIn('data: [DONE]', body) + + async def test_responses_stream_accumulates_fragmented_tool_arguments_without_repeated_id_or_name(self) -> None: + async def _chat_sse(): + yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "id": "call_1", "function": {"name": "lookup", "arguments": "{\\"q\\":"}}]}}]}\n\n' + yield b'data: {"choices": [{"delta": {"tool_calls": [{"index": 0, "function": {"arguments": " \\\"x\\\"}"}}]}}]}\n\n' + yield b"data: [DONE]\n\n" + + req = ResponsesRequest(model="org_auto", input="hi", stream=True) + mock_chat = AsyncMock( + return_value=StreamingResponse(_chat_sse(), media_type="text/event-stream") + ) + + with patch.object(main, "v1_chat_completions", mock_chat): + response = await main.v1_responses(req, _make_request("/v1/responses")) + body = await _collect_stream(response) + + self.assertEqual(body.count('"item_id": "call_1"'), 3) + self.assertIn('"name": "lookup"', body) + self.assertIn('"delta": "{\\"q\\":"', body) + self.assertIn('"delta": " \\\"x\\\"}"', body) + self.assertIn('"arguments": "{\\"q\\": \\\"x\\\"}"', body) + self.assertIn('data: [DONE]', body) + async def test_responses_stream_emits_completed_when_upstream_closes_without_done(self) -> None: async def _chat_sse_without_done(): yield b'data: {"choices": [{"delta": {"content": "partial"}}]}\n\n' @@ -954,7 +1289,31 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('data: [DONE]', body) - async def test_responses_non_stream_returns_502_on_invalid_upstream_json(self) -> None: + async def test_responses_alias_matches_v1_responses_behavior(self) -> None: + req = ResponsesRequest(model="org_auto", input="hello", stream=False) + chat_payload = { + "id": "chatcmpl-alias1", + "created": 123, + "model": "org_auto", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"role": "assistant", "content": "done"}, + } + ], + "usage": {"prompt_tokens": 4, "completion_tokens": 2, "total_tokens": 6}, + } + + mock_chat = AsyncMock(return_value=JSONResponse(content=chat_payload)) + with patch.object(main, "v1_chat_completions", mock_chat): + response = await main.v1_responses(req, _make_request("/responses")) + + payload = json.loads(response.body) + self.assertEqual(payload["id"], "resp_alias1") + self.assertEqual(payload["status"], "completed") + mock_chat.assert_awaited_once() + req = ResponsesRequest(model="org_auto", input="hi", stream=False) mock_chat = AsyncMock(return_value=Response(content="not-json", media_type="text/plain"))