feat: bridge Lingma tool events to OpenAI/Anthropic responses

Add structured tool event propagation from Lingma stream/finish metadata and map it to OpenAI tool_calls and Anthropic tool_use/tool_result in both streaming and non-streaming responses. Add focused bridge tests and update docs/design notes to match current behavior.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-18 22:34:43 +08:00
parent b3fd8800f7
commit 1c7b86e2c0
6 changed files with 668 additions and 35 deletions

95
CLAUDE.md Normal file
View File

@@ -0,0 +1,95 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Primary docs to read first
- `README.md` (runtime commands, env model, API examples)
- `DESIGN.md` (architecture decisions, module boundaries, request lifecycle)
- `.env.example` (authoritative env var reference)
No Cursor/Copilot rule files were found in this repo (`.cursorrules`, `.cursor/rules/`, `.github/copilot-instructions.md`).
## Common development commands
### Start locally
```bash
pip install -r requirements.txt
uvicorn app.main:app --reload --port 8317
```
### Start with Docker Compose
```bash
cp .env.example .env
mkdir -p data secrets
docker compose up -d --build
docker compose logs -f
```
### Run tests
```bash
# current focused suite
python3 -m unittest tests/test_tool_call_bridge.py
# discover all unittest tests under tests/
python3 -m unittest discover -s tests -p "test_*.py"
# run a single test method
python3 -m unittest tests.test_tool_call_bridge.ToolCallBridgeTests.test_openai_non_stream_bridges_tool_calls
```
### Smoke-check running gateway
```bash
API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1)
curl -s http://127.0.0.1:8317/healthz
curl -s http://127.0.0.1:8317/v1/models -H "Authorization: Bearer $API_KEY"
```
### Linting/type-checking status
- There is currently no repo-configured lint/type command (no `ruff`/`flake8`/`mypy` config found).
- Do not invent tooling commands; if linting is needed, add tooling in a dedicated change first.
## Architecture (big picture)
### What this service is
A FastAPI gateway that fronts Lingma and exposes:
- OpenAI-compatible API (`/v1/models`, `/v1/chat/completions`)
- Anthropic Messages-compatible API (`/v1/messages`, `/v1/messages/count_tokens`)
Both protocols share the same backend pool, backpressure guard, stats, and session reuse logic.
### Request lifecycle (important for most changes)
1. Authenticate request (`app/auth.py`)
2. Normalize inbound protocol payload to internal message shape (`openai_schema.py` / `anthropic_schema.py`)
3. Session-cache lookup (`app/session_cache.py`) for prefix-based reuse
4. Pick backend instance (`app/lingma_pool.py`) with affinity + least-in-flight
5. Acquire concurrency ticket (`app/concurrency.py`)
6. Call Lingma via websocket/LSP client (`app/lingma_client.py`)
7. Map upstream result/stream back to wire protocol in `app/main.py`
8. Record stats and release ticket (including stream-finally paths)
### Core module boundaries
- `app/main.py`: API entrypoint + orchestration + wire-format adapters
- `app/lingma_pool.py`: multi-instance lifecycle, selection, health-aware fallback
- `app/lingma_client.py`: subprocess + LSP-over-WebSocket transport to Lingma
- `app/session_cache.py`: LRU+TTL cache of conversation-prefix -> upstream session id (+ instance binding)
- `app/concurrency.py`: in-flight guard and queue timeout/backpressure behavior
- `app/stats.py`: usage counters and Prometheus text
### Protocol-specific notes
- Anthropic and OpenAI endpoints are separate adapters over shared internals.
- Response-side tool bridge is implemented: upstream Lingma tool events are surfaced as:
- OpenAI: `tool_calls` (stream + non-stream)
- Anthropic: `tool_use` / `tool_result` blocks (stream + non-stream)
- Request-side `tools` / `tool_choice` are accepted by schemas but not forwarded to Lingma.
### Operational invariants to preserve
- One request must stay on one Lingma instance for session continuity.
- Session cache entries include instance identity; invalidate on unhealthy instance mismatch.
- Streaming paths must always release in-flight tickets in `finally`.
- Multi-instance mode must use isolated workdirs per instance.
### Deployment/runtime model
- Container startup runs `python /app/app/bootstrap_lingma.py` before uvicorn.
- Compose mounts:
- `./data -> /app/data` (persistent Lingma binary/cache/workdirs)
- `./secrets -> /secrets:ro` (session bundles, secrets)

View File

@@ -47,7 +47,8 @@
- **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。 - **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。
- **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。 - **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。
- **完整 function calling / tools**OpenAI schema 里保留了字段,但目前不透传给 LingmaLingma 侧没有等价能力)。 - **请求侧完整 function calling / tools 透传**OpenAI schema 里保留了字段,但目前不会把 `tools`/`tool_choice` 透传给 Lingma上游无等价输入协议)。
- **响应侧工具事件桥接**:若 Lingma 上游产出 tool 事件,网关会向 OpenAI 输出 `tool_calls`,向 Anthropic 输出 `tool_use` / `tool_result`stream + non-stream
- **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。 - **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。
--- ---
@@ -591,7 +592,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进
| 需求 | 改哪些文件 | 关键入口 | | 需求 | 改哪些文件 | 关键入口 |
|---|---|---| |---|---|---|
| 加一个新的 OpenAI 端点(如 embeddings | `main.py`, `openai_schema.py` | 仿照 `v1_models``@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` | | 加一个新的 OpenAI 端点(如 embeddings | `main.py`, `openai_schema.py` | 仿照 `v1_models``@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` |
| 扩展 Anthropic 端点(如 count_tokens / tool_use 贯通 | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`tool_use 需要 Lingma 上游支持payload 转发点在 `chat_stream` / `chat_complete` | | 扩展 Anthropic 端点(如 count_tokens / tool_use 相关能力 | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`响应侧 `tool_use/tool_result` 桥接已支持,若要请求侧 tools 透传仍需改 `lingma_client.py` payload |
| 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin | | 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin |
| 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 | | 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 |
| 增加限流(按 api_key 配额) | `concurrency.py``PerKeyGuard``main.py``chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) | | 增加限流(按 api_key 配额) | `concurrency.py``PerKeyGuard``main.py``chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) |
@@ -627,7 +628,7 @@ uvicorn app.main:app --reload --port 8317
| 标签 | 描述 | 影响 | 计划 | | 标签 | 描述 | 影响 | 计划 |
|---|---|---|---| |---|---|---|---|
| D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 | | D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 |
| D3 | 无单元测试骨架 | 重构要靠 deploy 验证 | 想加 CI 时优先补 | | D3 | 已有基础单测覆盖 tool-call bridgeOpenAI/Anthropicstream + non-stream但整体测试矩阵仍不完整 | 回归仍依赖手工验证与定向测试 | 后续补充会话复用、背压、鉴权和异常路径用例 |
| Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint涉及数据迁移谨慎推进 | | Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint涉及数据迁移谨慎推进 |
| ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 | | ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 |
| Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 | | Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 |

View File

@@ -221,7 +221,8 @@ curl -N http://127.0.0.1:8317/v1/messages \
说明: 说明:
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key后端实际仍由 Lingma 提供Qwen 系列)。如需显式选模型,直接传 Lingma key`dashscope_qmodel` 等)。 - **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key后端实际仍由 Lingma 提供Qwen 系列)。如需显式选模型,直接传 Lingma key`dashscope_qmodel` 等)。
- **会话复用共享**Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId` - **会话复用共享**Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`
- **多模态**`image` 块会被降级为 `[image]` 占位符Lingma 不支持 vision`tool_use` / `tool_result` 会以纯文本形式保留语义 - **多模态**`image` 块会被降级为 `[image]` 占位符Lingma 不支持 vision
- **工具事件桥接**:当 Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(含 stream/non-stream和 Anthropic `tool_use`/`tool_result` blocks含 stream/non-stream但请求侧 `tools`/`tool_choice` 仍不会透传到 Lingma。
- **鉴权**:优先 `x-api-key`Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。 - **鉴权**:优先 `x-api-key`Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS` ### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`

View File

@@ -9,7 +9,7 @@ import subprocess
import time import time
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import AsyncIterator, Callable, Optional from typing import Any, AsyncIterator, Callable, Optional
import websockets import websockets
@@ -103,6 +103,58 @@ class LspWsRpcClient:
self._on_disconnect = on_disconnect self._on_disconnect = on_disconnect
self._closed = False self._closed = False
@staticmethod
def _extract_tool_event(params: dict[str, Any]) -> dict[str, Any] | None:
candidates: list[dict[str, Any]] = []
if isinstance(params.get("toolCall"), dict):
candidates.append(params["toolCall"])
if isinstance(params.get("tool_call"), dict):
candidates.append(params["tool_call"])
if isinstance(params.get("tool"), dict):
candidates.append(params["tool"])
data = params.get("data")
if isinstance(data, dict):
if isinstance(data.get("toolCall"), dict):
candidates.append(data["toolCall"])
if isinstance(data.get("tool_call"), dict):
candidates.append(data["tool_call"])
if isinstance(data.get("tool"), dict):
candidates.append(data["tool"])
if not candidates:
return None
raw = candidates[0]
tool_id = (
raw.get("toolCallId")
or raw.get("tool_call_id")
or raw.get("id")
or params.get("toolCallId")
or params.get("tool_call_id")
)
name = raw.get("name") or raw.get("toolName") or raw.get("tool_name")
call_input = raw.get("input")
if call_input is None:
call_input = raw.get("arguments")
if call_input is None:
call_input = raw.get("args")
result_payload = raw.get("result")
if result_payload is None:
result_payload = params.get("result")
if result_payload is None and isinstance(data, dict):
result_payload = data.get("result")
if not tool_id:
return None
return {
"id": str(tool_id),
"name": str(name or "tool"),
"input": call_input if call_input is not None else {},
"result": result_payload,
}
async def start(self): async def start(self):
self._reader_task = asyncio.create_task(self._reader_loop()) self._reader_task = asyncio.create_task(self._reader_loop())
@@ -185,7 +237,16 @@ class LspWsRpcClient:
stream["parts"].append(text) stream["parts"].append(text)
if stream["first_chunk_at"] is None: if stream["first_chunk_at"] is None:
stream["first_chunk_at"] = time.monotonic() stream["first_chunk_at"] = time.monotonic()
stream["chunks"].put_nowait(text) stream["chunks"].put_nowait({"type": "text", "text": text})
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve"}:
req_id = params.get("requestId")
stream = self._chat_streams.get(req_id)
if stream is not None:
tool_event = self._extract_tool_event(params)
if tool_event is not None:
stream["tool_events"].append(tool_event)
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
if method == "chat/finish": if method == "chat/finish":
req_id = params.get("requestId") req_id = params.get("requestId")
@@ -224,6 +285,7 @@ class LspWsRpcClient:
"chunks": asyncio.Queue(), "chunks": asyncio.Queue(),
"done": asyncio.Event(), "done": asyncio.Event(),
"finish": None, "finish": None,
"tool_events": [],
"started_at": time.monotonic(), "started_at": time.monotonic(),
"first_chunk_at": None, "first_chunk_at": None,
"finish_at": None, "finish_at": None,
@@ -239,7 +301,7 @@ class LspWsRpcClient:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
stream["chunks"].put_nowait(None) stream["chunks"].put_nowait(None)
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]: async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[dict[str, Any]]:
stream = self._chat_streams.get(request_id) stream = self._chat_streams.get(request_id)
if stream is None: if stream is None:
return return
@@ -266,6 +328,7 @@ class LspWsRpcClient:
"finish": stream.get("finish") or {}, "finish": stream.get("finish") or {},
"firstTokenLatencyMs": first_ms, "firstTokenLatencyMs": first_ms,
"totalLatencyMs": total_ms, "totalLatencyMs": total_ms,
"toolEvents": stream.get("tool_events") or [],
} }
@@ -722,8 +785,12 @@ class LingmaGatewayClient:
session_id: str | None = None, session_id: str | None = None,
is_reply: bool = False, is_reply: bool = False,
out_meta: dict | None = None, out_meta: dict | None = None,
) -> AsyncIterator[str]: ) -> AsyncIterator[dict[str, Any]]:
"""Stream `chat/answer` chunks. """Stream chat events.
Yields structured events:
* {"type": "text", "text": "..."}
* {"type": "tool", "tool": {...}}
If `out_meta` is provided, the final `chat/finish` payload's sessionId If `out_meta` is provided, the final `chat/finish` payload's sessionId
(and the raw finish dict) is written into it when the stream ends or is (and the raw finish dict) is written into it when the stream ends or is
@@ -739,10 +806,10 @@ class LingmaGatewayClient:
self.rpc.create_stream(request_id) self.rpc.create_stream(request_id)
try: try:
await self._kick_chat_ask(payload) await self._kick_chat_ask(payload)
async for chunk in self.rpc.consume_stream( async for event in self.rpc.consume_stream(
request_id, timeout=max(60.0, self.rpc_timeout + 60.0) request_id, timeout=max(60.0, self.rpc_timeout + 60.0)
): ):
yield chunk yield event
finally: finally:
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect). # Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
if out_meta is not None: if out_meta is not None:
@@ -753,6 +820,7 @@ class LingmaGatewayClient:
out_meta["finish"] = finish out_meta["finish"] = finish
out_meta["request_id"] = request_id out_meta["request_id"] = request_id
out_meta["chars"] = len(stream_result.get("text") or "") out_meta["chars"] = len(stream_result.get("text") or "")
out_meta["tool_events"] = stream_result.get("toolEvents") or []
except Exception: except Exception:
pass pass
self.rpc.pop_stream(request_id) self.rpc.pop_stream(request_id)

View File

@@ -6,6 +6,7 @@ import json
import time import time
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any
from fastapi import Depends, FastAPI, HTTPException, Request from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
@@ -350,6 +351,78 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage")) return bool(stream_options.get("include_usage"))
def _stream_event_type(event: Any) -> str:
if isinstance(event, dict):
t = event.get("type")
if t in {"text", "tool"}:
return t
return "text"
def _stream_text(event: Any) -> str:
if isinstance(event, dict):
if event.get("type") == "text":
text = event.get("text")
if isinstance(text, str):
return text
return ""
if isinstance(event, str):
return event
return ""
def _stream_tool_event(event: Any) -> dict[str, Any] | None:
if isinstance(event, dict) and event.get("type") == "tool":
tool = event.get("tool")
if isinstance(tool, dict):
return tool
return None
def _json_string(value: Any) -> str:
if isinstance(value, str):
return value
try:
return json.dumps(value if value is not None else {}, ensure_ascii=False)
except Exception:
return "{}"
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
return {
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
"type": "function",
"function": {
"name": str(tool.get("name") or "tool"),
"arguments": _json_string(tool.get("input")),
},
}
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
return {
"type": "tool_use",
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
"name": str(tool.get("name") or "tool"),
"input": tool.get("input") if tool.get("input") is not None else {},
}
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
if "result" not in tool:
return None
result = tool.get("result")
if isinstance(result, str):
content: Any = result
else:
content = _json_string(result)
return {
"type": "tool_result",
"tool_use_id": str(tool.get("id") or ""),
"content": content,
}
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)]) @app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
p = _require_pool() p = _require_pool()
@@ -485,7 +558,37 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
is_reply=is_reply, is_reply=is_reply,
out_meta=_meta, out_meta=_meta,
): ):
completion_tokens_holder["n"] += estimate_tokens(chunk) if _stream_event_type(chunk) == "tool":
tool = _stream_tool_event(chunk)
if not tool:
continue
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [
{
"index": 0,
**_openai_tool_call(tool),
}
]
},
"finish_reason": None,
}
],
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
continue
text = _stream_text(chunk)
if not text:
continue
completion_tokens_holder["n"] += estimate_tokens(text)
payload = { payload = {
"id": completion_id, "id": completion_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@@ -494,7 +597,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": {"content": chunk}, "delta": {"content": text},
"finish_reason": None, "finish_reason": None,
} }
], ],
@@ -596,6 +699,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
sid = result.get("sessionId") sid = result.get("sessionId")
if sid: if sid:
await session_cache.put(write_key, sid, inst.name) await session_cache.put(write_key, sid, inst.name)
tool_events = result.get("toolEvents") or []
message_content = result.get("text") or ""
tool_calls: list[dict[str, Any]] = []
if isinstance(tool_events, list):
for item in tool_events:
if isinstance(item, dict):
tool_calls.append(_openai_tool_call(item))
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}", id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()), created=int(time.time()),
@@ -604,10 +714,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ChatCompletionChoice( ChatCompletionChoice(
index=0, index=0,
finish_reason="stop", finish_reason="stop",
message={"role": "assistant", "content": result.get("text") or ""}, message={
"role": "assistant",
"content": message_content,
"tool_calls": tool_calls or None,
},
) )
], ],
) )
data = response.model_dump() data = response.model_dump()
data["latency"] = { data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"), "first_token_ms": result.get("firstTokenLatencyMs"),
@@ -810,6 +925,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False success = False
block_index = 0
text_block_open = False
try: try:
# 1) message_start — Anthropic SDKs read this first to get # 1) message_start — Anthropic SDKs read this first to get
# the message envelope (id/model/initial usage). # the message envelope (id/model/initial usage).
@@ -833,17 +950,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
} }
yield _sse("message_start", start_payload) yield _sse("message_start", start_payload)
# 2) content_block_start for a single text block (index 0).
yield _sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
# 3) content_block_delta stream of text tokens.
async for chunk in _inst.client.chat_stream( async for chunk in _inst.client.chat_stream(
prompt, prompt,
model, model,
@@ -852,23 +958,80 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
is_reply=is_reply, is_reply=is_reply,
out_meta=_meta, out_meta=_meta,
): ):
if not chunk: if _stream_event_type(chunk) == "tool":
if text_block_open:
yield _sse(
"content_block_stop",
{"type": "content_block_stop", "index": block_index},
)
block_index += 1
text_block_open = False
tool = _stream_tool_event(chunk)
if not tool:
continue
tool_use_block = _anthropic_tool_use_block(tool)
yield _sse(
"content_block_start",
{
"type": "content_block_start",
"index": block_index,
"content_block": tool_use_block,
},
)
yield _sse(
"content_block_stop",
{"type": "content_block_stop", "index": block_index},
)
block_index += 1
tool_result_block = _anthropic_tool_result_block(tool)
if tool_result_block is not None:
yield _sse(
"content_block_start",
{
"type": "content_block_start",
"index": block_index,
"content_block": tool_result_block,
},
)
yield _sse(
"content_block_stop",
{"type": "content_block_stop", "index": block_index},
)
block_index += 1
continue continue
completion_tokens_holder["n"] += estimate_tokens(chunk)
text = _stream_text(chunk)
if not text:
continue
completion_tokens_holder["n"] += estimate_tokens(text)
if not text_block_open:
yield _sse(
"content_block_start",
{
"type": "content_block_start",
"index": block_index,
"content_block": {"type": "text", "text": ""},
},
)
text_block_open = True
yield _sse( yield _sse(
"content_block_delta", "content_block_delta",
{ {
"type": "content_block_delta", "type": "content_block_delta",
"index": 0, "index": block_index,
"delta": {"type": "text_delta", "text": chunk}, "delta": {"type": "text_delta", "text": text},
}, },
) )
# 4) content_block_stop closes the single text block. if text_block_open:
yield _sse( yield _sse(
"content_block_stop", "content_block_stop",
{"type": "content_block_stop", "index": 0}, {"type": "content_block_stop", "index": block_index},
) )
# 5) message_delta carries the terminal stop_reason and # 5) message_delta carries the terminal stop_reason and
# the final cumulative output_tokens count. # the final cumulative output_tokens count.
@@ -972,12 +1135,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if sid: if sid:
await session_cache.put(write_key, sid, inst.name) await session_cache.put(write_key, sid, inst.name)
content_blocks: list[dict[str, Any]] = []
if text:
content_blocks.append({"type": "text", "text": text})
tool_events = result.get("toolEvents") or []
if isinstance(tool_events, list):
for item in tool_events:
if not isinstance(item, dict):
continue
content_blocks.append(_anthropic_tool_use_block(item))
tool_result = _anthropic_tool_result_block(item)
if tool_result is not None:
content_blocks.append(tool_result)
response_body: dict = { response_body: dict = {
"id": message_id, "id": message_id,
"type": "message", "type": "message",
"role": "assistant", "role": "assistant",
"model": model, "model": model,
"content": [{"type": "text", "text": text}], "content": content_blocks,
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens), "stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
"stop_sequence": None, "stop_sequence": None,
"usage": { "usage": {

View File

@@ -0,0 +1,292 @@
from __future__ import annotations
import json
import sys
import types
import unittest
from unittest.mock import AsyncMock, patch
# app.main imports playwright via auto_login; tests don't exercise that path.
# Inject a lightweight stub so unit tests run without installing playwright.
_playwright = types.ModuleType("playwright")
_playwright_async = types.ModuleType("playwright.async_api")
class _StubPlaywrightTimeoutError(Exception):
pass
async def _stub_async_playwright():
raise RuntimeError("playwright is stubbed in unit tests")
_playwright_async.TimeoutError = _StubPlaywrightTimeoutError
_playwright_async.async_playwright = _stub_async_playwright
sys.modules.setdefault("playwright", _playwright)
sys.modules.setdefault("playwright.async_api", _playwright_async)
from starlette.requests import Request
from app.anthropic_schema import AnthropicMessagesRequest
from app.openai_schema import ChatCompletionsRequest
import app.main as main
class _FakeTicket:
def __init__(self) -> None:
self.released = False
def release(self) -> None:
self.released = True
class _FakeGuard:
def __init__(self) -> None:
self.in_flight = 0
async def try_acquire(self) -> _FakeTicket:
return _FakeTicket()
class _FakeClient:
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
self._stream_events = stream_events
self._complete_result = complete_result
async def query_models(self) -> dict:
return {
"chat": [
{
"key": "org_auto",
"displayName": "Auto",
}
]
}
async def chat_complete(self, *args, **kwargs) -> dict:
return self._complete_result
async def chat_stream(self, *args, **kwargs):
out_meta = kwargs.get("out_meta")
if isinstance(out_meta, dict):
out_meta["session_id"] = "sess-stream"
for event in self._stream_events:
yield event
class _FakeInstance:
def __init__(self, client: _FakeClient) -> None:
self.name = "inst-test"
self.client = client
self.in_flight = 0
class _FakePool:
def __init__(self, inst: _FakeInstance) -> None:
self._inst = inst
def pick(self, affinity_key: str | None = None) -> _FakeInstance:
return self._inst
def _make_request(path: str, headers: dict[str, str] | None = None) -> Request:
header_pairs = []
for k, v in (headers or {}).items():
header_pairs.append((k.lower().encode("latin-1"), v.encode("latin-1")))
scope = {
"type": "http",
"http_version": "1.1",
"method": "POST",
"scheme": "http",
"path": path,
"raw_path": path.encode("latin-1"),
"query_string": b"",
"headers": header_pairs,
"client": ("testclient", 12345),
"server": ("testserver", 80),
"root_path": "",
}
return Request(scope)
async def _collect_stream(response) -> str:
chunks: list[str] = []
async for part in response.body_iterator:
if isinstance(part, bytes):
chunks.append(part.decode("utf-8"))
else:
chunks.append(str(part))
return "".join(chunks)
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "done",
"toolEvents": [
{
"id": "call_123",
"name": "search_docs",
"input": {"query": "gateway"},
"result": {"ok": True},
}
],
"sessionId": "sess-1",
"firstTokenLatencyMs": 12,
"totalLatencyMs": 34,
},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
payload = json.loads(response.body)
message = payload["choices"][0]["message"]
self.assertEqual(message["content"], "done")
self.assertIsInstance(message["tool_calls"], list)
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
self.assertEqual(
json.loads(message["tool_calls"][0]["function"]["arguments"]),
{"query": "gateway"},
)
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"id": "call_stream_1",
"name": "read_file",
"input": {"path": "README.md"},
},
},
{"type": "text", "text": "hello"},
],
complete_result={},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=True,
stream_options={"include_usage": True},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
body = await _collect_stream(response)
self.assertIn('"tool_calls"', body)
self.assertIn('"content": "hello"', body)
self.assertIn('"usage"', body)
self.assertIn("data: [DONE]", body)
async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "ok",
"toolEvents": [
{
"id": "toolu_1",
"name": "lookup",
"input": {"k": "v"},
"result": {"value": 1},
}
],
"sessionId": "sess-2",
},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
payload = json.loads(response.body)
types = [item["type"] for item in payload["content"]]
self.assertEqual(types, ["text", "tool_use", "tool_result"])
self.assertEqual(payload["content"][1]["name"], "lookup")
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"id": "toolu_stream_1",
"name": "read",
"input": {"file": "a.txt"},
"result": "done",
},
},
{"type": "text", "text": "world"},
],
complete_result={},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
body = await _collect_stream(response)
self.assertIn("event: message_start", body)
self.assertIn('"type": "tool_use"', body)
self.assertIn('"type": "tool_result"', body)
self.assertIn('"type": "text_delta"', body)
self.assertIn("event: message_stop", body)
if __name__ == "__main__":
unittest.main()