diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..5d70d99 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,95 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Primary docs to read first +- `README.md` (runtime commands, env model, API examples) +- `DESIGN.md` (architecture decisions, module boundaries, request lifecycle) +- `.env.example` (authoritative env var reference) + +No Cursor/Copilot rule files were found in this repo (`.cursorrules`, `.cursor/rules/`, `.github/copilot-instructions.md`). + +## Common development commands + +### Start locally +```bash +pip install -r requirements.txt +uvicorn app.main:app --reload --port 8317 +``` + +### Start with Docker Compose +```bash +cp .env.example .env +mkdir -p data secrets +docker compose up -d --build +docker compose logs -f +``` + +### Run tests +```bash +# current focused suite +python3 -m unittest tests/test_tool_call_bridge.py + +# discover all unittest tests under tests/ +python3 -m unittest discover -s tests -p "test_*.py" + +# run a single test method +python3 -m unittest tests.test_tool_call_bridge.ToolCallBridgeTests.test_openai_non_stream_bridges_tool_calls +``` + +### Smoke-check running gateway +```bash +API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1) +curl -s http://127.0.0.1:8317/healthz +curl -s http://127.0.0.1:8317/v1/models -H "Authorization: Bearer $API_KEY" +``` + +### Linting/type-checking status +- There is currently no repo-configured lint/type command (no `ruff`/`flake8`/`mypy` config found). +- Do not invent tooling commands; if linting is needed, add tooling in a dedicated change first. + +## Architecture (big picture) + +### What this service is +A FastAPI gateway that fronts Lingma and exposes: +- OpenAI-compatible API (`/v1/models`, `/v1/chat/completions`) +- Anthropic Messages-compatible API (`/v1/messages`, `/v1/messages/count_tokens`) + +Both protocols share the same backend pool, backpressure guard, stats, and session reuse logic. + +### Request lifecycle (important for most changes) +1. Authenticate request (`app/auth.py`) +2. Normalize inbound protocol payload to internal message shape (`openai_schema.py` / `anthropic_schema.py`) +3. Session-cache lookup (`app/session_cache.py`) for prefix-based reuse +4. Pick backend instance (`app/lingma_pool.py`) with affinity + least-in-flight +5. Acquire concurrency ticket (`app/concurrency.py`) +6. Call Lingma via websocket/LSP client (`app/lingma_client.py`) +7. Map upstream result/stream back to wire protocol in `app/main.py` +8. Record stats and release ticket (including stream-finally paths) + +### Core module boundaries +- `app/main.py`: API entrypoint + orchestration + wire-format adapters +- `app/lingma_pool.py`: multi-instance lifecycle, selection, health-aware fallback +- `app/lingma_client.py`: subprocess + LSP-over-WebSocket transport to Lingma +- `app/session_cache.py`: LRU+TTL cache of conversation-prefix -> upstream session id (+ instance binding) +- `app/concurrency.py`: in-flight guard and queue timeout/backpressure behavior +- `app/stats.py`: usage counters and Prometheus text + +### Protocol-specific notes +- Anthropic and OpenAI endpoints are separate adapters over shared internals. +- Response-side tool bridge is implemented: upstream Lingma tool events are surfaced as: + - OpenAI: `tool_calls` (stream + non-stream) + - Anthropic: `tool_use` / `tool_result` blocks (stream + non-stream) +- Request-side `tools` / `tool_choice` are accepted by schemas but not forwarded to Lingma. + +### Operational invariants to preserve +- One request must stay on one Lingma instance for session continuity. +- Session cache entries include instance identity; invalidate on unhealthy instance mismatch. +- Streaming paths must always release in-flight tickets in `finally`. +- Multi-instance mode must use isolated workdirs per instance. + +### Deployment/runtime model +- Container startup runs `python /app/app/bootstrap_lingma.py` before uvicorn. +- Compose mounts: + - `./data -> /app/data` (persistent Lingma binary/cache/workdirs) + - `./secrets -> /secrets:ro` (session bundles, secrets) diff --git a/DESIGN.md b/DESIGN.md index 43c35a6..30696db 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -47,7 +47,8 @@ - **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。 - **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。 -- **完整 function calling / tools**:OpenAI schema 里保留了字段,但目前不透传给 Lingma(Lingma 侧没有等价能力)。 +- **请求侧完整 function calling / tools 透传**:OpenAI schema 里保留了字段,但目前不会把 `tools`/`tool_choice` 透传给 Lingma(上游无等价输入协议)。 +- **响应侧工具事件桥接**:若 Lingma 上游产出 tool 事件,网关会向 OpenAI 输出 `tool_calls`,向 Anthropic 输出 `tool_use` / `tool_result`(stream + non-stream)。 - **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。 --- @@ -591,7 +592,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进 | 需求 | 改哪些文件 | 关键入口 | |---|---|---| | 加一个新的 OpenAI 端点(如 embeddings) | `main.py`, `openai_schema.py` | 仿照 `v1_models` 加 `@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` | -| 扩展 Anthropic 端点(如 count_tokens / tool_use 贯通) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;tool_use 需要 Lingma 上游支持,payload 转发点在 `chat_stream` / `chat_complete` | +| 扩展 Anthropic 端点(如 count_tokens / tool_use 相关能力) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;响应侧 `tool_use/tool_result` 桥接已支持,若要请求侧 tools 透传仍需改 `lingma_client.py` payload | | 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin | | 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 | | 增加限流(按 api_key 配额) | `concurrency.py` 加 `PerKeyGuard`;`main.py` 在 `chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) | @@ -627,7 +628,7 @@ uvicorn app.main:app --reload --port 8317 | 标签 | 描述 | 影响 | 计划 | |---|---|---|---| | D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 | -| D3 | 无单元测试骨架 | 重构要靠 deploy 验证 | 想加 CI 时优先补 | +| D3 | 已有基础单测覆盖 tool-call bridge(OpenAI/Anthropic,stream + non-stream),但整体测试矩阵仍不完整 | 回归仍依赖手工验证与定向测试 | 后续补充会话复用、背压、鉴权和异常路径用例 | | Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint,涉及数据迁移,谨慎推进 | | ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 | | Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback,但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 | diff --git a/README.md b/README.md index 64549c1..16c6f8d 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,8 @@ curl -N http://127.0.0.1:8317/v1/messages \ 说明: - **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。 - **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。 -- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision);`tool_use` / `tool_result` 会以纯文本形式保留语义。 +- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision)。 +- **工具事件桥接**:当 Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(含 stream/non-stream)和 Anthropic `tool_use`/`tool_result` blocks(含 stream/non-stream);但请求侧 `tools`/`tool_choice` 仍不会透传到 Lingma。 - **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。 ### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`) diff --git a/app/lingma_client.py b/app/lingma_client.py index 2bb503b..f0447b2 100644 --- a/app/lingma_client.py +++ b/app/lingma_client.py @@ -9,7 +9,7 @@ import subprocess import time import uuid from pathlib import Path -from typing import AsyncIterator, Callable, Optional +from typing import Any, AsyncIterator, Callable, Optional import websockets @@ -103,6 +103,58 @@ class LspWsRpcClient: self._on_disconnect = on_disconnect self._closed = False + @staticmethod + def _extract_tool_event(params: dict[str, Any]) -> dict[str, Any] | None: + candidates: list[dict[str, Any]] = [] + if isinstance(params.get("toolCall"), dict): + candidates.append(params["toolCall"]) + if isinstance(params.get("tool_call"), dict): + candidates.append(params["tool_call"]) + if isinstance(params.get("tool"), dict): + candidates.append(params["tool"]) + data = params.get("data") + if isinstance(data, dict): + if isinstance(data.get("toolCall"), dict): + candidates.append(data["toolCall"]) + if isinstance(data.get("tool_call"), dict): + candidates.append(data["tool_call"]) + if isinstance(data.get("tool"), dict): + candidates.append(data["tool"]) + + if not candidates: + return None + + raw = candidates[0] + tool_id = ( + raw.get("toolCallId") + or raw.get("tool_call_id") + or raw.get("id") + or params.get("toolCallId") + or params.get("tool_call_id") + ) + name = raw.get("name") or raw.get("toolName") or raw.get("tool_name") + call_input = raw.get("input") + if call_input is None: + call_input = raw.get("arguments") + if call_input is None: + call_input = raw.get("args") + + result_payload = raw.get("result") + if result_payload is None: + result_payload = params.get("result") + if result_payload is None and isinstance(data, dict): + result_payload = data.get("result") + + if not tool_id: + return None + + return { + "id": str(tool_id), + "name": str(name or "tool"), + "input": call_input if call_input is not None else {}, + "result": result_payload, + } + async def start(self): self._reader_task = asyncio.create_task(self._reader_loop()) @@ -185,7 +237,16 @@ class LspWsRpcClient: stream["parts"].append(text) if stream["first_chunk_at"] is None: stream["first_chunk_at"] = time.monotonic() - stream["chunks"].put_nowait(text) + stream["chunks"].put_nowait({"type": "text", "text": text}) + + if method in {"tool/call/sync", "tool/invoke", "tool/call/approve"}: + req_id = params.get("requestId") + stream = self._chat_streams.get(req_id) + if stream is not None: + tool_event = self._extract_tool_event(params) + if tool_event is not None: + stream["tool_events"].append(tool_event) + stream["chunks"].put_nowait({"type": "tool", "tool": tool_event}) if method == "chat/finish": req_id = params.get("requestId") @@ -224,6 +285,7 @@ class LspWsRpcClient: "chunks": asyncio.Queue(), "done": asyncio.Event(), "finish": None, + "tool_events": [], "started_at": time.monotonic(), "first_chunk_at": None, "finish_at": None, @@ -239,7 +301,7 @@ class LspWsRpcClient: with contextlib.suppress(Exception): stream["chunks"].put_nowait(None) - async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]: + async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[dict[str, Any]]: stream = self._chat_streams.get(request_id) if stream is None: return @@ -266,6 +328,7 @@ class LspWsRpcClient: "finish": stream.get("finish") or {}, "firstTokenLatencyMs": first_ms, "totalLatencyMs": total_ms, + "toolEvents": stream.get("tool_events") or [], } @@ -722,8 +785,12 @@ class LingmaGatewayClient: session_id: str | None = None, is_reply: bool = False, out_meta: dict | None = None, - ) -> AsyncIterator[str]: - """Stream `chat/answer` chunks. + ) -> AsyncIterator[dict[str, Any]]: + """Stream chat events. + + Yields structured events: + * {"type": "text", "text": "..."} + * {"type": "tool", "tool": {...}} If `out_meta` is provided, the final `chat/finish` payload's sessionId (and the raw finish dict) is written into it when the stream ends or is @@ -739,10 +806,10 @@ class LingmaGatewayClient: self.rpc.create_stream(request_id) try: await self._kick_chat_ask(payload) - async for chunk in self.rpc.consume_stream( + async for event in self.rpc.consume_stream( request_id, timeout=max(60.0, self.rpc_timeout + 60.0) ): - yield chunk + yield event finally: # Runs on normal completion, exception, or consumer GeneratorExit (client disconnect). if out_meta is not None: @@ -753,6 +820,7 @@ class LingmaGatewayClient: out_meta["finish"] = finish out_meta["request_id"] = request_id out_meta["chars"] = len(stream_result.get("text") or "") + out_meta["tool_events"] = stream_result.get("toolEvents") or [] except Exception: pass self.rpc.pop_stream(request_id) diff --git a/app/main.py b/app/main.py index 79066f5..082c30c 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ import json import time import uuid from contextlib import asynccontextmanager +from typing import Any from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse @@ -350,6 +351,78 @@ def _include_usage(stream_options: dict | None) -> bool: return bool(stream_options.get("include_usage")) +def _stream_event_type(event: Any) -> str: + if isinstance(event, dict): + t = event.get("type") + if t in {"text", "tool"}: + return t + return "text" + + +def _stream_text(event: Any) -> str: + if isinstance(event, dict): + if event.get("type") == "text": + text = event.get("text") + if isinstance(text, str): + return text + return "" + if isinstance(event, str): + return event + return "" + + +def _stream_tool_event(event: Any) -> dict[str, Any] | None: + if isinstance(event, dict) and event.get("type") == "tool": + tool = event.get("tool") + if isinstance(tool, dict): + return tool + return None + + +def _json_string(value: Any) -> str: + if isinstance(value, str): + return value + try: + return json.dumps(value if value is not None else {}, ensure_ascii=False) + except Exception: + return "{}" + + +def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]: + return { + "id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"), + "type": "function", + "function": { + "name": str(tool.get("name") or "tool"), + "arguments": _json_string(tool.get("input")), + }, + } + + +def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]: + return { + "type": "tool_use", + "id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"), + "name": str(tool.get("name") or "tool"), + "input": tool.get("input") if tool.get("input") is not None else {}, + } + + +def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None: + if "result" not in tool: + return None + result = tool.get("result") + if isinstance(result, str): + content: Any = result + else: + content = _json_string(result) + return { + "type": "tool_result", + "tool_use_id": str(tool.get("id") or ""), + "content": content, + } + + @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): p = _require_pool() @@ -485,7 +558,37 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): is_reply=is_reply, out_meta=_meta, ): - completion_tokens_holder["n"] += estimate_tokens(chunk) + if _stream_event_type(chunk) == "tool": + tool = _stream_tool_event(chunk) + if not tool: + continue + payload = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [ + { + "index": 0, + **_openai_tool_call(tool), + } + ] + }, + "finish_reason": None, + } + ], + } + yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + continue + + text = _stream_text(chunk) + if not text: + continue + completion_tokens_holder["n"] += estimate_tokens(text) payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -494,7 +597,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "choices": [ { "index": 0, - "delta": {"content": chunk}, + "delta": {"content": text}, "finish_reason": None, } ], @@ -596,6 +699,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): sid = result.get("sessionId") if sid: await session_cache.put(write_key, sid, inst.name) + tool_events = result.get("toolEvents") or [] + message_content = result.get("text") or "" + tool_calls: list[dict[str, Any]] = [] + if isinstance(tool_events, list): + for item in tool_events: + if isinstance(item, dict): + tool_calls.append(_openai_tool_call(item)) response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), @@ -604,10 +714,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ChatCompletionChoice( index=0, finish_reason="stop", - message={"role": "assistant", "content": result.get("text") or ""}, + message={ + "role": "assistant", + "content": message_content, + "tool_calls": tool_calls or None, + }, ) ], ) + data = response.model_dump() data["latency"] = { "first_token_ms": result.get("firstTokenLatencyMs"), @@ -810,6 +925,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False + block_index = 0 + text_block_open = False try: # 1) message_start — Anthropic SDKs read this first to get # the message envelope (id/model/initial usage). @@ -833,17 +950,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): } yield _sse("message_start", start_payload) - # 2) content_block_start for a single text block (index 0). - yield _sse( - "content_block_start", - { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - }, - ) - - # 3) content_block_delta stream of text tokens. async for chunk in _inst.client.chat_stream( prompt, model, @@ -852,23 +958,80 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): is_reply=is_reply, out_meta=_meta, ): - if not chunk: + if _stream_event_type(chunk) == "tool": + if text_block_open: + yield _sse( + "content_block_stop", + {"type": "content_block_stop", "index": block_index}, + ) + block_index += 1 + text_block_open = False + + tool = _stream_tool_event(chunk) + if not tool: + continue + + tool_use_block = _anthropic_tool_use_block(tool) + yield _sse( + "content_block_start", + { + "type": "content_block_start", + "index": block_index, + "content_block": tool_use_block, + }, + ) + yield _sse( + "content_block_stop", + {"type": "content_block_stop", "index": block_index}, + ) + block_index += 1 + + tool_result_block = _anthropic_tool_result_block(tool) + if tool_result_block is not None: + yield _sse( + "content_block_start", + { + "type": "content_block_start", + "index": block_index, + "content_block": tool_result_block, + }, + ) + yield _sse( + "content_block_stop", + {"type": "content_block_stop", "index": block_index}, + ) + block_index += 1 continue - completion_tokens_holder["n"] += estimate_tokens(chunk) + + text = _stream_text(chunk) + if not text: + continue + completion_tokens_holder["n"] += estimate_tokens(text) + if not text_block_open: + yield _sse( + "content_block_start", + { + "type": "content_block_start", + "index": block_index, + "content_block": {"type": "text", "text": ""}, + }, + ) + text_block_open = True + yield _sse( "content_block_delta", { "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": chunk}, + "index": block_index, + "delta": {"type": "text_delta", "text": text}, }, ) - # 4) content_block_stop closes the single text block. - yield _sse( - "content_block_stop", - {"type": "content_block_stop", "index": 0}, - ) + if text_block_open: + yield _sse( + "content_block_stop", + {"type": "content_block_stop", "index": block_index}, + ) # 5) message_delta carries the terminal stop_reason and # the final cumulative output_tokens count. @@ -972,12 +1135,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): if sid: await session_cache.put(write_key, sid, inst.name) + content_blocks: list[dict[str, Any]] = [] + if text: + content_blocks.append({"type": "text", "text": text}) + tool_events = result.get("toolEvents") or [] + if isinstance(tool_events, list): + for item in tool_events: + if not isinstance(item, dict): + continue + content_blocks.append(_anthropic_tool_use_block(item)) + tool_result = _anthropic_tool_result_block(item) + if tool_result is not None: + content_blocks.append(tool_result) + response_body: dict = { "id": message_id, "type": "message", "role": "assistant", "model": model, - "content": [{"type": "text", "text": text}], + "content": content_blocks, "stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens), "stop_sequence": None, "usage": { diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py new file mode 100644 index 0000000..e7126eb --- /dev/null +++ b/tests/test_tool_call_bridge.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import json +import sys +import types +import unittest +from unittest.mock import AsyncMock, patch + +# app.main imports playwright via auto_login; tests don't exercise that path. +# Inject a lightweight stub so unit tests run without installing playwright. +_playwright = types.ModuleType("playwright") +_playwright_async = types.ModuleType("playwright.async_api") + + +class _StubPlaywrightTimeoutError(Exception): + pass + + +async def _stub_async_playwright(): + raise RuntimeError("playwright is stubbed in unit tests") + + +_playwright_async.TimeoutError = _StubPlaywrightTimeoutError +_playwright_async.async_playwright = _stub_async_playwright +sys.modules.setdefault("playwright", _playwright) +sys.modules.setdefault("playwright.async_api", _playwright_async) + +from starlette.requests import Request + +from app.anthropic_schema import AnthropicMessagesRequest +from app.openai_schema import ChatCompletionsRequest +import app.main as main + + +class _FakeTicket: + def __init__(self) -> None: + self.released = False + + def release(self) -> None: + self.released = True + + +class _FakeGuard: + def __init__(self) -> None: + self.in_flight = 0 + + async def try_acquire(self) -> _FakeTicket: + return _FakeTicket() + + +class _FakeClient: + def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None: + self._stream_events = stream_events + self._complete_result = complete_result + + async def query_models(self) -> dict: + return { + "chat": [ + { + "key": "org_auto", + "displayName": "Auto", + } + ] + } + + async def chat_complete(self, *args, **kwargs) -> dict: + return self._complete_result + + async def chat_stream(self, *args, **kwargs): + out_meta = kwargs.get("out_meta") + if isinstance(out_meta, dict): + out_meta["session_id"] = "sess-stream" + for event in self._stream_events: + yield event + + +class _FakeInstance: + def __init__(self, client: _FakeClient) -> None: + self.name = "inst-test" + self.client = client + self.in_flight = 0 + + +class _FakePool: + def __init__(self, inst: _FakeInstance) -> None: + self._inst = inst + + def pick(self, affinity_key: str | None = None) -> _FakeInstance: + return self._inst + + +def _make_request(path: str, headers: dict[str, str] | None = None) -> Request: + header_pairs = [] + for k, v in (headers or {}).items(): + header_pairs.append((k.lower().encode("latin-1"), v.encode("latin-1"))) + scope = { + "type": "http", + "http_version": "1.1", + "method": "POST", + "scheme": "http", + "path": path, + "raw_path": path.encode("latin-1"), + "query_string": b"", + "headers": header_pairs, + "client": ("testclient", 12345), + "server": ("testserver", 80), + "root_path": "", + } + return Request(scope) + + +async def _collect_stream(response) -> str: + chunks: list[str] = [] + async for part in response.body_iterator: + if isinstance(part, bytes): + chunks.append(part.decode("utf-8")) + else: + chunks.append(str(part)) + return "".join(chunks) + + +class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): + async def test_openai_non_stream_bridges_tool_calls(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "done", + "toolEvents": [ + { + "id": "call_123", + "name": "search_docs", + "input": {"query": "gateway"}, + "result": {"ok": True}, + } + ], + "sessionId": "sess-1", + "firstTokenLatencyMs": 12, + "totalLatencyMs": 34, + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + payload = json.loads(response.body) + message = payload["choices"][0]["message"] + self.assertEqual(message["content"], "done") + self.assertIsInstance(message["tool_calls"], list) + self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs") + self.assertEqual( + json.loads(message["tool_calls"][0]["function"]["arguments"]), + {"query": "gateway"}, + ) + + async def test_openai_stream_bridges_tool_and_text_events(self) -> None: + fake_client = _FakeClient( + stream_events=[ + { + "type": "tool", + "tool": { + "id": "call_stream_1", + "name": "read_file", + "input": {"path": "README.md"}, + }, + }, + {"type": "text", "text": "hello"}, + ], + complete_result={}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=True, + stream_options={"include_usage": True}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + body = await _collect_stream(response) + + self.assertIn('"tool_calls"', body) + self.assertIn('"content": "hello"', body) + self.assertIn('"usage"', body) + self.assertIn("data: [DONE]", body) + + async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "ok", + "toolEvents": [ + { + "id": "toolu_1", + "name": "lookup", + "input": {"k": "v"}, + "result": {"value": 1}, + } + ], + "sessionId": "sess-2", + }, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + payload = json.loads(response.body) + types = [item["type"] for item in payload["content"]] + self.assertEqual(types, ["text", "tool_use", "tool_result"]) + self.assertEqual(payload["content"][1]["name"], "lookup") + self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") + + async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None: + fake_client = _FakeClient( + stream_events=[ + { + "type": "tool", + "tool": { + "id": "toolu_stream_1", + "name": "read", + "input": {"file": "a.txt"}, + "result": "done", + }, + }, + {"type": "text", "text": "world"}, + ], + complete_result={}, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=True, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + body = await _collect_stream(response) + + self.assertIn("event: message_start", body) + self.assertIn('"type": "tool_use"', body) + self.assertIn('"type": "tool_result"', body) + self.assertIn('"type": "text_delta"', body) + self.assertIn("event: message_stop", body) + + +if __name__ == "__main__": + unittest.main()