feat: bridge Lingma tool events to OpenAI/Anthropic responses
Add structured tool event propagation from Lingma stream/finish metadata and map it to OpenAI tool_calls and Anthropic tool_use/tool_result in both streaming and non-streaming responses. Add focused bridge tests and update docs/design notes to match current behavior. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
95
CLAUDE.md
Normal file
95
CLAUDE.md
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Primary docs to read first
|
||||||
|
- `README.md` (runtime commands, env model, API examples)
|
||||||
|
- `DESIGN.md` (architecture decisions, module boundaries, request lifecycle)
|
||||||
|
- `.env.example` (authoritative env var reference)
|
||||||
|
|
||||||
|
No Cursor/Copilot rule files were found in this repo (`.cursorrules`, `.cursor/rules/`, `.github/copilot-instructions.md`).
|
||||||
|
|
||||||
|
## Common development commands
|
||||||
|
|
||||||
|
### Start locally
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
uvicorn app.main:app --reload --port 8317
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start with Docker Compose
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
mkdir -p data secrets
|
||||||
|
docker compose up -d --build
|
||||||
|
docker compose logs -f
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run tests
|
||||||
|
```bash
|
||||||
|
# current focused suite
|
||||||
|
python3 -m unittest tests/test_tool_call_bridge.py
|
||||||
|
|
||||||
|
# discover all unittest tests under tests/
|
||||||
|
python3 -m unittest discover -s tests -p "test_*.py"
|
||||||
|
|
||||||
|
# run a single test method
|
||||||
|
python3 -m unittest tests.test_tool_call_bridge.ToolCallBridgeTests.test_openai_non_stream_bridges_tool_calls
|
||||||
|
```
|
||||||
|
|
||||||
|
### Smoke-check running gateway
|
||||||
|
```bash
|
||||||
|
API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1)
|
||||||
|
curl -s http://127.0.0.1:8317/healthz
|
||||||
|
curl -s http://127.0.0.1:8317/v1/models -H "Authorization: Bearer $API_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Linting/type-checking status
|
||||||
|
- There is currently no repo-configured lint/type command (no `ruff`/`flake8`/`mypy` config found).
|
||||||
|
- Do not invent tooling commands; if linting is needed, add tooling in a dedicated change first.
|
||||||
|
|
||||||
|
## Architecture (big picture)
|
||||||
|
|
||||||
|
### What this service is
|
||||||
|
A FastAPI gateway that fronts Lingma and exposes:
|
||||||
|
- OpenAI-compatible API (`/v1/models`, `/v1/chat/completions`)
|
||||||
|
- Anthropic Messages-compatible API (`/v1/messages`, `/v1/messages/count_tokens`)
|
||||||
|
|
||||||
|
Both protocols share the same backend pool, backpressure guard, stats, and session reuse logic.
|
||||||
|
|
||||||
|
### Request lifecycle (important for most changes)
|
||||||
|
1. Authenticate request (`app/auth.py`)
|
||||||
|
2. Normalize inbound protocol payload to internal message shape (`openai_schema.py` / `anthropic_schema.py`)
|
||||||
|
3. Session-cache lookup (`app/session_cache.py`) for prefix-based reuse
|
||||||
|
4. Pick backend instance (`app/lingma_pool.py`) with affinity + least-in-flight
|
||||||
|
5. Acquire concurrency ticket (`app/concurrency.py`)
|
||||||
|
6. Call Lingma via websocket/LSP client (`app/lingma_client.py`)
|
||||||
|
7. Map upstream result/stream back to wire protocol in `app/main.py`
|
||||||
|
8. Record stats and release ticket (including stream-finally paths)
|
||||||
|
|
||||||
|
### Core module boundaries
|
||||||
|
- `app/main.py`: API entrypoint + orchestration + wire-format adapters
|
||||||
|
- `app/lingma_pool.py`: multi-instance lifecycle, selection, health-aware fallback
|
||||||
|
- `app/lingma_client.py`: subprocess + LSP-over-WebSocket transport to Lingma
|
||||||
|
- `app/session_cache.py`: LRU+TTL cache of conversation-prefix -> upstream session id (+ instance binding)
|
||||||
|
- `app/concurrency.py`: in-flight guard and queue timeout/backpressure behavior
|
||||||
|
- `app/stats.py`: usage counters and Prometheus text
|
||||||
|
|
||||||
|
### Protocol-specific notes
|
||||||
|
- Anthropic and OpenAI endpoints are separate adapters over shared internals.
|
||||||
|
- Response-side tool bridge is implemented: upstream Lingma tool events are surfaced as:
|
||||||
|
- OpenAI: `tool_calls` (stream + non-stream)
|
||||||
|
- Anthropic: `tool_use` / `tool_result` blocks (stream + non-stream)
|
||||||
|
- Request-side `tools` / `tool_choice` are accepted by schemas but not forwarded to Lingma.
|
||||||
|
|
||||||
|
### Operational invariants to preserve
|
||||||
|
- One request must stay on one Lingma instance for session continuity.
|
||||||
|
- Session cache entries include instance identity; invalidate on unhealthy instance mismatch.
|
||||||
|
- Streaming paths must always release in-flight tickets in `finally`.
|
||||||
|
- Multi-instance mode must use isolated workdirs per instance.
|
||||||
|
|
||||||
|
### Deployment/runtime model
|
||||||
|
- Container startup runs `python /app/app/bootstrap_lingma.py` before uvicorn.
|
||||||
|
- Compose mounts:
|
||||||
|
- `./data -> /app/data` (persistent Lingma binary/cache/workdirs)
|
||||||
|
- `./secrets -> /secrets:ro` (session bundles, secrets)
|
||||||
@@ -47,7 +47,8 @@
|
|||||||
|
|
||||||
- **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。
|
- **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。
|
||||||
- **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。
|
- **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。
|
||||||
- **完整 function calling / tools**:OpenAI schema 里保留了字段,但目前不透传给 Lingma(Lingma 侧没有等价能力)。
|
- **请求侧完整 function calling / tools 透传**:OpenAI schema 里保留了字段,但目前不会把 `tools`/`tool_choice` 透传给 Lingma(上游无等价输入协议)。
|
||||||
|
- **响应侧工具事件桥接**:若 Lingma 上游产出 tool 事件,网关会向 OpenAI 输出 `tool_calls`,向 Anthropic 输出 `tool_use` / `tool_result`(stream + non-stream)。
|
||||||
- **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。
|
- **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -591,7 +592,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进
|
|||||||
| 需求 | 改哪些文件 | 关键入口 |
|
| 需求 | 改哪些文件 | 关键入口 |
|
||||||
|---|---|---|
|
|---|---|---|
|
||||||
| 加一个新的 OpenAI 端点(如 embeddings) | `main.py`, `openai_schema.py` | 仿照 `v1_models` 加 `@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` |
|
| 加一个新的 OpenAI 端点(如 embeddings) | `main.py`, `openai_schema.py` | 仿照 `v1_models` 加 `@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` |
|
||||||
| 扩展 Anthropic 端点(如 count_tokens / tool_use 贯通) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;tool_use 需要 Lingma 上游支持,payload 转发点在 `chat_stream` / `chat_complete` |
|
| 扩展 Anthropic 端点(如 count_tokens / tool_use 相关能力) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;响应侧 `tool_use/tool_result` 桥接已支持,若要请求侧 tools 透传仍需改 `lingma_client.py` payload |
|
||||||
| 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin |
|
| 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin |
|
||||||
| 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 |
|
| 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 |
|
||||||
| 增加限流(按 api_key 配额) | `concurrency.py` 加 `PerKeyGuard`;`main.py` 在 `chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) |
|
| 增加限流(按 api_key 配额) | `concurrency.py` 加 `PerKeyGuard`;`main.py` 在 `chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) |
|
||||||
@@ -627,7 +628,7 @@ uvicorn app.main:app --reload --port 8317
|
|||||||
| 标签 | 描述 | 影响 | 计划 |
|
| 标签 | 描述 | 影响 | 计划 |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 |
|
| D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 |
|
||||||
| D3 | 无单元测试骨架 | 重构要靠 deploy 验证 | 想加 CI 时优先补 |
|
| D3 | 已有基础单测覆盖 tool-call bridge(OpenAI/Anthropic,stream + non-stream),但整体测试矩阵仍不完整 | 回归仍依赖手工验证与定向测试 | 后续补充会话复用、背压、鉴权和异常路径用例 |
|
||||||
| Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint,涉及数据迁移,谨慎推进 |
|
| Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint,涉及数据迁移,谨慎推进 |
|
||||||
| ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 |
|
| ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 |
|
||||||
| Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback,但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 |
|
| Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback,但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 |
|
||||||
|
|||||||
@@ -221,7 +221,8 @@ curl -N http://127.0.0.1:8317/v1/messages \
|
|||||||
说明:
|
说明:
|
||||||
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。
|
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。
|
||||||
- **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。
|
- **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。
|
||||||
- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision);`tool_use` / `tool_result` 会以纯文本形式保留语义。
|
- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision)。
|
||||||
|
- **工具事件桥接**:当 Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(含 stream/non-stream)和 Anthropic `tool_use`/`tool_result` blocks(含 stream/non-stream);但请求侧 `tools`/`tool_choice` 仍不会透传到 Lingma。
|
||||||
- **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
|
- **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
|
||||||
|
|
||||||
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`)
|
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import subprocess
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncIterator, Callable, Optional
|
from typing import Any, AsyncIterator, Callable, Optional
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
@@ -103,6 +103,58 @@ class LspWsRpcClient:
|
|||||||
self._on_disconnect = on_disconnect
|
self._on_disconnect = on_disconnect
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_event(params: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
candidates: list[dict[str, Any]] = []
|
||||||
|
if isinstance(params.get("toolCall"), dict):
|
||||||
|
candidates.append(params["toolCall"])
|
||||||
|
if isinstance(params.get("tool_call"), dict):
|
||||||
|
candidates.append(params["tool_call"])
|
||||||
|
if isinstance(params.get("tool"), dict):
|
||||||
|
candidates.append(params["tool"])
|
||||||
|
data = params.get("data")
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if isinstance(data.get("toolCall"), dict):
|
||||||
|
candidates.append(data["toolCall"])
|
||||||
|
if isinstance(data.get("tool_call"), dict):
|
||||||
|
candidates.append(data["tool_call"])
|
||||||
|
if isinstance(data.get("tool"), dict):
|
||||||
|
candidates.append(data["tool"])
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw = candidates[0]
|
||||||
|
tool_id = (
|
||||||
|
raw.get("toolCallId")
|
||||||
|
or raw.get("tool_call_id")
|
||||||
|
or raw.get("id")
|
||||||
|
or params.get("toolCallId")
|
||||||
|
or params.get("tool_call_id")
|
||||||
|
)
|
||||||
|
name = raw.get("name") or raw.get("toolName") or raw.get("tool_name")
|
||||||
|
call_input = raw.get("input")
|
||||||
|
if call_input is None:
|
||||||
|
call_input = raw.get("arguments")
|
||||||
|
if call_input is None:
|
||||||
|
call_input = raw.get("args")
|
||||||
|
|
||||||
|
result_payload = raw.get("result")
|
||||||
|
if result_payload is None:
|
||||||
|
result_payload = params.get("result")
|
||||||
|
if result_payload is None and isinstance(data, dict):
|
||||||
|
result_payload = data.get("result")
|
||||||
|
|
||||||
|
if not tool_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(tool_id),
|
||||||
|
"name": str(name or "tool"),
|
||||||
|
"input": call_input if call_input is not None else {},
|
||||||
|
"result": result_payload,
|
||||||
|
}
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||||
|
|
||||||
@@ -185,7 +237,16 @@ class LspWsRpcClient:
|
|||||||
stream["parts"].append(text)
|
stream["parts"].append(text)
|
||||||
if stream["first_chunk_at"] is None:
|
if stream["first_chunk_at"] is None:
|
||||||
stream["first_chunk_at"] = time.monotonic()
|
stream["first_chunk_at"] = time.monotonic()
|
||||||
stream["chunks"].put_nowait(text)
|
stream["chunks"].put_nowait({"type": "text", "text": text})
|
||||||
|
|
||||||
|
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve"}:
|
||||||
|
req_id = params.get("requestId")
|
||||||
|
stream = self._chat_streams.get(req_id)
|
||||||
|
if stream is not None:
|
||||||
|
tool_event = self._extract_tool_event(params)
|
||||||
|
if tool_event is not None:
|
||||||
|
stream["tool_events"].append(tool_event)
|
||||||
|
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
|
||||||
|
|
||||||
if method == "chat/finish":
|
if method == "chat/finish":
|
||||||
req_id = params.get("requestId")
|
req_id = params.get("requestId")
|
||||||
@@ -224,6 +285,7 @@ class LspWsRpcClient:
|
|||||||
"chunks": asyncio.Queue(),
|
"chunks": asyncio.Queue(),
|
||||||
"done": asyncio.Event(),
|
"done": asyncio.Event(),
|
||||||
"finish": None,
|
"finish": None,
|
||||||
|
"tool_events": [],
|
||||||
"started_at": time.monotonic(),
|
"started_at": time.monotonic(),
|
||||||
"first_chunk_at": None,
|
"first_chunk_at": None,
|
||||||
"finish_at": None,
|
"finish_at": None,
|
||||||
@@ -239,7 +301,7 @@ class LspWsRpcClient:
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
stream["chunks"].put_nowait(None)
|
stream["chunks"].put_nowait(None)
|
||||||
|
|
||||||
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]:
|
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[dict[str, Any]]:
|
||||||
stream = self._chat_streams.get(request_id)
|
stream = self._chat_streams.get(request_id)
|
||||||
if stream is None:
|
if stream is None:
|
||||||
return
|
return
|
||||||
@@ -266,6 +328,7 @@ class LspWsRpcClient:
|
|||||||
"finish": stream.get("finish") or {},
|
"finish": stream.get("finish") or {},
|
||||||
"firstTokenLatencyMs": first_ms,
|
"firstTokenLatencyMs": first_ms,
|
||||||
"totalLatencyMs": total_ms,
|
"totalLatencyMs": total_ms,
|
||||||
|
"toolEvents": stream.get("tool_events") or [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -722,8 +785,12 @@ class LingmaGatewayClient:
|
|||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
is_reply: bool = False,
|
is_reply: bool = False,
|
||||||
out_meta: dict | None = None,
|
out_meta: dict | None = None,
|
||||||
) -> AsyncIterator[str]:
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
"""Stream `chat/answer` chunks.
|
"""Stream chat events.
|
||||||
|
|
||||||
|
Yields structured events:
|
||||||
|
* {"type": "text", "text": "..."}
|
||||||
|
* {"type": "tool", "tool": {...}}
|
||||||
|
|
||||||
If `out_meta` is provided, the final `chat/finish` payload's sessionId
|
If `out_meta` is provided, the final `chat/finish` payload's sessionId
|
||||||
(and the raw finish dict) is written into it when the stream ends or is
|
(and the raw finish dict) is written into it when the stream ends or is
|
||||||
@@ -739,10 +806,10 @@ class LingmaGatewayClient:
|
|||||||
self.rpc.create_stream(request_id)
|
self.rpc.create_stream(request_id)
|
||||||
try:
|
try:
|
||||||
await self._kick_chat_ask(payload)
|
await self._kick_chat_ask(payload)
|
||||||
async for chunk in self.rpc.consume_stream(
|
async for event in self.rpc.consume_stream(
|
||||||
request_id, timeout=max(60.0, self.rpc_timeout + 60.0)
|
request_id, timeout=max(60.0, self.rpc_timeout + 60.0)
|
||||||
):
|
):
|
||||||
yield chunk
|
yield event
|
||||||
finally:
|
finally:
|
||||||
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
||||||
if out_meta is not None:
|
if out_meta is not None:
|
||||||
@@ -753,6 +820,7 @@ class LingmaGatewayClient:
|
|||||||
out_meta["finish"] = finish
|
out_meta["finish"] = finish
|
||||||
out_meta["request_id"] = request_id
|
out_meta["request_id"] = request_id
|
||||||
out_meta["chars"] = len(stream_result.get("text") or "")
|
out_meta["chars"] = len(stream_result.get("text") or "")
|
||||||
|
out_meta["tool_events"] = stream_result.get("toolEvents") or []
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self.rpc.pop_stream(request_id)
|
self.rpc.pop_stream(request_id)
|
||||||
|
|||||||
224
app/main.py
224
app/main.py
@@ -6,6 +6,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
@@ -350,6 +351,78 @@ def _include_usage(stream_options: dict | None) -> bool:
|
|||||||
return bool(stream_options.get("include_usage"))
|
return bool(stream_options.get("include_usage"))
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_event_type(event: Any) -> str:
|
||||||
|
if isinstance(event, dict):
|
||||||
|
t = event.get("type")
|
||||||
|
if t in {"text", "tool"}:
|
||||||
|
return t
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_text(event: Any) -> str:
|
||||||
|
if isinstance(event, dict):
|
||||||
|
if event.get("type") == "text":
|
||||||
|
text = event.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
return ""
|
||||||
|
if isinstance(event, str):
|
||||||
|
return event
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_tool_event(event: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(event, dict) and event.get("type") == "tool":
|
||||||
|
tool = event.get("tool")
|
||||||
|
if isinstance(tool, dict):
|
||||||
|
return tool
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _json_string(value: Any) -> str:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return json.dumps(value if value is not None else {}, ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": str(tool.get("name") or "tool"),
|
||||||
|
"arguments": _json_string(tool.get("input")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
|
||||||
|
"name": str(tool.get("name") or "tool"),
|
||||||
|
"input": tool.get("input") if tool.get("input") is not None else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if "result" not in tool:
|
||||||
|
return None
|
||||||
|
result = tool.get("result")
|
||||||
|
if isinstance(result, str):
|
||||||
|
content: Any = result
|
||||||
|
else:
|
||||||
|
content = _json_string(result)
|
||||||
|
return {
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": str(tool.get("id") or ""),
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
||||||
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||||
p = _require_pool()
|
p = _require_pool()
|
||||||
@@ -485,7 +558,37 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
out_meta=_meta,
|
out_meta=_meta,
|
||||||
):
|
):
|
||||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
if _stream_event_type(chunk) == "tool":
|
||||||
|
tool = _stream_tool_event(chunk)
|
||||||
|
if not tool:
|
||||||
|
continue
|
||||||
|
payload = {
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
**_openai_tool_call(tool),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = _stream_text(chunk)
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||||
payload = {
|
payload = {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@@ -494,7 +597,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {"content": chunk},
|
"delta": {"content": text},
|
||||||
"finish_reason": None,
|
"finish_reason": None,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -596,6 +699,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
sid = result.get("sessionId")
|
sid = result.get("sessionId")
|
||||||
if sid:
|
if sid:
|
||||||
await session_cache.put(write_key, sid, inst.name)
|
await session_cache.put(write_key, sid, inst.name)
|
||||||
|
tool_events = result.get("toolEvents") or []
|
||||||
|
message_content = result.get("text") or ""
|
||||||
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
if isinstance(tool_events, list):
|
||||||
|
for item in tool_events:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
tool_calls.append(_openai_tool_call(item))
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -604,10 +714,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
ChatCompletionChoice(
|
ChatCompletionChoice(
|
||||||
index=0,
|
index=0,
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
message={"role": "assistant", "content": result.get("text") or ""},
|
message={
|
||||||
|
"role": "assistant",
|
||||||
|
"content": message_content,
|
||||||
|
"tool_calls": tool_calls or None,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
data = response.model_dump()
|
data = response.model_dump()
|
||||||
data["latency"] = {
|
data["latency"] = {
|
||||||
"first_token_ms": result.get("firstTokenLatencyMs"),
|
"first_token_ms": result.get("firstTokenLatencyMs"),
|
||||||
@@ -810,6 +925,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
|
|
||||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||||
success = False
|
success = False
|
||||||
|
block_index = 0
|
||||||
|
text_block_open = False
|
||||||
try:
|
try:
|
||||||
# 1) message_start — Anthropic SDKs read this first to get
|
# 1) message_start — Anthropic SDKs read this first to get
|
||||||
# the message envelope (id/model/initial usage).
|
# the message envelope (id/model/initial usage).
|
||||||
@@ -833,17 +950,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
}
|
}
|
||||||
yield _sse("message_start", start_payload)
|
yield _sse("message_start", start_payload)
|
||||||
|
|
||||||
# 2) content_block_start for a single text block (index 0).
|
|
||||||
yield _sse(
|
|
||||||
"content_block_start",
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) content_block_delta stream of text tokens.
|
|
||||||
async for chunk in _inst.client.chat_stream(
|
async for chunk in _inst.client.chat_stream(
|
||||||
prompt,
|
prompt,
|
||||||
model,
|
model,
|
||||||
@@ -852,23 +958,80 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
out_meta=_meta,
|
out_meta=_meta,
|
||||||
):
|
):
|
||||||
if not chunk:
|
if _stream_event_type(chunk) == "tool":
|
||||||
|
if text_block_open:
|
||||||
|
yield _sse(
|
||||||
|
"content_block_stop",
|
||||||
|
{"type": "content_block_stop", "index": block_index},
|
||||||
|
)
|
||||||
|
block_index += 1
|
||||||
|
text_block_open = False
|
||||||
|
|
||||||
|
tool = _stream_tool_event(chunk)
|
||||||
|
if not tool:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_use_block = _anthropic_tool_use_block(tool)
|
||||||
|
yield _sse(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": block_index,
|
||||||
|
"content_block": tool_use_block,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _sse(
|
||||||
|
"content_block_stop",
|
||||||
|
{"type": "content_block_stop", "index": block_index},
|
||||||
|
)
|
||||||
|
block_index += 1
|
||||||
|
|
||||||
|
tool_result_block = _anthropic_tool_result_block(tool)
|
||||||
|
if tool_result_block is not None:
|
||||||
|
yield _sse(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": block_index,
|
||||||
|
"content_block": tool_result_block,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield _sse(
|
||||||
|
"content_block_stop",
|
||||||
|
{"type": "content_block_stop", "index": block_index},
|
||||||
|
)
|
||||||
|
block_index += 1
|
||||||
continue
|
continue
|
||||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
|
||||||
|
text = _stream_text(chunk)
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||||
|
if not text_block_open:
|
||||||
|
yield _sse(
|
||||||
|
"content_block_start",
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": block_index,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
text_block_open = True
|
||||||
|
|
||||||
yield _sse(
|
yield _sse(
|
||||||
"content_block_delta",
|
"content_block_delta",
|
||||||
{
|
{
|
||||||
"type": "content_block_delta",
|
"type": "content_block_delta",
|
||||||
"index": 0,
|
"index": block_index,
|
||||||
"delta": {"type": "text_delta", "text": chunk},
|
"delta": {"type": "text_delta", "text": text},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4) content_block_stop closes the single text block.
|
if text_block_open:
|
||||||
yield _sse(
|
yield _sse(
|
||||||
"content_block_stop",
|
"content_block_stop",
|
||||||
{"type": "content_block_stop", "index": 0},
|
{"type": "content_block_stop", "index": block_index},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5) message_delta carries the terminal stop_reason and
|
# 5) message_delta carries the terminal stop_reason and
|
||||||
# the final cumulative output_tokens count.
|
# the final cumulative output_tokens count.
|
||||||
@@ -972,12 +1135,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
if sid:
|
if sid:
|
||||||
await session_cache.put(write_key, sid, inst.name)
|
await session_cache.put(write_key, sid, inst.name)
|
||||||
|
|
||||||
|
content_blocks: list[dict[str, Any]] = []
|
||||||
|
if text:
|
||||||
|
content_blocks.append({"type": "text", "text": text})
|
||||||
|
tool_events = result.get("toolEvents") or []
|
||||||
|
if isinstance(tool_events, list):
|
||||||
|
for item in tool_events:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
content_blocks.append(_anthropic_tool_use_block(item))
|
||||||
|
tool_result = _anthropic_tool_result_block(item)
|
||||||
|
if tool_result is not None:
|
||||||
|
content_blocks.append(tool_result)
|
||||||
|
|
||||||
response_body: dict = {
|
response_body: dict = {
|
||||||
"id": message_id,
|
"id": message_id,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
"content": [{"type": "text", "text": text}],
|
"content": content_blocks,
|
||||||
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
|
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
|
||||||
"stop_sequence": None,
|
"stop_sequence": None,
|
||||||
"usage": {
|
"usage": {
|
||||||
|
|||||||
292
tests/test_tool_call_bridge.py
Normal file
292
tests/test_tool_call_bridge.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
# app.main imports playwright via auto_login; tests don't exercise that path.
|
||||||
|
# Inject a lightweight stub so unit tests run without installing playwright.
|
||||||
|
_playwright = types.ModuleType("playwright")
|
||||||
|
_playwright_async = types.ModuleType("playwright.async_api")
|
||||||
|
|
||||||
|
|
||||||
|
class _StubPlaywrightTimeoutError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def _stub_async_playwright():
|
||||||
|
raise RuntimeError("playwright is stubbed in unit tests")
|
||||||
|
|
||||||
|
|
||||||
|
_playwright_async.TimeoutError = _StubPlaywrightTimeoutError
|
||||||
|
_playwright_async.async_playwright = _stub_async_playwright
|
||||||
|
sys.modules.setdefault("playwright", _playwright)
|
||||||
|
sys.modules.setdefault("playwright.async_api", _playwright_async)
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from app.anthropic_schema import AnthropicMessagesRequest
|
||||||
|
from app.openai_schema import ChatCompletionsRequest
|
||||||
|
import app.main as main
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTicket:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.released = False
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
self.released = True
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeGuard:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.in_flight = 0
|
||||||
|
|
||||||
|
async def try_acquire(self) -> _FakeTicket:
|
||||||
|
return _FakeTicket()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||||
|
self._stream_events = stream_events
|
||||||
|
self._complete_result = complete_result
|
||||||
|
|
||||||
|
async def query_models(self) -> dict:
|
||||||
|
return {
|
||||||
|
"chat": [
|
||||||
|
{
|
||||||
|
"key": "org_auto",
|
||||||
|
"displayName": "Auto",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||||
|
return self._complete_result
|
||||||
|
|
||||||
|
async def chat_stream(self, *args, **kwargs):
|
||||||
|
out_meta = kwargs.get("out_meta")
|
||||||
|
if isinstance(out_meta, dict):
|
||||||
|
out_meta["session_id"] = "sess-stream"
|
||||||
|
for event in self._stream_events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeInstance:
|
||||||
|
def __init__(self, client: _FakeClient) -> None:
|
||||||
|
self.name = "inst-test"
|
||||||
|
self.client = client
|
||||||
|
self.in_flight = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _FakePool:
|
||||||
|
def __init__(self, inst: _FakeInstance) -> None:
|
||||||
|
self._inst = inst
|
||||||
|
|
||||||
|
def pick(self, affinity_key: str | None = None) -> _FakeInstance:
|
||||||
|
return self._inst
|
||||||
|
|
||||||
|
|
||||||
|
def _make_request(path: str, headers: dict[str, str] | None = None) -> Request:
|
||||||
|
header_pairs = []
|
||||||
|
for k, v in (headers or {}).items():
|
||||||
|
header_pairs.append((k.lower().encode("latin-1"), v.encode("latin-1")))
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"http_version": "1.1",
|
||||||
|
"method": "POST",
|
||||||
|
"scheme": "http",
|
||||||
|
"path": path,
|
||||||
|
"raw_path": path.encode("latin-1"),
|
||||||
|
"query_string": b"",
|
||||||
|
"headers": header_pairs,
|
||||||
|
"client": ("testclient", 12345),
|
||||||
|
"server": ("testserver", 80),
|
||||||
|
"root_path": "",
|
||||||
|
}
|
||||||
|
return Request(scope)
|
||||||
|
|
||||||
|
|
||||||
|
async def _collect_stream(response) -> str:
|
||||||
|
chunks: list[str] = []
|
||||||
|
async for part in response.body_iterator:
|
||||||
|
if isinstance(part, bytes):
|
||||||
|
chunks.append(part.decode("utf-8"))
|
||||||
|
else:
|
||||||
|
chunks.append(str(part))
|
||||||
|
return "".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "done",
|
||||||
|
"toolEvents": [
|
||||||
|
{
|
||||||
|
"id": "call_123",
|
||||||
|
"name": "search_docs",
|
||||||
|
"input": {"query": "gateway"},
|
||||||
|
"result": {"ok": True},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"firstTokenLatencyMs": 12,
|
||||||
|
"totalLatencyMs": 34,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
):
|
||||||
|
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
message = payload["choices"][0]["message"]
|
||||||
|
self.assertEqual(message["content"], "done")
|
||||||
|
self.assertIsInstance(message["tool_calls"], list)
|
||||||
|
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
|
||||||
|
self.assertEqual(
|
||||||
|
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||||
|
{"query": "gateway"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"tool": {
|
||||||
|
"id": "call_stream_1",
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {"path": "README.md"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "hello"},
|
||||||
|
],
|
||||||
|
complete_result={},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
):
|
||||||
|
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
body = await _collect_stream(response)
|
||||||
|
|
||||||
|
self.assertIn('"tool_calls"', body)
|
||||||
|
self.assertIn('"content": "hello"', body)
|
||||||
|
self.assertIn('"usage"', body)
|
||||||
|
self.assertIn("data: [DONE]", body)
|
||||||
|
|
||||||
|
async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "ok",
|
||||||
|
"toolEvents": [
|
||||||
|
{
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "lookup",
|
||||||
|
"input": {"k": "v"},
|
||||||
|
"result": {"value": 1},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sessionId": "sess-2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
types = [item["type"] for item in payload["content"]]
|
||||||
|
self.assertEqual(types, ["text", "tool_use", "tool_result"])
|
||||||
|
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||||
|
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||||
|
|
||||||
|
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"tool": {
|
||||||
|
"id": "toolu_stream_1",
|
||||||
|
"name": "read",
|
||||||
|
"input": {"file": "a.txt"},
|
||||||
|
"result": "done",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": "world"},
|
||||||
|
],
|
||||||
|
complete_result={},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
body = await _collect_stream(response)
|
||||||
|
|
||||||
|
self.assertIn("event: message_start", body)
|
||||||
|
self.assertIn('"type": "tool_use"', body)
|
||||||
|
self.assertIn('"type": "tool_result"', body)
|
||||||
|
self.assertIn('"type": "text_delta"', body)
|
||||||
|
self.assertIn("event: message_stop", body)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user