Compare commits
10 Commits
v0.1.0
...
56c57a4901
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56c57a4901 | ||
|
|
df80a86310 | ||
|
|
15cd5e8770 | ||
|
|
63583712a8 | ||
|
|
c67a9c3d61 | ||
|
|
e208025f35 | ||
|
|
3498b81fa2 | ||
|
|
e600bae27c | ||
|
|
5aa7fbfae5 | ||
|
|
1c7b86e2c0 |
@@ -46,6 +46,9 @@ DEFAULT_MODEL=org_auto
|
||||
# 默认模式:chat 或 agent
|
||||
DEFAULT_ASK_MODE=chat
|
||||
|
||||
# 请求侧 tools/tool_choice 透传到 Lingma(默认关闭,开启后可支持工具写文件等场景)
|
||||
TOOL_FORWARD_ENABLED=false
|
||||
|
||||
# 专属域(可选)
|
||||
DEDICATED_DOMAIN_URL=
|
||||
|
||||
|
||||
95
CLAUDE.md
Normal file
95
CLAUDE.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Primary docs to read first
|
||||
- `README.md` (runtime commands, env model, API examples)
|
||||
- `DESIGN.md` (architecture decisions, module boundaries, request lifecycle)
|
||||
- `.env.example` (authoritative env var reference)
|
||||
|
||||
No Cursor/Copilot rule files were found in this repo (`.cursorrules`, `.cursor/rules/`, `.github/copilot-instructions.md`).
|
||||
|
||||
## Common development commands
|
||||
|
||||
### Start locally
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
uvicorn app.main:app --reload --port 8317
|
||||
```
|
||||
|
||||
### Start with Docker Compose
|
||||
```bash
|
||||
cp .env.example .env
|
||||
mkdir -p data secrets
|
||||
docker compose up -d --build
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
### Run tests
|
||||
```bash
|
||||
# current focused suite
|
||||
python3 -m unittest tests/test_tool_call_bridge.py
|
||||
|
||||
# discover all unittest tests under tests/
|
||||
python3 -m unittest discover -s tests -p "test_*.py"
|
||||
|
||||
# run a single test method
|
||||
python3 -m unittest tests.test_tool_call_bridge.ToolCallBridgeTests.test_openai_non_stream_bridges_tool_calls
|
||||
```
|
||||
|
||||
### Smoke-check running gateway
|
||||
```bash
|
||||
API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1)
|
||||
curl -s http://127.0.0.1:8317/healthz
|
||||
curl -s http://127.0.0.1:8317/v1/models -H "Authorization: Bearer $API_KEY"
|
||||
```
|
||||
|
||||
### Linting/type-checking status
|
||||
- There is currently no repo-configured lint/type command (no `ruff`/`flake8`/`mypy` config found).
|
||||
- Do not invent tooling commands; if linting is needed, add tooling in a dedicated change first.
|
||||
|
||||
## Architecture (big picture)
|
||||
|
||||
### What this service is
|
||||
A FastAPI gateway that fronts Lingma and exposes:
|
||||
- OpenAI-compatible API (`/v1/models`, `/v1/chat/completions`)
|
||||
- Anthropic Messages-compatible API (`/v1/messages`, `/v1/messages/count_tokens`)
|
||||
|
||||
Both protocols share the same backend pool, backpressure guard, stats, and session reuse logic.
|
||||
|
||||
### Request lifecycle (important for most changes)
|
||||
1. Authenticate request (`app/auth.py`)
|
||||
2. Normalize inbound protocol payload to internal message shape (`openai_schema.py` / `anthropic_schema.py`)
|
||||
3. Session-cache lookup (`app/session_cache.py`) for prefix-based reuse
|
||||
4. Pick backend instance (`app/lingma_pool.py`) with affinity + least-in-flight
|
||||
5. Acquire concurrency ticket (`app/concurrency.py`)
|
||||
6. Call Lingma via websocket/LSP client (`app/lingma_client.py`)
|
||||
7. Map upstream result/stream back to wire protocol in `app/main.py`
|
||||
8. Record stats and release ticket (including stream-finally paths)
|
||||
|
||||
### Core module boundaries
|
||||
- `app/main.py`: API entrypoint + orchestration + wire-format adapters
|
||||
- `app/lingma_pool.py`: multi-instance lifecycle, selection, health-aware fallback
|
||||
- `app/lingma_client.py`: subprocess + LSP-over-WebSocket transport to Lingma
|
||||
- `app/session_cache.py`: LRU+TTL cache of conversation-prefix -> upstream session id (+ instance binding)
|
||||
- `app/concurrency.py`: in-flight guard and queue timeout/backpressure behavior
|
||||
- `app/stats.py`: usage counters and Prometheus text
|
||||
|
||||
### Protocol-specific notes
|
||||
- Anthropic and OpenAI endpoints are separate adapters over shared internals.
|
||||
- Response-side tool bridge is implemented: upstream Lingma tool events are surfaced as:
|
||||
- OpenAI: `tool_calls` (stream + non-stream)
|
||||
- Anthropic: `tool_use` / `tool_result` blocks (stream + non-stream)
|
||||
- Request-side `tools` / `tool_choice` are accepted by schemas but not forwarded to Lingma.
|
||||
|
||||
### Operational invariants to preserve
|
||||
- One request must stay on one Lingma instance for session continuity.
|
||||
- Session cache entries include instance identity; invalidate on unhealthy instance mismatch.
|
||||
- Streaming paths must always release in-flight tickets in `finally`.
|
||||
- Multi-instance mode must use isolated workdirs per instance.
|
||||
|
||||
### Deployment/runtime model
|
||||
- Container startup runs `python /app/app/bootstrap_lingma.py` before uvicorn.
|
||||
- Compose mounts:
|
||||
- `./data -> /app/data` (persistent Lingma binary/cache/workdirs)
|
||||
- `./secrets -> /secrets:ro` (session bundles, secrets)
|
||||
53
DESIGN.md
53
DESIGN.md
@@ -47,8 +47,9 @@
|
||||
|
||||
- **逆向 Lingma 后端协议**:之前评估过(曾经的"B1 终极方案"),需要反编译二进制,维护成本高、政策风险大,放弃。
|
||||
- **多租户 / 水平扩缩**:单容器即可;真要大规模部署 → 套层反代 + N 个网关副本就够,不在进程内解决。
|
||||
- **完整 function calling / tools**:OpenAI schema 里保留了字段,但目前不透传给 Lingma(Lingma 侧没有等价能力)。
|
||||
- **多模态**:请求里的 image/audio 会被降级成占位符 `[image]` / `[audio]`,因为 Lingma chat 不支持。
|
||||
- **请求侧完整 function calling / tools 语义**:仍不是当前目标;现阶段仅支持 `tools`/`tool_choice` 在 `TOOL_FORWARD_ENABLED` 开关下灰度透传(默认关闭)。
|
||||
- **响应侧工具事件桥接**:若 Lingma 上游产出 tool 事件,网关会向 OpenAI 输出 `tool_calls`,向 Anthropic 输出 `tool_use` / `tool_result`(stream + non-stream)。
|
||||
- **强制工具回退闭环(non-stream)**:当上游未返回 tool 事件且请求为强制 `tool_choice` 时,网关会从文本里解析严格 JSON,合成 OpenAI `tool_calls` 与 Anthropic `tool_use` / `tool_result`。
|
||||
|
||||
---
|
||||
|
||||
@@ -591,7 +592,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进
|
||||
| 需求 | 改哪些文件 | 关键入口 |
|
||||
|---|---|---|
|
||||
| 加一个新的 OpenAI 端点(如 embeddings) | `main.py`, `openai_schema.py` | 仿照 `v1_models` 加 `@app.post("/v1/embeddings", dependencies=[Depends(auth_guard)])` |
|
||||
| 扩展 Anthropic 端点(如 count_tokens / tool_use 贯通) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;tool_use 需要 Lingma 上游支持,payload 转发点在 `chat_stream` / `chat_complete` |
|
||||
| 扩展 Anthropic 端点(如 count_tokens / tool_use 相关能力) | `main.py::v1_messages`, `anthropic_schema.py` | count_tokens 只读:复用 `estimate_tokens`;响应侧 `tool_use/tool_result` 桥接已支持;请求侧 `tools/tool_choice` 透传由 `TOOL_FORWARD_ENABLED` 控制并经 `lingma_client.py` payload 下发 |
|
||||
| 加一种新的实例调度策略(如加权轮询) | `lingma_pool.py::pick()` | 当前是 affinity → least-in-flight → round-robin |
|
||||
| 改认证为 JWT / OAuth | `auth.py` | 三个 `require_*` 函数是全部入口;`main.py` 里只有 `*_guard` 代理 |
|
||||
| 增加限流(按 api_key 配额) | `concurrency.py` 加 `PerKeyGuard`;`main.py` 在 `chat_guard.try_acquire()` 后再来一层 | 注意 ticket 释放顺序(内层先释放) |
|
||||
@@ -599,7 +600,7 @@ FastAPI `lifespan` 退出 → `pool.close()` → 每个 `client.close()` → 进
|
||||
| 改 Prometheus 指标名 | 所有 `prometheus_lines()` 或 `prometheus_text()` | 注意生态兼容;更名要在 README 留 alias |
|
||||
| 接入 Jaeger / OpenTelemetry | `logging_config.py` 加 OTel instrumentation;`main.py::request_id_middleware` 注入 traceid | request_id 可以复用为 span_id |
|
||||
| 加一个 Lingma 新方法调用(比如 code/complete) | `lingma_client.py` 仿照 `query_models`:`await self.ensure_ready(); return await self.rpc.request("code/complete", ...)` | 原始上游响应形态需抓包确认 |
|
||||
| 支持 function calling(假设 Lingma 将来支持) | `openai_schema.py` 已保留 `tools` / `tool_choice` 字段;`lingma_client.py::_build_payload` 加 `extra.tools` | 上游协议 TBD |
|
||||
| 支持 function calling(假设 Lingma 将来支持) | `openai_schema.py` / `anthropic_schema.py` / `main.py` / `lingma_client.py` | 当前仅支持请求侧 `tools/tool_choice` 在开关控制下透传与响应侧桥接;若要完整 function calling 语义仍需按上游协议补齐 |
|
||||
| 多模态穿透 | `openai_schema.py::flatten_content` 不再降级;`lingma_client.py` payload 传 url | 前提:Lingma 支持(目前不支持) |
|
||||
| 换 session_cache 后端(如 Redis) | 实现同样接口的 `RedisSessionCache`,`main.py` 初始化换实现 | 接口是 `get / put / invalidate / stats / prometheus_lines / build_key / enabled`,内存换远端成本不高 |
|
||||
| 多容器副本(水平扩) | 外面套反代 + sticky session(根据 `Authorization` 或 `x-user` 做 hash);session cache 改 Redis | 或直接接受多副本 cache 独立,轻微浪费 KV cache 命中率 |
|
||||
@@ -611,7 +612,8 @@ pip install -r requirements.txt
|
||||
# 在容器外跑,需要自己准备 Lingma 二进制
|
||||
export LINGMA_BIN=/path/to/Lingma
|
||||
export API_KEYS=sk-dev
|
||||
uvicorn app.main:app --reload --port 8317
|
||||
export PORT=8317
|
||||
uvicorn app.main:app --reload --port ${PORT}
|
||||
```
|
||||
|
||||
主要断点位置:
|
||||
@@ -627,7 +629,7 @@ uvicorn app.main:app --reload --port 8317
|
||||
| 标签 | 描述 | 影响 | 计划 |
|
||||
|---|---|---|---|
|
||||
| D1 | `config.py` 还是纯 `dataclass` + `os.getenv`,未迁 `pydantic-settings` | 类型校验靠自己 cast | 低优,收益有限,有精力再做 |
|
||||
| D3 | 无单元测试骨架 | 重构要靠 deploy 验证 | 想加 CI 时优先补 |
|
||||
| D3 | 已有基础单测覆盖 tool-call bridge(OpenAI/Anthropic,stream + non-stream),但整体测试矩阵仍不完整 | 回归仍依赖手工验证与定向测试 | 后续补充会话复用、背压、鉴权和异常路径用例 |
|
||||
| Docker non-root | 容器还是 root 跑 | 容器逃逸时影响宿主 | 需要加 `gosu` + chown entrypoint,涉及数据迁移,谨慎推进 |
|
||||
| ADMIN_TOKEN 轮换 | 没有过期机制,只能重启 | 自用场景不影响 | 接 Vault / sops 时一并做 |
|
||||
| Lingma 版本漂移 | 新版 Lingma 改 LSP 方法或新增必需 cache 文件时会无声崩 | 注入失败会 fallback,但 chat 不回话题型的错误不易定位 | 加一个 `/internal/smoke` 端点做端到端自检 |
|
||||
@@ -707,6 +709,45 @@ uvicorn app.main:app --reload --port 8317
|
||||
| → | `chat/ask` (notify!) | 见 `_build_payload` | 不回 result;通过 server push 下推 |
|
||||
| ← | `chat/answer` | `{requestId, text, content}` | 流式 token |
|
||||
| ← | `chat/finish` | `{requestId, sessionId, ...其它元数据}` | 结束信号,含上游真实 sessionId |
|
||||
| ← | `tool/call/sync` | `{requestId?, toolCallId, toolCallStatus, parameters, results?}` | 工具状态与结果回流 |
|
||||
| ← | `tool/invoke` | `{requestId?, toolCallId, ...}` | 工具调用中间事件(兼容旧链路) |
|
||||
| ← | `tool/call/approve` | `{requestId?, toolCallId, approval, ...}` | 工具审批事件 |
|
||||
| ← | `tool/invokeResult` | `{requestId?, toolCallId, name, success, errorMessage, result}` | 工具执行结果事件 |
|
||||
|
||||
### 9.1 Tool call 监控 SOP(VSCode 真实环境)
|
||||
|
||||
目标:拿到 Lingma 扩展真实 method/字段,避免猜测协议。
|
||||
|
||||
1. 确认入口文件
|
||||
- `~/.vscode/extensions/alibaba-cloud.tongyi-lingma-*/package.json`
|
||||
- 查 `main`(当前是 `dist/extension.js`)
|
||||
|
||||
2. 在发送侧打点
|
||||
- 在 `sendRequest` / `sendNotification` 处记录 method 与参数 keys
|
||||
- 优先写文件,不依赖 console
|
||||
|
||||
3. 在入站 `tool/call/sync` handler 打点
|
||||
- 记录 `toolCallId`、`toolCallStatus`、是否包含 `results`
|
||||
|
||||
4. 用真实交互触发
|
||||
- VSCode 内发起会话并触发工具
|
||||
- 点击 Accept/Reject,观察事件闭环
|
||||
|
||||
5. 验证闭环
|
||||
- `tool/call/sync(pending|processing)`
|
||||
- `tool/call/approve`
|
||||
- `tool/invokeResult`
|
||||
- `tool/call/sync(results)`
|
||||
|
||||
6. 回滚
|
||||
- 用备份文件恢复 `dist/extension.js`
|
||||
- 避免长期携带探针到日常环境
|
||||
|
||||
**建议日志位置**:
|
||||
- `~/.lingma/vscode/sharedClientCache/logs/lingma-probe.log`
|
||||
- `~/.lingma/vscode/sharedClientCache/logs/lingma-extension.log`
|
||||
|
||||
**注意**:优先使用 VSCode,不混用 Cursor 扩展环境;`pipe` 连接模式下,扩展层探针最稳定。
|
||||
|
||||
**`chat/ask` payload 关键字段**:
|
||||
|
||||
|
||||
480
README.md
480
README.md
@@ -1,395 +1,215 @@
|
||||
# Lingma OpenAI Gateway
|
||||
|
||||
把本地 Lingma 插件封装成 OpenAI 兼容接口。任何能调 OpenAI 的客户端(Cursor、Dify、LangChain、curl…)都能直接接入。
|
||||
将 Lingma 封装为 OpenAI / Anthropic 兼容网关,便于现有客户端直接接入。
|
||||
|
||||
**支持:**
|
||||
- OpenAI 兼容:`GET /v1/models` / `POST /v1/chat/completions`(含 SSE 流式) / Bearer 鉴权
|
||||
- **Anthropic 兼容**:`POST /v1/messages`(含 Anthropic SSE 事件流) / `x-api-key` 鉴权
|
||||
- Prometheus / 多账号实例池 / 会话复用(跨两种协议共享) / 免浏览器登录态注入
|
||||
- OpenAI:`/v1/models`、`/v1/chat/completions`(含 stream)
|
||||
- Anthropic:`/v1/messages`、`/v1/messages/count_tokens`(含 stream)
|
||||
- 内置:多实例池、会话复用、Prometheus 指标、登录态 bundle 注入
|
||||
|
||||
> 想看架构、模块划分、设计决策、二开路线图 → 直接读 [`DESIGN.md`](./DESIGN.md)。
|
||||
> 架构设计与二开细节请看 [`DESIGN.md`](./DESIGN.md)。
|
||||
|
||||
---
|
||||
|
||||
## 架构速览
|
||||
## 目录
|
||||
|
||||
```
|
||||
┌─────────────┐ OpenAI 协议 ┌─────────────────────────────────────────┐
|
||||
│ 任意客户端 │ ───────────▶ │ FastAPI (app/main.py) │
|
||||
│ (curl/ │ │ ├─ auth_guard / admin_guard │
|
||||
│ Cursor/ │ │ ├─ chat_guard (InFlightGuard 背压) │
|
||||
│ Dify…) │ │ ├─ SessionCache (LRU+TTL, KV 复用) │
|
||||
└─────────────┘ │ └─ StatsCollector + Prometheus │
|
||||
└────────────────┬────────────────────────┘
|
||||
│ 选实例 (least-in-flight + affinity)
|
||||
┌────────────────▼────────────────────────┐
|
||||
│ LingmaPool (app/lingma_pool.py) │
|
||||
│ ├─ inst-0 inst-1 inst-N … │
|
||||
│ └─ 启动前自动 restore session bundle │
|
||||
└────────────────┬────────────────────────┘
|
||||
│
|
||||
┌───────────────────────┼───────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌────────────────────┐ ┌────────────────────┐ ┌────────────────────┐
|
||||
│ LingmaGatewayClient│ │ … │ │ … │
|
||||
│ (LSP over WS) │ │ │ │ │
|
||||
│ ├─ Popen (PID管理) │ │ │ │ │
|
||||
│ ├─ reconnect loop │ │ │ │ │
|
||||
│ └─ ws://:PORT │ │ │ │ │
|
||||
└──────────┬─────────┘ └────────────────────┘ └────────────────────┘
|
||||
│ spawn + ws
|
||||
┌──────────▼─────────┐
|
||||
│ Lingma 二进制 │
|
||||
│ --workDir /… │
|
||||
└────────────────────┘
|
||||
```
|
||||
1. [5 分钟启动](#5-分钟启动)
|
||||
2. [常用命令](#常用命令)
|
||||
3. [最小 API 示例](#最小-api-示例)
|
||||
4. [部署与更新](#部署与更新)
|
||||
5. [排障速查](#排障速查)
|
||||
6. [文档入口](#文档入口)
|
||||
|
||||
---
|
||||
|
||||
## 一、快速开始
|
||||
## 5 分钟启动
|
||||
|
||||
### 1) 准备配置
|
||||
|
||||
```bash
|
||||
git clone <repo>
|
||||
cd lingma-openai-gateway
|
||||
cp .env.example .env
|
||||
# 至少填 API_KEYS + LINGMA_USERNAME + LINGMA_PASSWORD(或 session bundle)
|
||||
```
|
||||
|
||||
至少配置这些变量(在 `.env`):
|
||||
|
||||
- `API_KEYS`
|
||||
- `LINGMA_USERNAME` / `LINGMA_PASSWORD`(或 `LINGMA_SESSION_BUNDLE(_FILE)`)
|
||||
|
||||
### 2) Docker 启动(推荐)
|
||||
|
||||
```bash
|
||||
mkdir -p data secrets
|
||||
docker compose up -d --build
|
||||
docker compose logs -f # 看到 "Uvicorn running on..." 就 OK
|
||||
docker compose logs -f
|
||||
```
|
||||
|
||||
冒烟测试:
|
||||
### 3) 冒烟检查
|
||||
|
||||
```bash
|
||||
PORT=$(grep '^PORT=' .env | cut -d= -f2)
|
||||
API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1)
|
||||
curl -s http://127.0.0.1:8317/healthz
|
||||
curl -s http://127.0.0.1:8317/v1/models -H "Authorization: Bearer $API_KEY"
|
||||
curl -s http://127.0.0.1:8317/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" \
|
||||
|
||||
curl -s "http://127.0.0.1:${PORT}/healthz"
|
||||
curl -s "http://127.0.0.1:${PORT}/v1/models" \
|
||||
-H "Authorization: Bearer ${API_KEY}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 常用命令
|
||||
|
||||
### 本地开发运行
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
uvicorn app.main:app --reload --port 8317
|
||||
```
|
||||
|
||||
### Docker 常用
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
docker compose logs -f
|
||||
docker compose ps
|
||||
docker compose down
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
# 重点回归套件
|
||||
python3 -m unittest tests/test_tool_call_bridge.py
|
||||
|
||||
# 全量 unittest
|
||||
python3 -m unittest discover -s tests -p "test_*.py"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 最小 API 示例
|
||||
|
||||
先取 key:
|
||||
|
||||
```bash
|
||||
PORT=$(grep '^PORT=' .env | cut -d= -f2)
|
||||
API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1)
|
||||
```
|
||||
|
||||
### OpenAI:非流式
|
||||
|
||||
```bash
|
||||
curl -s "http://127.0.0.1:${PORT}/v1/chat/completions" \
|
||||
-H "Authorization: Bearer ${API_KEY}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model":"org_auto","messages":[{"role":"user","content":"hi"}]}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 二、配置参考
|
||||
|
||||
`.env.example` 是权威说明,这里按主题分组。
|
||||
|
||||
### 2.1 核心
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `HOST` / `PORT` | `0.0.0.0` / `8317` | 网关监听地址与端口 |
|
||||
| `API_KEYS` | — | Bearer key,多个逗号分隔;**留空则 /v1/\* 无鉴权**,启动会 warn |
|
||||
| `LOG_LEVEL` | `INFO` | `DEBUG`/`INFO`/`WARNING`/`ERROR`,日志为结构化 JSON,含 `request_id` |
|
||||
| `DEFAULT_MODEL` | `org_auto` | 模型无法映射时兜底 |
|
||||
| `DEFAULT_ASK_MODE` | `chat` | `chat` 或 `agent`(传 `model: "agent"` 时自动切) |
|
||||
| `DEDICATED_DOMAIN_URL` | — | 企业专属域(可空) |
|
||||
|
||||
### 2.2 权限分层(生产建议全配)
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `ADMIN_TOKEN` | — | `/internal/*` 专属 token;未配置时 fallback 到 `API_KEYS`(兼容);都为空 → 503 |
|
||||
| `METRICS_TOKEN` | — | `/metrics` 专属 token;未配置时 fallback 到 `API_KEYS` |
|
||||
| `METRICS_PUBLIC` | `false` | 显式公开 `/metrics`(仅用于私网采集器) |
|
||||
|
||||
> `ADMIN_TOKEN` / `METRICS_TOKEN` / `API_KEYS` 三者都为空时,`/metrics` 和 `/internal/*` 会返回 503(拒绝裸奔)。
|
||||
|
||||
### 2.3 并发与背压
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `GATEWAY_MAX_IN_FLIGHT` | `4` | 并发上限;`<=0` 表示不限 |
|
||||
| `GATEWAY_QUEUE_TIMEOUT_SEC` | `30` | 排队超时;超时直接返回 `429 + Retry-After` |
|
||||
|
||||
### 2.4 Lingma 进程
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `LINGMA_BIN` | `/app/data/bin/Lingma` | 容器内二进制路径 |
|
||||
| `LINGMA_SOURCE_TYPE` | `marketplace` | `marketplace` 或 `vsix` |
|
||||
| `LINGMA_MARKETPLACE_PUBLISHER` | `Alibaba-Cloud` | Marketplace 发布者 |
|
||||
| `LINGMA_MARKETPLACE_EXTENSION` | `tongyi-lingma` | Marketplace 扩展名 |
|
||||
| `LINGMA_VSIX_URL` | 官方地址 | 兜底 VSIX 下载地址 |
|
||||
| `LINGMA_BOOTSTRAP_ALWAYS` | `true` | 启动时总是尝试刷新二进制 |
|
||||
| `LINGMA_FORCE_REFRESH` | `false` | 强制忽略本地缓存重新下载 |
|
||||
| `LINGMA_WORK_DIR` | `/app/data/.lingma/vscode/sharedClientCache` | 登录态/缓存所在目录 |
|
||||
| `LINGMA_SOCKET_PORT` | `36510` | 单实例模式下的 Lingma WS 端口 |
|
||||
| `LINGMA_STARTUP_TIMEOUT` | `40` | 启动超时秒 |
|
||||
| `LINGMA_RPC_TIMEOUT` | `30` | 单次 RPC 超时秒 |
|
||||
|
||||
### 2.5 多账号 / 多实例池
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `LINGMA_ACCOUNTS` | — | `u1:p1,u2:p2` 或 JSON 数组;配置后每个账号 = 一个独立 Lingma 子进程 |
|
||||
| `LINGMA_INSTANCE_COUNT` | 账号数 | 显式指定实例数;不足账号循环复用并打 warn |
|
||||
| `LINGMA_USERNAME` / `LINGMA_PASSWORD` | — | 单实例兼容模式(仅 `LINGMA_ACCOUNTS` 为空时生效) |
|
||||
|
||||
### 2.6 会话复用(KV cache 优化)
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `SESSION_REUSE_ENABLED` | `true` | 多轮对话命中时只发增量 user 消息 + 复用上游 `sessionId` |
|
||||
| `SESSION_CACHE_MAX_ENTRIES` | `256` | LRU 容量 |
|
||||
| `SESSION_CACHE_TTL_SEC` | `1800` | TTL(秒),避免命中已回收的 session |
|
||||
|
||||
### 2.7 登录态注入(跳过 Playwright)
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `LINGMA_SESSION_BUNDLE` | — | base64 格式的 bundle(inline,适合短字符串) |
|
||||
| `LINGMA_SESSION_BUNDLE_FILE` | — | bundle 文件路径(推荐,避免 env 过长) |
|
||||
|
||||
### 2.8 自动登录
|
||||
|
||||
| 变量 | 默认 | 说明 |
|
||||
|---|---|---|
|
||||
| `AUTO_LOGIN_ENABLED` | `true` | 未登录时自动启 Playwright |
|
||||
| `AUTO_LOGIN_HEADLESS` | `true` | 无头浏览器 |
|
||||
| `AUTO_LOGIN_TIMEOUT` | `180` | 登录超时秒 |
|
||||
| `AUTO_LOGIN_MAX_RETRY` | `2` | 登录失败重试次数 |
|
||||
|
||||
---
|
||||
|
||||
## 三、API 参考
|
||||
|
||||
### 3.1 公共(`API_KEYS`)
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|---|---|---|
|
||||
| GET | `/healthz` | 免鉴权;返回 `ok` / `pool_size` / `pool_ready` / 每实例状态 |
|
||||
| GET | `/v1/models` | OpenAI 兼容;`id` 是 Lingma 原 key,`name` 是可读名 |
|
||||
| POST | `/v1/chat/completions` | OpenAI 兼容;`stream=true` 走 SSE;`model: "agent"` 切 agent 模式 |
|
||||
| POST | `/v1/messages` | **Anthropic Messages 兼容**;`x-api-key` 或 `Authorization: Bearer`;`stream=true` 走 Anthropic 命名事件 SSE |
|
||||
|
||||
**chat 请求示例(非流式)**
|
||||
|
||||
```bash
|
||||
curl -s http://127.0.0.1:8317/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" -H "Content-Type: application/json" \
|
||||
-d '{"model":"dashscope_qmodel","messages":[{"role":"user","content":"你好"}]}'
|
||||
```
|
||||
|
||||
**chat 请求示例(流式 + usage)**
|
||||
|
||||
```bash
|
||||
curl -N http://127.0.0.1:8317/v1/chat/completions \
|
||||
-H "Authorization: Bearer $API_KEY" -H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model":"dashscope_qmodel",
|
||||
"stream":true,
|
||||
"stream_options":{"include_usage":true},
|
||||
"messages":[{"role":"user","content":"介绍一下你自己"}]
|
||||
"model": "org_auto",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Anthropic Messages 示例(非流式)**
|
||||
### OpenAI:流式
|
||||
|
||||
```bash
|
||||
curl -s http://127.0.0.1:8317/v1/messages \
|
||||
-H "x-api-key: $API_KEY" \
|
||||
curl -N "http://127.0.0.1:${PORT}/v1/chat/completions" \
|
||||
-H "Authorization: Bearer ${API_KEY}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "org_auto",
|
||||
"messages": [{"role": "user", "content": "say hi"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### Anthropic:非流式
|
||||
|
||||
```bash
|
||||
curl -s "http://127.0.0.1:${PORT}/v1/messages" \
|
||||
-H "x-api-key: ${API_KEY}" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"max_tokens":256,
|
||||
"system":"你是一个简洁的助手",
|
||||
"messages":[{"role":"user","content":"你好"}]
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
**Anthropic Messages 示例(流式)**
|
||||
### Anthropic:流式
|
||||
|
||||
```bash
|
||||
curl -N http://127.0.0.1:8317/v1/messages \
|
||||
-H "x-api-key: $API_KEY" \
|
||||
curl -N "http://127.0.0.1:${PORT}/v1/messages" \
|
||||
-H "x-api-key: ${API_KEY}" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model":"claude-3-5-sonnet-20241022",
|
||||
"max_tokens":256,
|
||||
"stream":true,
|
||||
"messages":[{"role":"user","content":"写一首四行诗"}]
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "say hi"}],
|
||||
"stream": true
|
||||
}'
|
||||
# 返回 message_start / content_block_start / content_block_delta* /
|
||||
# content_block_stop / message_delta / message_stop
|
||||
```
|
||||
|
||||
说明:
|
||||
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。
|
||||
- **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。
|
||||
- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision);`tool_use` / `tool_result` 会以纯文本形式保留语义。
|
||||
- **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
|
||||
|
||||
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`)
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|---|---|---|
|
||||
| GET | `/metrics` | Prometheus 文本;含池每实例 gauge、并发、session cache 命中率、token 计数 |
|
||||
|
||||
### 3.3 管理(`ADMIN_TOKEN` 或 fallback 到 `API_KEYS`)
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|---|---|---|
|
||||
| GET | `/internal/stats` | JSON:`stats` + `concurrency` + `pool` + `session_cache` |
|
||||
| GET | `/internal/auto-login/status` | 每实例登录态与 auto_login 状态 |
|
||||
| POST | `/internal/auto-login/start?instance=inst-0` | 主动触发某实例登录(可不传,由 pool.pick 选) |
|
||||
| POST | `/internal/session/export?instance=inst-0` | 把已登录实例的 cache 打包成 base64 bundle |
|
||||
| GET | `/internal/models/raw?instance=inst-0` | Lingma 原始 `config/queryModels` 响应(displayName / isReasoning / isVl 等) |
|
||||
|
||||
---
|
||||
|
||||
## 四、常用场景
|
||||
|
||||
### 4.1 多账号池
|
||||
|
||||
```env
|
||||
LINGMA_ACCOUNTS=user1:pass1,user2:pass2,user3:pass3
|
||||
# LINGMA_INSTANCE_COUNT=3 # 不写默认=账号数
|
||||
```
|
||||
|
||||
- 每个账号一个独立 Lingma 子进程 + 独立 `workDir`(`data/.lingma/pool/inst-<i>/`)。
|
||||
- 路由:同 `user` 字段或同 system prompt 的请求**粘性**分到同一实例;其他按**最小在途**分配。
|
||||
- 一个实例挂掉不影响整体,`/healthz.pool_ready` 下降,自动重连。
|
||||
|
||||
### 4.2 跳过 Playwright(session bundle)
|
||||
|
||||
**从已登录实例导出:**
|
||||
### Anthropic:count_tokens
|
||||
|
||||
```bash
|
||||
curl -sS -X POST \
|
||||
-H "Authorization: Bearer $ADMIN_TOKEN" \
|
||||
"http://host:port/internal/session/export" \
|
||||
| jq -r '.bundle_b64' > secrets/lingma-session.b64
|
||||
chmod 600 secrets/lingma-session.b64
|
||||
curl -s "http://127.0.0.1:${PORT}/v1/messages/count_tokens" \
|
||||
-H "x-api-key: ${API_KEY}" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-3-5-sonnet-20241022",
|
||||
"max_tokens": 64,
|
||||
"messages": [{"role": "user", "content": "count me"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**在新部署注入(选一种):**
|
||||
---
|
||||
|
||||
```env
|
||||
# 文件注入(推荐)—— 需要在 docker-compose.yml 挂载 secrets 目录
|
||||
LINGMA_SESSION_BUNDLE_FILE=/secrets/lingma-session.b64
|
||||
## 部署与更新
|
||||
|
||||
# 或 inline(适合小 bundle)
|
||||
LINGMA_SESSION_BUNDLE=H4sIAAAA...
|
||||
### 服务器更新到最新 main
|
||||
|
||||
# 多账号 JSON 模式,每账号独立 bundle
|
||||
LINGMA_ACCOUNTS=[
|
||||
{"username":"u1","password":"p1","session_bundle_file":"/secrets/u1.b64"},
|
||||
{"username":"u2","password":"p2","session_bundle":"H4sIAAAA..."}
|
||||
]
|
||||
```bash
|
||||
cd /root/lingma-openai-gateway
|
||||
git fetch origin
|
||||
git checkout -B main origin/main
|
||||
git reset --hard origin/main
|
||||
git clean -fd
|
||||
docker compose up -d --build
|
||||
docker compose ps
|
||||
```
|
||||
|
||||
**行为保证:**
|
||||
### 健康检查
|
||||
|
||||
- 只在目标 `workDir` 空(`cache/user` 不存在或 empty)时才注入;不会覆盖活跃登录态。
|
||||
- 注入失败(损坏/权限)自动 fallback 到 Playwright。
|
||||
- bundle 只含 `cache/{id,user,quota,config.json}` 4 个文件;大小上限 4 MiB,实际通常 < 10 KB。
|
||||
- **bundle 等同于密钥**,落盘需 `chmod 600`,不要进 git。
|
||||
|
||||
### 4.3 Prometheus 接入
|
||||
|
||||
```yaml
|
||||
# prometheus scrape_configs 片段
|
||||
- job_name: lingma-gateway
|
||||
bearer_token: <METRICS_TOKEN>
|
||||
static_configs: [{targets: ['host:8317']}]
|
||||
metrics_path: /metrics
|
||||
```bash
|
||||
PORT=$(grep '^PORT=' .env | cut -d= -f2)
|
||||
curl -s "http://127.0.0.1:${PORT}/healthz"
|
||||
```
|
||||
|
||||
关键指标:
|
||||
---
|
||||
|
||||
| 指标 | 类型 | 意义 |
|
||||
## 排障速查
|
||||
|
||||
| 现象 | 常见原因 | 处理 |
|
||||
|---|---|---|
|
||||
| `gateway_in_flight` / `gateway_queued` | gauge | 并发 / 排队 |
|
||||
| `gateway_rejected_total` | counter | 背压拒绝(429)累计 |
|
||||
| `gateway_pool_instance_ready{name}` | gauge | 每实例是否就绪(0/1) |
|
||||
| `gateway_pool_instance_in_flight{name}` | gauge | 每实例在途 |
|
||||
| `gateway_session_cache_hit_total` / `_miss_total` | counter | 会话复用命中率原料 |
|
||||
| `gateway_chat_requests_success` / `_error` | counter | chat 成功率 |
|
||||
| `/v1/*` 返回 401 | 缺失或错误 API key | 检查 `Authorization: Bearer` 或 `x-api-key` |
|
||||
| `healthz` 正常但请求失败 | 用错端口 | 以 `.env` 的 `PORT` 为准,`docker compose ps` 再确认 |
|
||||
| `git pull` 提示 not on a branch | 处于 detached HEAD | 执行 `git checkout -B main origin/main` |
|
||||
| 自动登录不稳定 | 浏览器流程波动 | 优先使用 `LINGMA_SESSION_BUNDLE(_FILE)` |
|
||||
| 工具调用未触发 | 模型未选择工具 | 使用 `tool_choice` 强制,必要时约束输出 JSON |
|
||||
|
||||
---
|
||||
|
||||
## 五、升级注意事项
|
||||
## 文档入口
|
||||
|
||||
从旧版本升级时注意**破坏性变更**(每一项都有 fallback,默认不会炸,但建议显式配置):
|
||||
|
||||
| 版本 | 变更 | 应对 |
|
||||
|---|---|---|
|
||||
| v0.3 | `/metrics` 裸奔时(无 token / 无 key)由公开改为 503 | 显式配 `METRICS_PUBLIC=true` 或 `METRICS_TOKEN` |
|
||||
| v0.3 | `/internal/*` 引入 `ADMIN_TOKEN` | 未配置自动 fallback 到 `API_KEYS`,生产建议单独配 |
|
||||
| v0.2 | 默认会话复用(多轮对话只发增量) | 如果你的客户端裁剪了历史导致语义不连续,设 `SESSION_REUSE_ENABLED=false` |
|
||||
| v0.2 | Chat 请求走 JSON-RPC `notify` 而非 `request`(修复 30s TTFB bug) | 无需行动 |
|
||||
| v0.2 | 多实例池(`LINGMA_ACCOUNTS` 存在时启用) | 不配则保持单实例行为 |
|
||||
|
||||
---
|
||||
|
||||
## 六、故障排查(FAQ)
|
||||
|
||||
| 症状 | 排查方向 |
|
||||
|---|---|
|
||||
| `/healthz` 返回 `ok=false` / `pool_ready=0` | 查 `docker logs`,关键字 `lingma spawned` / `state ... -> ready`;若卡在 `starting` → Lingma 二进制或 workDir 权限问题 |
|
||||
| 返回 `401` 且带 `Invalid admin token` | 你用了 `API_KEYS` 去打 `/internal/*`,但服务端已设了 `ADMIN_TOKEN`;用 `ADMIN_TOKEN` 或清空 `ADMIN_TOKEN` |
|
||||
| 返回 `503 metrics scraping disabled` | 三个 env 全空,按 "权限分层" 章节配任一 |
|
||||
| 返回 `429 Too many in-flight` | 并发超过 `GATEWAY_MAX_IN_FLIGHT`;增大或客户端加重试 |
|
||||
| 首 token 延迟 2-3 秒 | Lingma 侧常态;多轮对话第二轮起,会话复用命中后 TTFB 明显降低(看 `gateway_session_cache_hit_total`) |
|
||||
| Playwright 登录失败 | 导出一个已登录 bundle 注入(见 4.2),彻底跳过浏览器 |
|
||||
| 容器重启后 Lingma 要重新登录 | `data/` 没挂在卷上或被清过;确认 `./data:/app/data` 挂载 + bundle fallback |
|
||||
| 升级后 `/metrics` 返回 503 | v0.3 默认严格;按表格 5.1 配置 |
|
||||
|
||||
开 `LOG_LEVEL=DEBUG` 可以看到 Lingma 子进程的 stderr 输出,便于定位 native 崩溃。
|
||||
|
||||
---
|
||||
|
||||
## 七、开发与二开
|
||||
|
||||
项目本身是单仓 FastAPI,3400 行 Python。推荐阅读路径:
|
||||
|
||||
1. **先读 [`DESIGN.md`](./DESIGN.md)** —— 架构、模块职责、关键设计决策、二开指引。
|
||||
2. 再按需读对应模块:
|
||||
- 想改请求入口 / 路由 → `app/main.py`
|
||||
- 想加实例调度策略 → `app/lingma_pool.py::pick()`
|
||||
- 想改 Lingma 通信协议 → `app/lingma_client.py`
|
||||
- 想扩展会话复用 → `app/session_cache.py` + `main.py` 的 reuse 块
|
||||
- 想做认证改造 → `app/auth.py` + `main.py::*_guard`
|
||||
3. 本地跑:`pip install -r requirements.txt && uvicorn app.main:app --reload`。
|
||||
|
||||
---
|
||||
|
||||
## 八、目录结构
|
||||
|
||||
```
|
||||
lingma-openai-gateway/
|
||||
├── app/ # 主代码(见 DESIGN.md 模块一览)
|
||||
│ ├── main.py # FastAPI 入口 + 路由
|
||||
│ ├── lingma_pool.py # N 实例池
|
||||
│ ├── lingma_client.py # LSP over WS + 子进程管理
|
||||
│ ├── session_cache.py # 多轮对话 sessionId 复用
|
||||
│ ├── session_bundle.py # 登录态 export/import
|
||||
│ ├── concurrency.py # InFlightGuard 背压
|
||||
│ ├── auto_login.py # Playwright 登录
|
||||
│ ├── auth.py # Bearer / admin / metrics 三档鉴权
|
||||
│ ├── config.py # 环境变量 → dataclass
|
||||
│ ├── model_map.py # 模型 key ↔ displayName
|
||||
│ ├── openai_schema.py # OpenAI 请求/响应 Pydantic
|
||||
│ ├── stats.py # StatsCollector + Prometheus
|
||||
│ ├── logging_config.py # 结构化 JSON log + request_id 上下文
|
||||
│ └── bootstrap_lingma.py # 启动时下载/提取 Lingma 二进制
|
||||
├── data/ # 持久化(Lingma 二进制 + workDir),不进 git
|
||||
├── secrets/ # 注入的 bundle 等敏感文件,不进 git
|
||||
├── Dockerfile # Playwright base + HEALTHCHECK
|
||||
├── docker-compose.yml
|
||||
├── .env.example # 配置权威文档
|
||||
├── requirements.txt
|
||||
├── README.md # 本文件
|
||||
└── DESIGN.md # 架构与二开手册
|
||||
```
|
||||
|
||||
---
|
||||
- 配置权威:[`/.env.example`](./.env.example)
|
||||
- 架构/模块边界/设计决策:[`/DESIGN.md`](./DESIGN.md)
|
||||
- 主要入口代码:[`/app/main.py`](./app/main.py)
|
||||
- 测试:[`/tests/test_tool_call_bridge.py`](./tests/test_tool_call_bridge.py)
|
||||
|
||||
## License
|
||||
|
||||
内部使用,按需调整。
|
||||
MIT
|
||||
|
||||
@@ -44,6 +44,7 @@ class Settings:
|
||||
session_reuse_enabled: bool = True
|
||||
session_cache_max_entries: int = 256
|
||||
session_cache_ttl_sec: float = 1800.0
|
||||
tool_forward_enabled: bool = False
|
||||
|
||||
|
||||
def _bool_env(name: str, default: bool) -> bool:
|
||||
@@ -175,4 +176,5 @@ def load_settings() -> Settings:
|
||||
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
|
||||
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
|
||||
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
|
||||
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable, Optional
|
||||
from typing import Any, AsyncIterator, Callable, Optional
|
||||
|
||||
import websockets
|
||||
|
||||
@@ -100,9 +100,90 @@ class LspWsRpcClient:
|
||||
self._reader_task: asyncio.Task | None = None
|
||||
self._rx_buffer = b""
|
||||
self._chat_streams: dict[str, dict] = {}
|
||||
self._tool_stream_map: dict[str, str] = {}
|
||||
self._tool_roundtrip_done: set[str] = set()
|
||||
self._on_disconnect = on_disconnect
|
||||
self._closed = False
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_event(params: dict[str, Any]) -> dict[str, Any] | None:
|
||||
candidates: list[dict[str, Any]] = []
|
||||
|
||||
def add_candidate(obj: Any) -> None:
|
||||
if isinstance(obj, dict):
|
||||
candidates.append(obj)
|
||||
|
||||
add_candidate(params.get("toolCall"))
|
||||
add_candidate(params.get("tool_call"))
|
||||
add_candidate(params.get("tool"))
|
||||
|
||||
data = params.get("data")
|
||||
if isinstance(data, dict):
|
||||
add_candidate(data.get("toolCall"))
|
||||
add_candidate(data.get("tool_call"))
|
||||
add_candidate(data.get("tool"))
|
||||
|
||||
results = params.get("results")
|
||||
if isinstance(results, list):
|
||||
for item in results:
|
||||
add_candidate(item)
|
||||
|
||||
if not candidates:
|
||||
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
|
||||
if not fallback_id:
|
||||
return None
|
||||
return {
|
||||
"id": str(fallback_id),
|
||||
"name": str(params.get("name") or "tool"),
|
||||
"input": params.get("parameters") or {},
|
||||
"result": params.get("result"),
|
||||
}
|
||||
|
||||
raw = candidates[0]
|
||||
tool_id = (
|
||||
raw.get("toolCallId")
|
||||
or raw.get("tool_call_id")
|
||||
or raw.get("id")
|
||||
or params.get("toolCallId")
|
||||
or params.get("tool_call_id")
|
||||
)
|
||||
name = (
|
||||
raw.get("name")
|
||||
or raw.get("toolName")
|
||||
or raw.get("tool_name")
|
||||
or params.get("name")
|
||||
)
|
||||
|
||||
call_input = raw.get("input")
|
||||
if call_input is None:
|
||||
call_input = raw.get("arguments")
|
||||
if call_input is None:
|
||||
call_input = raw.get("args")
|
||||
if call_input is None:
|
||||
call_input = raw.get("parameters")
|
||||
if call_input is None:
|
||||
call_input = params.get("parameters")
|
||||
|
||||
result_payload = raw.get("result")
|
||||
if result_payload is None:
|
||||
result_payload = params.get("result")
|
||||
if result_payload is None and isinstance(data, dict):
|
||||
result_payload = data.get("result")
|
||||
if result_payload is None and isinstance(raw.get("results"), list):
|
||||
result_payload = raw.get("results")
|
||||
|
||||
if not tool_id:
|
||||
return None
|
||||
|
||||
event: dict[str, Any] = {
|
||||
"id": str(tool_id),
|
||||
"name": str(name or "tool"),
|
||||
"input": call_input if call_input is not None else {},
|
||||
}
|
||||
if result_payload is not None:
|
||||
event["result"] = result_payload
|
||||
return event
|
||||
|
||||
async def start(self):
|
||||
self._reader_task = asyncio.create_task(self._reader_loop())
|
||||
|
||||
@@ -123,6 +204,8 @@ class LspWsRpcClient:
|
||||
stream["done"].set()
|
||||
stream["chunks"].put_nowait(None)
|
||||
self._chat_streams.clear()
|
||||
self._tool_stream_map.clear()
|
||||
self._tool_roundtrip_done.clear()
|
||||
|
||||
async def _send(self, payload: dict):
|
||||
async with self._send_lock:
|
||||
@@ -172,6 +255,141 @@ class LspWsRpcClient:
|
||||
except Exception:
|
||||
logger.exception("on_disconnect callback failed")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
|
||||
event_id = None
|
||||
if isinstance(tool_event, dict):
|
||||
event_id = tool_event.get("id")
|
||||
if isinstance(event_id, str) and event_id.strip():
|
||||
return event_id.strip()
|
||||
|
||||
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
|
||||
if isinstance(fallback_id, str) and fallback_id.strip():
|
||||
return fallback_id.strip()
|
||||
|
||||
req_id = params.get("requestId")
|
||||
name = None
|
||||
if isinstance(tool_event, dict):
|
||||
name = tool_event.get("name")
|
||||
if not name:
|
||||
name = params.get("name")
|
||||
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
|
||||
return f"{req_id.strip()}:tool:{name.strip()}"
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
return f"{req_id.strip()}:tool"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||
merged = dict(existing or {})
|
||||
changed = False
|
||||
|
||||
val = incoming.get("id")
|
||||
if val and merged.get("id") != val:
|
||||
merged["id"] = val
|
||||
changed = True
|
||||
|
||||
name = incoming.get("name")
|
||||
if name:
|
||||
existing_name = merged.get("name")
|
||||
if not existing_name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
else:
|
||||
existing_norm = str(existing_name).strip().lower()
|
||||
incoming_norm = str(name).strip().lower()
|
||||
if existing_norm == "tool" and incoming_norm != "tool":
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
elif existing_norm != "tool" and incoming_norm == "tool":
|
||||
pass
|
||||
elif merged.get("name") != name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
|
||||
if "input" in incoming and incoming.get("input") is not None:
|
||||
incoming_input = incoming.get("input")
|
||||
should_update_input = incoming_input != {} or "input" not in merged
|
||||
if should_update_input and merged.get("input") != incoming_input:
|
||||
merged["input"] = incoming_input
|
||||
changed = True
|
||||
|
||||
if "result" in incoming and incoming.get("result") is not None:
|
||||
if merged.get("result") != incoming.get("result"):
|
||||
merged["result"] = incoming.get("result")
|
||||
changed = True
|
||||
|
||||
return merged, changed
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _is_tool_roundtrip_method(method: str | None) -> bool:
|
||||
return method in {"tool/call/sync", "tool/invoke"}
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_approve_params(params: dict[str, Any], tool_id: str) -> dict[str, Any] | None:
|
||||
req_id = params.get("requestId")
|
||||
session_id = params.get("sessionId")
|
||||
if not isinstance(req_id, str) or not req_id.strip():
|
||||
return None
|
||||
if not isinstance(session_id, str) or not session_id.strip():
|
||||
return None
|
||||
return {
|
||||
"type": "tool_call",
|
||||
"sessionId": session_id,
|
||||
"requestId": req_id,
|
||||
"toolCallId": tool_id,
|
||||
"approval": True,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_tool_invoke_result_params(params: dict[str, Any], tool_event: dict[str, Any], tool_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"toolCallId": tool_id,
|
||||
"name": str(tool_event.get("name") or params.get("name") or "tool"),
|
||||
"success": True,
|
||||
"errorMessage": "",
|
||||
"result": tool_event.get("result") if "result" in tool_event else {},
|
||||
}
|
||||
|
||||
async def _maybe_emit_tool_roundtrip(self, method: str, params: dict[str, Any], tool_event: dict[str, Any]) -> None:
|
||||
if not self._is_tool_roundtrip_method(method):
|
||||
return
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if not tool_id:
|
||||
return
|
||||
if tool_id in self._tool_roundtrip_done:
|
||||
return
|
||||
|
||||
approve_params = self._build_tool_approve_params(params, tool_id)
|
||||
if approve_params is None:
|
||||
return
|
||||
|
||||
self._tool_roundtrip_done.add(tool_id)
|
||||
await self.notify("tool/call/approve", approve_params)
|
||||
invoke_result_params = self._build_tool_invoke_result_params(params, tool_event, tool_id)
|
||||
await self.notify("tool/invokeResult", invoke_result_params)
|
||||
|
||||
|
||||
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||
req_id = params.get("requestId")
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
stream = self._chat_streams.get(req_id)
|
||||
if stream is not None and tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
self._tool_stream_map[tool_id] = req_id
|
||||
return stream
|
||||
|
||||
if tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
mapped_req = self._tool_stream_map.get(tool_id)
|
||||
if mapped_req:
|
||||
return self._chat_streams.get(mapped_req)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_server_message(self, msg: dict):
|
||||
method = msg.get("method")
|
||||
params = msg.get("params") or {}
|
||||
@@ -185,7 +403,34 @@ class LspWsRpcClient:
|
||||
stream["parts"].append(text)
|
||||
if stream["first_chunk_at"] is None:
|
||||
stream["first_chunk_at"] = time.monotonic()
|
||||
stream["chunks"].put_nowait(text)
|
||||
stream["chunks"].put_nowait({"type": "text", "text": text})
|
||||
|
||||
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
|
||||
tool_event = self._extract_tool_event(params)
|
||||
stream = self._resolve_tool_stream(method, params, tool_event)
|
||||
|
||||
if stream is not None and tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if not tool_id:
|
||||
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||
else:
|
||||
await self._maybe_emit_tool_roundtrip(method, params, tool_event)
|
||||
tool_states = stream["tool_states"]
|
||||
order = stream["tool_order"]
|
||||
existing = tool_states.get(tool_id)
|
||||
merged, changed = self._merge_tool_event(existing, tool_event)
|
||||
if not existing:
|
||||
if "id" not in merged or not merged.get("id"):
|
||||
merged["id"] = tool_id
|
||||
tool_states[tool_id] = merged
|
||||
order.append(tool_id)
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif changed:
|
||||
tool_states[tool_id] = merged
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif tool_event is not None:
|
||||
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
|
||||
|
||||
|
||||
if method == "chat/finish":
|
||||
req_id = params.get("requestId")
|
||||
@@ -224,6 +469,8 @@ class LspWsRpcClient:
|
||||
"chunks": asyncio.Queue(),
|
||||
"done": asyncio.Event(),
|
||||
"finish": None,
|
||||
"tool_states": {},
|
||||
"tool_order": [],
|
||||
"started_at": time.monotonic(),
|
||||
"first_chunk_at": None,
|
||||
"finish_at": None,
|
||||
@@ -233,13 +480,17 @@ class LspWsRpcClient:
|
||||
stream = self._chat_streams.pop(request_id, None)
|
||||
if stream is None:
|
||||
return
|
||||
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||
if mapped_req == request_id:
|
||||
self._tool_stream_map.pop(tool_id, None)
|
||||
self._tool_roundtrip_done.discard(tool_id)
|
||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||
if not stream["done"].is_set():
|
||||
stream["done"].set()
|
||||
with contextlib.suppress(Exception):
|
||||
stream["chunks"].put_nowait(None)
|
||||
|
||||
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]:
|
||||
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[dict[str, Any]]:
|
||||
stream = self._chat_streams.get(request_id)
|
||||
if stream is None:
|
||||
return
|
||||
@@ -261,11 +512,20 @@ class LspWsRpcClient:
|
||||
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
||||
if stream.get("finish_at") is not None:
|
||||
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
||||
|
||||
ordered_tool_events: list[dict[str, Any]] = []
|
||||
tool_states = stream.get("tool_states") or {}
|
||||
for tool_id in stream.get("tool_order") or []:
|
||||
event = tool_states.get(tool_id)
|
||||
if isinstance(event, dict):
|
||||
ordered_tool_events.append(event)
|
||||
|
||||
return {
|
||||
"text": "".join(stream.get("parts") or []),
|
||||
"finish": stream.get("finish") or {},
|
||||
"firstTokenLatencyMs": first_ms,
|
||||
"totalLatencyMs": total_ms,
|
||||
"toolEvents": ordered_tool_events,
|
||||
}
|
||||
|
||||
|
||||
@@ -634,13 +894,14 @@ class LingmaGatewayClient:
|
||||
request_id: str,
|
||||
*,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
):
|
||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
||||
return {
|
||||
session_type = "ask" if ask_mode == "agent" else "chat"
|
||||
payload = {
|
||||
"requestId": request_id,
|
||||
"sessionId": session_id,
|
||||
"sessionType": session_type,
|
||||
"chatTask": "FREE_INPUT",
|
||||
"chatTask": "chat" if ask_mode == "agent" else "FREE_INPUT",
|
||||
"mode": ask_mode,
|
||||
"stream": True,
|
||||
"source": 1,
|
||||
@@ -665,6 +926,9 @@ class LingmaGatewayClient:
|
||||
"localeLang": "zh-CN",
|
||||
},
|
||||
}
|
||||
if tool_config is not None:
|
||||
payload["toolConfig"] = tool_config
|
||||
return payload
|
||||
|
||||
async def _kick_chat_ask(self, payload: dict) -> None:
|
||||
"""Fire chat/ask as a notification.
|
||||
@@ -685,12 +949,19 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
await self.ensure_ready()
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
@@ -721,9 +992,14 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
out_meta: dict | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream `chat/answer` chunks.
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Stream chat events.
|
||||
|
||||
Yields structured events:
|
||||
* {"type": "text", "text": "..."}
|
||||
* {"type": "tool", "tool": {...}}
|
||||
|
||||
If `out_meta` is provided, the final `chat/finish` payload's sessionId
|
||||
(and the raw finish dict) is written into it when the stream ends or is
|
||||
@@ -734,15 +1010,21 @@ class LingmaGatewayClient:
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
await self._kick_chat_ask(payload)
|
||||
async for chunk in self.rpc.consume_stream(
|
||||
async for event in self.rpc.consume_stream(
|
||||
request_id, timeout=max(60.0, self.rpc_timeout + 60.0)
|
||||
):
|
||||
yield chunk
|
||||
yield event
|
||||
finally:
|
||||
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
||||
if out_meta is not None:
|
||||
@@ -753,6 +1035,7 @@ class LingmaGatewayClient:
|
||||
out_meta["finish"] = finish
|
||||
out_meta["request_id"] = request_id
|
||||
out_meta["chars"] = len(stream_result.get("text") or "")
|
||||
out_meta["tool_events"] = stream_result.get("toolEvents") or []
|
||||
except Exception:
|
||||
pass
|
||||
self.rpc.pop_stream(request_id)
|
||||
|
||||
486
app/main.py
486
app/main.py
@@ -6,6 +6,7 @@ import json
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
@@ -350,6 +351,233 @@ def _include_usage(stream_options: dict | None) -> bool:
|
||||
return bool(stream_options.get("include_usage"))
|
||||
|
||||
|
||||
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
|
||||
if not settings.tool_forward_enabled:
|
||||
return None
|
||||
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
return {
|
||||
"provider": "openai",
|
||||
"tools": req.tools or [],
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
|
||||
if not settings.tool_forward_enabled:
|
||||
return None
|
||||
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
"tools": req.tools or [],
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
|
||||
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||
return True
|
||||
if req.tool_choice is not None:
|
||||
return True
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
if role == "tool":
|
||||
return True
|
||||
if role == "assistant" and m.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
|
||||
if not isinstance(content, list):
|
||||
return False
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
|
||||
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||
return True
|
||||
if req.tool_choice is not None:
|
||||
return True
|
||||
if _anthropic_content_has_tool_blocks(req.system):
|
||||
return True
|
||||
for m in req.messages:
|
||||
if _anthropic_content_has_tool_blocks(m.content):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _stream_event_type(event: Any) -> str:
|
||||
if isinstance(event, dict):
|
||||
t = event.get("type")
|
||||
if t in {"text", "tool"}:
|
||||
return t
|
||||
return "text"
|
||||
|
||||
|
||||
def _stream_text(event: Any) -> str:
|
||||
if isinstance(event, dict):
|
||||
if event.get("type") == "text":
|
||||
text = event.get("text")
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
if isinstance(event, str):
|
||||
return event
|
||||
return ""
|
||||
|
||||
|
||||
def _stream_tool_event(event: Any) -> dict[str, Any] | None:
|
||||
if isinstance(event, dict) and event.get("type") == "tool":
|
||||
tool = event.get("tool")
|
||||
if isinstance(tool, dict):
|
||||
return tool
|
||||
return None
|
||||
|
||||
|
||||
def _json_string(value: Any) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.dumps(value if value is not None else {}, ensure_ascii=False)
|
||||
except Exception:
|
||||
return "{}"
|
||||
|
||||
|
||||
def _openai_forced_tool_name(tool_choice: Any) -> str | None:
|
||||
if not isinstance(tool_choice, dict):
|
||||
return None
|
||||
fn = tool_choice.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _anthropic_forced_tool_name(tool_choice: Any) -> str | None:
|
||||
if not isinstance(tool_choice, dict):
|
||||
return None
|
||||
if tool_choice.get("type") == "tool":
|
||||
name = tool_choice.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
fn = tool_choice.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name.strip()
|
||||
return None
|
||||
|
||||
|
||||
def _json_object_from_text(text: str) -> dict[str, Any] | None:
|
||||
raw = text.strip()
|
||||
if not raw:
|
||||
return None
|
||||
if raw.startswith("```") and raw.endswith("```"):
|
||||
raw = raw[3:-3].strip()
|
||||
if raw.lower().startswith("json"):
|
||||
raw = raw[4:].strip()
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
|
||||
parsed = _json_object_from_text(text)
|
||||
if parsed is None:
|
||||
return None
|
||||
|
||||
explicit_name: Any = parsed.get("name") or parsed.get("tool")
|
||||
fn = parsed.get("function")
|
||||
if explicit_name is None and isinstance(fn, dict):
|
||||
explicit_name = fn.get("name")
|
||||
if explicit_name is not None and str(explicit_name) != forced_tool_name:
|
||||
return None
|
||||
|
||||
tool_input: Any = None
|
||||
if "input" in parsed:
|
||||
tool_input = parsed.get("input")
|
||||
elif "arguments" in parsed:
|
||||
args = parsed.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tool_input = json.loads(args)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
tool_input = args
|
||||
elif isinstance(fn, dict) and "arguments" in fn:
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
tool_input = json.loads(args)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
tool_input = args
|
||||
else:
|
||||
reserved = {"name", "tool", "function", "arguments", "input", "result"}
|
||||
tool_input = {k: v for k, v in parsed.items() if k not in reserved}
|
||||
|
||||
event: dict[str, Any] = {
|
||||
"name": forced_tool_name,
|
||||
"input": tool_input if tool_input is not None else {},
|
||||
}
|
||||
if "result" in parsed:
|
||||
event["result"] = parsed.get("result")
|
||||
return event
|
||||
|
||||
|
||||
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(tool.get("name") or "tool"),
|
||||
"arguments": _json_string(tool.get("input")),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_use_block(
|
||||
tool: dict[str, Any], *, forced_id: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
|
||||
"name": str(tool.get("name") or "tool"),
|
||||
"input": tool.get("input") if tool.get("input") is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_result_block(
|
||||
tool: dict[str, Any], *, forced_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
if "result" not in tool:
|
||||
return None
|
||||
result = tool.get("result")
|
||||
if isinstance(result, str):
|
||||
content: Any = result
|
||||
else:
|
||||
content = _json_string(result)
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": str(tool.get("id") or forced_id or ""),
|
||||
"content": content,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(auth_guard)])
|
||||
async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
p = _require_pool()
|
||||
@@ -363,22 +591,26 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
|
||||
# 2. Send only the new user message instead of the whole history.
|
||||
# 3. Stick the request to the pool instance that originally served it.
|
||||
tool_config = _openai_tool_config(req)
|
||||
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
||||
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
ask_mode = "agent"
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled
|
||||
and ask_mode == "chat"
|
||||
and len(messages_dump) >= 2
|
||||
and not has_tooling_context
|
||||
)
|
||||
lookup_key: str | None = None
|
||||
write_key: str | None = None
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||
write_key = session_cache.build_key(api_key, messages_dump)
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
@@ -476,6 +708,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
|
||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||
success = False
|
||||
tool_call_indexes: dict[str, int] = {}
|
||||
saw_tool_call = False
|
||||
try:
|
||||
async for chunk in _inst.client.chat_stream(
|
||||
prompt,
|
||||
@@ -483,9 +717,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
out_meta=_meta,
|
||||
):
|
||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
||||
if _stream_event_type(chunk) == "tool":
|
||||
tool = _stream_tool_event(chunk)
|
||||
if not tool:
|
||||
continue
|
||||
tool_id = str(tool.get("id") or "")
|
||||
if not tool_id:
|
||||
tool_id = f"call_{len(tool_call_indexes)}"
|
||||
idx = tool_call_indexes.get(tool_id)
|
||||
if idx is None:
|
||||
idx = len(tool_call_indexes)
|
||||
tool_call_indexes[tool_id] = idx
|
||||
saw_tool_call = True
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -494,7 +740,34 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": chunk},
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": idx,
|
||||
**_openai_tool_call(tool, forced_id=tool_id),
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
continue
|
||||
|
||||
text = _stream_text(chunk)
|
||||
if not text:
|
||||
continue
|
||||
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": text},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
@@ -506,10 +779,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "tool_calls" if saw_tool_call else "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
if include_usage:
|
||||
usage_payload = {
|
||||
"id": completion_id,
|
||||
@@ -567,6 +847,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
||||
@@ -596,6 +877,24 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
sid = result.get("sessionId")
|
||||
if sid:
|
||||
await session_cache.put(write_key, sid, inst.name)
|
||||
tool_events = result.get("toolEvents") or []
|
||||
message_content = result.get("text") or ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
saw_tool_call = False
|
||||
if isinstance(tool_events, list):
|
||||
for idx, item in enumerate(tool_events):
|
||||
if isinstance(item, dict):
|
||||
tool_id = str(item.get("id") or f"call_{idx}")
|
||||
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||
saw_tool_call = True
|
||||
if not saw_tool_call:
|
||||
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
|
||||
if forced_tool_name:
|
||||
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
|
||||
if fallback_event is not None:
|
||||
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
|
||||
saw_tool_call = True
|
||||
message_content = ""
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||
created=int(time.time()),
|
||||
@@ -603,11 +902,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message={"role": "assistant", "content": result.get("text") or ""},
|
||||
finish_reason="tool_calls" if saw_tool_call else "stop",
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": message_content,
|
||||
"tool_calls": tool_calls or None,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
data = response.model_dump()
|
||||
data["latency"] = {
|
||||
"first_token_ms": result.get("firstTokenLatencyMs"),
|
||||
@@ -634,13 +939,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str:
|
||||
"""Approximate Anthropic `stop_reason`.
|
||||
|
||||
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it;
|
||||
we report `max_tokens` only when the generated length meets or exceeds
|
||||
the caller's stated ceiling. Everything else is `end_turn`.
|
||||
"""
|
||||
def _anthropic_stop_reason(
|
||||
completion_tokens: int,
|
||||
max_tokens: int,
|
||||
*,
|
||||
has_pending_tool_use: bool = False,
|
||||
) -> str:
|
||||
"""Approximate Anthropic `stop_reason`."""
|
||||
if has_pending_tool_use:
|
||||
return "tool_use"
|
||||
if max_tokens and completion_tokens >= max_tokens:
|
||||
return "max_tokens"
|
||||
return "end_turn"
|
||||
@@ -700,19 +1007,23 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------- session reuse
|
||||
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
|
||||
ask_mode = "chat"
|
||||
tool_config = _anthropic_tool_config(req)
|
||||
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||
|
||||
ask_mode = settings.default_ask_mode
|
||||
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
ask_mode = "agent"
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2
|
||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
||||
)
|
||||
lookup_key: str | None = None
|
||||
write_key: str | None = None
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||
write_key = session_cache.build_key(api_key, messages_dump)
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
@@ -760,7 +1071,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
return _anthropic_error(400, "invalid_request_error", "messages is empty")
|
||||
|
||||
prompt_tokens = estimate_tokens(prompt)
|
||||
|
||||
# ------------------------------------------------------------- backpressure
|
||||
try:
|
||||
ticket = await chat_guard.try_acquire()
|
||||
@@ -810,6 +1120,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
|
||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||
success = False
|
||||
block_index = 0
|
||||
text_block_open = False
|
||||
saw_pending_tool_use = False
|
||||
try:
|
||||
# 1) message_start — Anthropic SDKs read this first to get
|
||||
# the message envelope (id/model/initial usage).
|
||||
@@ -833,47 +1146,99 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
}
|
||||
yield _sse("message_start", start_payload)
|
||||
|
||||
# 2) content_block_start for a single text block (index 0).
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
|
||||
# 3) content_block_delta stream of text tokens.
|
||||
async for chunk in _inst.client.chat_stream(
|
||||
prompt,
|
||||
model,
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
out_meta=_meta,
|
||||
):
|
||||
if not chunk:
|
||||
if _stream_event_type(chunk) == "tool":
|
||||
if text_block_open:
|
||||
yield _sse(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": block_index},
|
||||
)
|
||||
block_index += 1
|
||||
text_block_open = False
|
||||
|
||||
tool = _stream_tool_event(chunk)
|
||||
if not tool:
|
||||
continue
|
||||
completion_tokens_holder["n"] += estimate_tokens(chunk)
|
||||
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
|
||||
|
||||
tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_index,
|
||||
"content_block": tool_use_block,
|
||||
},
|
||||
)
|
||||
yield _sse(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": block_index},
|
||||
)
|
||||
block_index += 1
|
||||
|
||||
tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
|
||||
if tool_result_block is not None:
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_index,
|
||||
"content_block": tool_result_block,
|
||||
},
|
||||
)
|
||||
yield _sse(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": block_index},
|
||||
)
|
||||
block_index += 1
|
||||
else:
|
||||
saw_pending_tool_use = True
|
||||
continue
|
||||
|
||||
text = _stream_text(chunk)
|
||||
if not text:
|
||||
continue
|
||||
completion_tokens_holder["n"] += estimate_tokens(text)
|
||||
if not text_block_open:
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_index,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
text_block_open = True
|
||||
|
||||
yield _sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": chunk},
|
||||
"index": block_index,
|
||||
"delta": {"type": "text_delta", "text": text},
|
||||
},
|
||||
)
|
||||
|
||||
# 4) content_block_stop closes the single text block.
|
||||
if text_block_open:
|
||||
yield _sse(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
{"type": "content_block_stop", "index": block_index},
|
||||
)
|
||||
|
||||
# 5) message_delta carries the terminal stop_reason and
|
||||
# the final cumulative output_tokens count.
|
||||
stop_reason = _anthropic_stop_reason(
|
||||
completion_tokens_holder["n"], max_tokens
|
||||
completion_tokens_holder["n"],
|
||||
max_tokens,
|
||||
has_pending_tool_use=saw_pending_tool_use,
|
||||
)
|
||||
yield _sse(
|
||||
"message_delta",
|
||||
@@ -887,6 +1252,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 6) message_stop — terminal event, no [DONE] sentinel.
|
||||
yield _sse("message_stop", {"type": "message_stop"})
|
||||
success = True
|
||||
@@ -946,6 +1312,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
|
||||
@@ -972,13 +1339,50 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
if sid:
|
||||
await session_cache.put(write_key, sid, inst.name)
|
||||
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
tool_events = result.get("toolEvents") or []
|
||||
saw_pending_tool_use = False
|
||||
saw_tool_event = False
|
||||
if isinstance(tool_events, list):
|
||||
for idx, item in enumerate(tool_events):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
saw_tool_event = True
|
||||
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
||||
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
||||
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
||||
if tool_result is not None:
|
||||
content_blocks.append(tool_result)
|
||||
else:
|
||||
saw_pending_tool_use = True
|
||||
|
||||
if not saw_tool_event:
|
||||
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
|
||||
if forced_tool_name:
|
||||
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
|
||||
if fallback_event is not None:
|
||||
content_blocks = []
|
||||
tool_id = "toolu_fallback_0"
|
||||
content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id))
|
||||
tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id)
|
||||
saw_pending_tool_use = tool_result is None
|
||||
if tool_result is not None:
|
||||
content_blocks.append(tool_result)
|
||||
|
||||
response_body: dict = {
|
||||
|
||||
"id": message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
|
||||
"content": content_blocks,
|
||||
"stop_reason": _anthropic_stop_reason(
|
||||
completion_tokens,
|
||||
req.max_tokens,
|
||||
has_pending_tool_use=saw_pending_tool_use,
|
||||
),
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": prompt_tokens,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
@@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str:
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _tool_fingerprint(tool_config: dict | None) -> str:
|
||||
if not isinstance(tool_config, dict):
|
||||
return "-"
|
||||
try:
|
||||
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
except Exception:
|
||||
canonical = str(tool_config)
|
||||
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
|
||||
class SessionCache:
|
||||
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
||||
|
||||
@@ -79,11 +90,11 @@ class SessionCache:
|
||||
def enabled(self) -> bool:
|
||||
return self.max > 0
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict]) -> str:
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
|
||||
# API key scoping prevents cross-tenant session leakage even when
|
||||
# different clients happen to produce identical histories.
|
||||
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||
return f"{key_scope}:{hash_user_context(messages)}"
|
||||
return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
||||
|
||||
async def get(self, key: str) -> SessionEntry | None:
|
||||
if not self.enabled:
|
||||
|
||||
981
tests/test_tool_call_bridge.py
Normal file
981
tests/test_tool_call_bridge.py
Normal file
@@ -0,0 +1,981 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
class _FakeSessionCache:
|
||||
def __init__(self) -> None:
|
||||
self.enabled = True
|
||||
self.keys: list[str] = []
|
||||
self.get_calls: list[str] = []
|
||||
self.put_calls: list[tuple[str, str, str]] = []
|
||||
self.invalidate_calls: list[str] = []
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
|
||||
marker = "with_tool" if tool_config is not None else "no_tool"
|
||||
key = f"{api_key}:{len(messages)}:{marker}"
|
||||
self.keys.append(key)
|
||||
return key
|
||||
|
||||
async def get(self, key: str):
|
||||
self.get_calls.append(key)
|
||||
return None
|
||||
|
||||
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
||||
self.put_calls.append((key, session_id, instance_name))
|
||||
|
||||
async def invalidate(self, key: str) -> None:
|
||||
self.invalidate_calls.append(key)
|
||||
|
||||
# app.main imports playwright via auto_login; tests don't exercise that path.
|
||||
# Inject a lightweight stub so unit tests run without installing playwright.
|
||||
_playwright = types.ModuleType("playwright")
|
||||
_playwright_async = types.ModuleType("playwright.async_api")
|
||||
|
||||
|
||||
class _StubPlaywrightTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
async def _stub_async_playwright():
|
||||
raise RuntimeError("playwright is stubbed in unit tests")
|
||||
|
||||
|
||||
_playwright_async.TimeoutError = _StubPlaywrightTimeoutError
|
||||
_playwright_async.async_playwright = _stub_async_playwright
|
||||
sys.modules.setdefault("playwright", _playwright)
|
||||
sys.modules.setdefault("playwright.async_api", _playwright_async)
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from app.anthropic_schema import AnthropicMessagesRequest
|
||||
from app.openai_schema import ChatCompletionsRequest
|
||||
import app.main as main
|
||||
|
||||
|
||||
class _FakeTicket:
|
||||
def __init__(self) -> None:
|
||||
self.released = False
|
||||
|
||||
def release(self) -> None:
|
||||
self.released = True
|
||||
|
||||
|
||||
class _FakeGuard:
|
||||
def __init__(self) -> None:
|
||||
self.in_flight = 0
|
||||
|
||||
async def try_acquire(self) -> _FakeTicket:
|
||||
return _FakeTicket()
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||
self._stream_events = stream_events
|
||||
self._complete_result = complete_result
|
||||
|
||||
async def query_models(self) -> dict:
|
||||
return {
|
||||
"chat": [
|
||||
{
|
||||
"key": "org_auto",
|
||||
"displayName": "Auto",
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||
return self._complete_result
|
||||
|
||||
async def chat_stream(self, *args, **kwargs):
|
||||
out_meta = kwargs.get("out_meta")
|
||||
if isinstance(out_meta, dict):
|
||||
out_meta["session_id"] = "sess-stream"
|
||||
for event in self._stream_events:
|
||||
yield event
|
||||
|
||||
|
||||
class _FakeInstance:
|
||||
def __init__(self, client: _FakeClient) -> None:
|
||||
self.name = "inst-test"
|
||||
self.client = client
|
||||
self.in_flight = 0
|
||||
|
||||
|
||||
class _FakePool:
|
||||
def __init__(self, inst: _FakeInstance) -> None:
|
||||
self._inst = inst
|
||||
|
||||
def pick(self, affinity_key: str | None = None) -> _FakeInstance:
|
||||
return self._inst
|
||||
|
||||
|
||||
def _make_request(path: str, headers: dict[str, str] | None = None) -> Request:
|
||||
header_pairs = []
|
||||
for k, v in (headers or {}).items():
|
||||
header_pairs.append((k.lower().encode("latin-1"), v.encode("latin-1")))
|
||||
scope = {
|
||||
"type": "http",
|
||||
"http_version": "1.1",
|
||||
"method": "POST",
|
||||
"scheme": "http",
|
||||
"path": path,
|
||||
"raw_path": path.encode("latin-1"),
|
||||
"query_string": b"",
|
||||
"headers": header_pairs,
|
||||
"client": ("testclient", 12345),
|
||||
"server": ("testserver", 80),
|
||||
"root_path": "",
|
||||
}
|
||||
return Request(scope)
|
||||
|
||||
|
||||
async def _collect_stream(response) -> str:
|
||||
chunks: list[str] = []
|
||||
async for part in response.body_iterator:
|
||||
if isinstance(part, bytes):
|
||||
chunks.append(part.decode("utf-8"))
|
||||
else:
|
||||
chunks.append(str(part))
|
||||
return "".join(chunks)
|
||||
|
||||
|
||||
class _SpyClient(_FakeClient):
|
||||
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||
super().__init__(stream_events=stream_events, complete_result=complete_result)
|
||||
self.last_complete_args: tuple = ()
|
||||
self.last_stream_args: tuple = ()
|
||||
self.last_complete_kwargs: dict = {}
|
||||
self.last_stream_kwargs: dict = {}
|
||||
|
||||
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||
self.last_complete_args = tuple(args)
|
||||
self.last_complete_kwargs = dict(kwargs)
|
||||
return await super().chat_complete(*args, **kwargs)
|
||||
|
||||
async def chat_stream(self, *args, **kwargs):
|
||||
self.last_stream_args = tuple(args)
|
||||
self.last_stream_kwargs = dict(kwargs)
|
||||
async for event in super().chat_stream(*args, **kwargs):
|
||||
yield event
|
||||
|
||||
|
||||
class _SettingsPatch:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __enter__(self):
|
||||
self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()]
|
||||
for p in self._patchers:
|
||||
p.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for p in reversed(self._patchers):
|
||||
p.stop()
|
||||
return False
|
||||
|
||||
|
||||
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "done",
|
||||
"toolEvents": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"name": "search_docs",
|
||||
"input": {"query": "gateway"},
|
||||
"result": {"ok": True},
|
||||
}
|
||||
],
|
||||
"sessionId": "sess-1",
|
||||
"firstTokenLatencyMs": 12,
|
||||
"totalLatencyMs": 34,
|
||||
},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
):
|
||||
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
payload = json.loads(response.body)
|
||||
message = payload["choices"][0]["message"]
|
||||
self.assertEqual(message["content"], "done")
|
||||
self.assertIsInstance(message["tool_calls"], list)
|
||||
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
|
||||
self.assertEqual(
|
||||
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||
{"query": "gateway"},
|
||||
)
|
||||
|
||||
async def test_openai_non_stream_fallbacks_to_structured_tool_call_for_forced_tool(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "```json\n{\"arguments\": {\"query\": \"gateway\"}}\n```",
|
||||
"toolEvents": [],
|
||||
"sessionId": "sess-fallback-openai",
|
||||
},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
):
|
||||
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
payload = json.loads(response.body)
|
||||
message = payload["choices"][0]["message"]
|
||||
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||
self.assertEqual(message["content"], "")
|
||||
self.assertIsInstance(message["tool_calls"], list)
|
||||
self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup")
|
||||
self.assertEqual(
|
||||
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||
{"query": "gateway"},
|
||||
)
|
||||
|
||||
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "call_stream_1",
|
||||
"name": "read_file",
|
||||
"input": {"path": "README.md"},
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "hello"},
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
):
|
||||
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn('"tool_calls"', body)
|
||||
self.assertIn('"content": "hello"', body)
|
||||
self.assertIn('"finish_reason": "tool_calls"', body)
|
||||
self.assertIn('"usage"', body)
|
||||
self.assertIn("data: [DONE]", body)
|
||||
|
||||
async def test_anthropic_non_stream_bridges_tool_blocks(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "ok",
|
||||
"toolEvents": [
|
||||
{
|
||||
"id": "toolu_1",
|
||||
"name": "lookup",
|
||||
"input": {"k": "v"},
|
||||
"result": {"value": 1},
|
||||
}
|
||||
],
|
||||
"sessionId": "sess-2",
|
||||
},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
payload = json.loads(response.body)
|
||||
types = [item["type"] for item in payload["content"]]
|
||||
self.assertEqual(types, ["text", "tool_use", "tool_result"])
|
||||
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||
|
||||
async def test_anthropic_non_stream_fallbacks_to_structured_tool_blocks_for_forced_tool(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "{\"input\": {\"k\": \"v\"}, \"result\": {\"value\": 1}}",
|
||||
"toolEvents": [],
|
||||
"sessionId": "sess-fallback-anthropic",
|
||||
},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||
tool_choice={"type": "tool", "name": "lookup"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
payload = json.loads(response.body)
|
||||
types = [item["type"] for item in payload["content"]]
|
||||
self.assertEqual(types, ["tool_use", "tool_result"])
|
||||
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||
self.assertEqual(payload["content"][0]["name"], "lookup")
|
||||
self.assertEqual(payload["content"][1]["tool_use_id"], "toolu_fallback_0")
|
||||
|
||||
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "call_a",
|
||||
"name": "read_file",
|
||||
"input": {"path": "README.md"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "call_b",
|
||||
"name": "search_docs",
|
||||
"input": {"query": "gateway"},
|
||||
},
|
||||
},
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
):
|
||||
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn('"id": "call_a"', body)
|
||||
self.assertIn('"id": "call_b"', body)
|
||||
self.assertIn('"index": 0', body)
|
||||
self.assertIn('"index": 1', body)
|
||||
|
||||
async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={
|
||||
"text": "",
|
||||
"toolEvents": [
|
||||
{
|
||||
"name": "lookup",
|
||||
"input": {"k": "v"},
|
||||
}
|
||||
],
|
||||
"sessionId": "sess-2",
|
||||
},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
payload = json.loads(response.body)
|
||||
self.assertEqual(payload["stop_reason"], "tool_use")
|
||||
self.assertEqual(len(payload["content"]), 1)
|
||||
self.assertEqual(payload["content"][0]["type"], "tool_use")
|
||||
|
||||
async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"name": "read",
|
||||
"input": {"file": "a.txt"},
|
||||
},
|
||||
}
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn('"type": "tool_use"', body)
|
||||
self.assertIn('"stop_reason": "tool_use"', body)
|
||||
|
||||
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[
|
||||
{
|
||||
"type": "tool",
|
||||
"tool": {
|
||||
"id": "toolu_stream_1",
|
||||
"name": "read",
|
||||
"input": {"file": "a.txt"},
|
||||
"result": "done",
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "world"},
|
||||
],
|
||||
complete_result={},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=256,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
response = await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
body = await _collect_stream(response)
|
||||
|
||||
self.assertIn("event: message_start", body)
|
||||
self.assertIn('"type": "tool_use"', body)
|
||||
self.assertIn('"type": "tool_result"', body)
|
||||
self.assertIn('"stop_reason": "end_turn"', body)
|
||||
self.assertIn('"type": "text_delta"', body)
|
||||
self.assertIn("event: message_stop", body)
|
||||
|
||||
|
||||
|
||||
async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=True),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||
cfg = spy_client.last_complete_kwargs["tool_config"]
|
||||
self.assertEqual(cfg["provider"], "openai")
|
||||
self.assertEqual(len(cfg["tools"]), 1)
|
||||
self.assertIsInstance(cfg["tool_choice"], dict)
|
||||
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||
|
||||
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=False),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
|
||||
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||
|
||||
|
||||
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||
fake_cache = _FakeSessionCache()
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"},
|
||||
)
|
||||
req = ChatCompletionsRequest(
|
||||
model="org_auto",
|
||||
messages=[
|
||||
{"role": "user", "content": "turn-1"},
|
||||
{"role": "user", "content": "turn-2"},
|
||||
],
|
||||
stream=False,
|
||||
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "session_cache", fake_cache),
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
_SettingsPatch(tool_forward_enabled=True),
|
||||
):
|
||||
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||
|
||||
self.assertEqual(fake_cache.keys, [])
|
||||
self.assertEqual(fake_cache.get_calls, [])
|
||||
self.assertEqual(fake_cache.put_calls, [])
|
||||
|
||||
|
||||
async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None:
|
||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=128,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
tools=[{"name": "write_file", "input_schema": {"type": "object", "properties": {}}}],
|
||||
tool_choice={"type": "auto"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
_SettingsPatch(tool_forward_enabled=True, default_ask_mode="chat"),
|
||||
):
|
||||
await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||
cfg = spy_client.last_complete_kwargs["tool_config"]
|
||||
self.assertEqual(cfg["provider"], "anthropic")
|
||||
self.assertEqual(len(cfg["tools"]), 1)
|
||||
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||
|
||||
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||
fake_cache = _FakeSessionCache()
|
||||
fake_client = _FakeClient(
|
||||
stream_events=[],
|
||||
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"},
|
||||
)
|
||||
req = AnthropicMessagesRequest(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=128,
|
||||
messages=[
|
||||
{"role": "user", "content": "turn-1"},
|
||||
{"role": "user", "content": "turn-2"},
|
||||
],
|
||||
stream=False,
|
||||
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||
tool_choice={"type": "auto"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(main, "session_cache", fake_cache),
|
||||
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||
patch.object(main, "chat_guard", _FakeGuard()),
|
||||
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||
):
|
||||
await main.v1_messages(
|
||||
req,
|
||||
_make_request(
|
||||
"/v1/messages",
|
||||
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(fake_cache.keys, [])
|
||||
self.assertEqual(fake_cache.get_calls, [])
|
||||
self.assertEqual(fake_cache.put_calls, [])
|
||||
|
||||
|
||||
class SessionCacheToolFingerprintTests(unittest.TestCase):
|
||||
def test_build_key_changes_with_tool_config(self) -> None:
|
||||
from app.session_cache import SessionCache
|
||||
|
||||
cache = SessionCache(max_entries=8, ttl_sec=60)
|
||||
messages = [
|
||||
{"role": "system", "content": "sys"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
cfg_a = {
|
||||
"provider": "openai",
|
||||
"tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||
"tool_choice": {"type": "function", "function": {"name": "lookup"}},
|
||||
}
|
||||
cfg_a_reordered = {
|
||||
"tool_choice": {"function": {"name": "lookup"}, "type": "function"},
|
||||
"tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}],
|
||||
"provider": "openai",
|
||||
}
|
||||
cfg_b = {
|
||||
"provider": "openai",
|
||||
"tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}],
|
||||
"tool_choice": {"type": "function", "function": {"name": "lookup_v2"}},
|
||||
}
|
||||
|
||||
key_no_tool = cache.build_key("api-key", messages)
|
||||
key_a = cache.build_key("api-key", messages, tool_config=cfg_a)
|
||||
key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered)
|
||||
key_b = cache.build_key("api-key", messages, tool_config=cfg_b)
|
||||
|
||||
self.assertNotEqual(key_no_tool, key_a)
|
||||
self.assertEqual(key_a, key_a_reordered)
|
||||
self.assertNotEqual(key_a, key_b)
|
||||
|
||||
|
||||
def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "x"},
|
||||
},
|
||||
}
|
||||
)
|
||||
stream = rpc._chat_streams["req-1"]
|
||||
self.assertEqual(stream["tool_order"], [])
|
||||
self.assertEqual(stream["tool_states"], {})
|
||||
self.assertTrue(stream["chunks"].empty())
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-1",
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "a"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invokeResult",
|
||||
"params": {
|
||||
"toolCallId": "call-1",
|
||||
"result": {"ok": True},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = rpc.get_stream_result("req-1")
|
||||
self.assertEqual(len(result["toolEvents"]), 1)
|
||||
self.assertEqual(result["toolEvents"][0]["id"], "call-1")
|
||||
self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"})
|
||||
self.assertEqual(result["toolEvents"][0]["result"], {"ok": True})
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
msg = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/invoke",
|
||||
"params": {
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-dup",
|
||||
"name": "lookup",
|
||||
"parameters": {"q": "dup"},
|
||||
},
|
||||
}
|
||||
await rpc._handle_server_message(msg)
|
||||
await rpc._handle_server_message(msg)
|
||||
|
||||
stream = rpc._chat_streams["req-1"]
|
||||
self.assertEqual(stream["tool_order"], ["call-dup"])
|
||||
self.assertEqual(stream["chunks"].qsize(), 1)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_extracts_tool_event_from_results_and_parameters(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
event = LspWsRpcClient._extract_tool_event(
|
||||
{
|
||||
"toolCallId": "call_sync_1",
|
||||
"parameters": {"path": "README.md"},
|
||||
"results": [
|
||||
{
|
||||
"toolCallId": "call_sync_1",
|
||||
"name": "read_file",
|
||||
"result": {"ok": True},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
event,
|
||||
{
|
||||
"id": "call_sync_1",
|
||||
"name": "read_file",
|
||||
"input": {"path": "README.md"},
|
||||
"result": {"ok": True},
|
||||
},
|
||||
)
|
||||
|
||||
def test_extracts_tool_event_from_invoke_result_payload(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
event = LspWsRpcClient._extract_tool_event(
|
||||
{
|
||||
"toolCallId": "call_inv_1",
|
||||
"name": "search_docs",
|
||||
"parameters": {"query": "gateway"},
|
||||
"result": {"hits": 3},
|
||||
}
|
||||
)
|
||||
self.assertEqual(
|
||||
event,
|
||||
{
|
||||
"id": "call_inv_1",
|
||||
"name": "search_docs",
|
||||
"input": {"query": "gateway"},
|
||||
"result": {"hits": 3},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_tool_sync_triggers_approve_and_invoke_result_requests(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
class _WsStub:
|
||||
def __init__(self) -> None:
|
||||
self.frames: list[bytes] = []
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
self.frames.append(data)
|
||||
|
||||
def _decode(frame: bytes) -> dict:
|
||||
body = frame.split(b"\r\n\r\n", 1)[1]
|
||||
return json.loads(body.decode("utf-8"))
|
||||
|
||||
ws = _WsStub()
|
||||
rpc = LspWsRpcClient(ws)
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/call/sync",
|
||||
"params": {
|
||||
"sessionId": "sess-1",
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-1",
|
||||
"name": "run_in_terminal",
|
||||
"parameters": {"command": "pwd"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
decoded = [_decode(frame) for frame in ws.frames]
|
||||
methods = [item.get("method") for item in decoded]
|
||||
self.assertIn("tool/call/approve", methods)
|
||||
self.assertIn("tool/invokeResult", methods)
|
||||
|
||||
approve = next(item for item in decoded if item.get("method") == "tool/call/approve")
|
||||
self.assertEqual(
|
||||
approve["params"],
|
||||
{
|
||||
"type": "tool_call",
|
||||
"sessionId": "sess-1",
|
||||
"requestId": "req-1",
|
||||
"toolCallId": "call-1",
|
||||
"approval": True,
|
||||
},
|
||||
)
|
||||
|
||||
invoke_result = next(item for item in decoded if item.get("method") == "tool/invokeResult")
|
||||
self.assertEqual(invoke_result["params"]["toolCallId"], "call-1")
|
||||
self.assertEqual(invoke_result["params"]["name"], "run_in_terminal")
|
||||
self.assertTrue(invoke_result["params"]["success"])
|
||||
self.assertEqual(invoke_result["params"]["errorMessage"], "")
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_tool_sync_does_not_emit_roundtrip_without_request_id(self) -> None:
|
||||
from app.lingma_client import LspWsRpcClient
|
||||
|
||||
class _WsStub:
|
||||
def __init__(self) -> None:
|
||||
self.frames: list[bytes] = []
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
self.frames.append(data)
|
||||
|
||||
ws = _WsStub()
|
||||
rpc = LspWsRpcClient(ws)
|
||||
|
||||
async def run() -> None:
|
||||
rpc.create_stream("req-1")
|
||||
await rpc._handle_server_message(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tool/call/sync",
|
||||
"params": {
|
||||
"sessionId": "sess-1",
|
||||
"toolCallId": "call-1",
|
||||
"name": "run_in_terminal",
|
||||
"parameters": {"command": "pwd"},
|
||||
},
|
||||
}
|
||||
)
|
||||
self.assertEqual(ws.frames, [])
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run())
|
||||
Reference in New Issue
Block a user