Compare commits
5 Commits
v0.1.1
...
15cd5e8770
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15cd5e8770 | ||
|
|
63583712a8 | ||
|
|
c67a9c3d61 | ||
|
|
e208025f35 | ||
|
|
3498b81fa2 |
@@ -46,6 +46,9 @@ DEFAULT_MODEL=org_auto
|
|||||||
# 默认模式:chat 或 agent
|
# 默认模式:chat 或 agent
|
||||||
DEFAULT_ASK_MODE=chat
|
DEFAULT_ASK_MODE=chat
|
||||||
|
|
||||||
|
# 请求侧 tools/tool_choice 透传到 Lingma(默认关闭,开启后可支持工具写文件等场景)
|
||||||
|
TOOL_FORWARD_ENABLED=false
|
||||||
|
|
||||||
# 专属域(可选)
|
# 专属域(可选)
|
||||||
DEDICATED_DOMAIN_URL=
|
DEDICATED_DOMAIN_URL=
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ curl -N http://127.0.0.1:8317/v1/messages \
|
|||||||
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。
|
- **模型名兼容**:客户端可以继续传 `claude-3-*` 等名字;未识别的 model 会回退到 `DEFAULT_MODEL` 对应的 Lingma key,后端实际仍由 Lingma 提供(Qwen 系列)。如需显式选模型,直接传 Lingma key(`dashscope_qmodel` 等)。
|
||||||
- **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。
|
- **会话复用共享**:Anthropic 与 OpenAI 两个端点共用同一 `SessionCache`,只要 API key 相同、对话前缀相同,就会命中同一上游 `sessionId`。
|
||||||
- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision)。
|
- **多模态**:`image` 块会被降级为 `[image]` 占位符(Lingma 不支持 vision)。
|
||||||
- **工具事件桥接**:当 Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(含 stream/non-stream)和 Anthropic `tool_use`/`tool_result` blocks(含 stream/non-stream);但请求侧 `tools`/`tool_choice` 仍不会透传到 Lingma。
|
- **工具事件桥接**:当 Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(含 stream/non-stream)和 Anthropic `tool_use`/`tool_result` blocks(含 stream/non-stream);请求侧 `tools`/`tool_choice` 在 `TOOL_FORWARD_ENABLED=true` 时会透传到 Lingma(默认关闭)。
|
||||||
- **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
|
- **鉴权**:优先 `x-api-key` 头(Anthropic 官方 SDK 默认),回退 `Authorization: Bearer`(方便 curl / OpenAI 风格客户端)。
|
||||||
|
|
||||||
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`)
|
### 3.2 观测(`METRICS_TOKEN` 或 `API_KEYS`)
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class LspWsRpcClient:
|
|||||||
self._rx_buffer = b""
|
self._rx_buffer = b""
|
||||||
self._chat_streams: dict[str, dict] = {}
|
self._chat_streams: dict[str, dict] = {}
|
||||||
self._tool_stream_map: dict[str, str] = {}
|
self._tool_stream_map: dict[str, str] = {}
|
||||||
|
self._tool_roundtrip_done: set[str] = set()
|
||||||
self._on_disconnect = on_disconnect
|
self._on_disconnect = on_disconnect
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
@@ -204,6 +205,7 @@ class LspWsRpcClient:
|
|||||||
stream["chunks"].put_nowait(None)
|
stream["chunks"].put_nowait(None)
|
||||||
self._chat_streams.clear()
|
self._chat_streams.clear()
|
||||||
self._tool_stream_map.clear()
|
self._tool_stream_map.clear()
|
||||||
|
self._tool_roundtrip_done.clear()
|
||||||
|
|
||||||
async def _send(self, payload: dict):
|
async def _send(self, payload: dict):
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
@@ -320,6 +322,55 @@ class LspWsRpcClient:
|
|||||||
return merged, changed
|
return merged, changed
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_tool_roundtrip_method(method: str | None) -> bool:
|
||||||
|
return method in {"tool/call/sync", "tool/invoke"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_approve_params(params: dict[str, Any], tool_id: str) -> dict[str, Any] | None:
|
||||||
|
req_id = params.get("requestId")
|
||||||
|
session_id = params.get("sessionId")
|
||||||
|
if not isinstance(req_id, str) or not req_id.strip():
|
||||||
|
return None
|
||||||
|
if not isinstance(session_id, str) or not session_id.strip():
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"type": "tool_call",
|
||||||
|
"sessionId": session_id,
|
||||||
|
"requestId": req_id,
|
||||||
|
"toolCallId": tool_id,
|
||||||
|
"approval": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_invoke_result_params(params: dict[str, Any], tool_event: dict[str, Any], tool_id: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"toolCallId": tool_id,
|
||||||
|
"name": str(tool_event.get("name") or params.get("name") or "tool"),
|
||||||
|
"success": True,
|
||||||
|
"errorMessage": "",
|
||||||
|
"result": tool_event.get("result") if "result" in tool_event else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _maybe_emit_tool_roundtrip(self, method: str, params: dict[str, Any], tool_event: dict[str, Any]) -> None:
|
||||||
|
if not self._is_tool_roundtrip_method(method):
|
||||||
|
return
|
||||||
|
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||||
|
if not tool_id:
|
||||||
|
return
|
||||||
|
if tool_id in self._tool_roundtrip_done:
|
||||||
|
return
|
||||||
|
|
||||||
|
approve_params = self._build_tool_approve_params(params, tool_id)
|
||||||
|
if approve_params is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._tool_roundtrip_done.add(tool_id)
|
||||||
|
await self.notify("tool/call/approve", approve_params)
|
||||||
|
invoke_result_params = self._build_tool_invoke_result_params(params, tool_event, tool_id)
|
||||||
|
await self.notify("tool/invokeResult", invoke_result_params)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||||
req_id = params.get("requestId")
|
req_id = params.get("requestId")
|
||||||
if isinstance(req_id, str) and req_id.strip():
|
if isinstance(req_id, str) and req_id.strip():
|
||||||
@@ -363,6 +414,7 @@ class LspWsRpcClient:
|
|||||||
if not tool_id:
|
if not tool_id:
|
||||||
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||||
else:
|
else:
|
||||||
|
await self._maybe_emit_tool_roundtrip(method, params, tool_event)
|
||||||
tool_states = stream["tool_states"]
|
tool_states = stream["tool_states"]
|
||||||
order = stream["tool_order"]
|
order = stream["tool_order"]
|
||||||
existing = tool_states.get(tool_id)
|
existing = tool_states.get(tool_id)
|
||||||
@@ -431,6 +483,7 @@ class LspWsRpcClient:
|
|||||||
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||||
if mapped_req == request_id:
|
if mapped_req == request_id:
|
||||||
self._tool_stream_map.pop(tool_id, None)
|
self._tool_stream_map.pop(tool_id, None)
|
||||||
|
self._tool_roundtrip_done.discard(tool_id)
|
||||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||||
if not stream["done"].is_set():
|
if not stream["done"].is_set():
|
||||||
stream["done"].set()
|
stream["done"].set()
|
||||||
@@ -843,12 +896,12 @@ class LingmaGatewayClient:
|
|||||||
is_reply: bool = False,
|
is_reply: bool = False,
|
||||||
tool_config: dict[str, Any] | None = None,
|
tool_config: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
session_type = "ask" if ask_mode == "agent" else "chat"
|
||||||
payload = {
|
payload = {
|
||||||
"requestId": request_id,
|
"requestId": request_id,
|
||||||
"sessionId": session_id,
|
"sessionId": session_id,
|
||||||
"sessionType": session_type,
|
"sessionType": session_type,
|
||||||
"chatTask": "FREE_INPUT",
|
"chatTask": "chat" if ask_mode == "agent" else "FREE_INPUT",
|
||||||
"mode": ask_mode,
|
"mode": ask_mode,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"source": 1,
|
"source": 1,
|
||||||
|
|||||||
126
app/main.py
126
app/main.py
@@ -452,6 +452,93 @@ def _json_string(value: Any) -> str:
|
|||||||
return "{}"
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_forced_tool_name(tool_choice: Any) -> str | None:
|
||||||
|
if not isinstance(tool_choice, dict):
|
||||||
|
return None
|
||||||
|
fn = tool_choice.get("function")
|
||||||
|
if isinstance(fn, dict):
|
||||||
|
name = fn.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_forced_tool_name(tool_choice: Any) -> str | None:
|
||||||
|
if not isinstance(tool_choice, dict):
|
||||||
|
return None
|
||||||
|
if tool_choice.get("type") == "tool":
|
||||||
|
name = tool_choice.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
fn = tool_choice.get("function")
|
||||||
|
if isinstance(fn, dict):
|
||||||
|
name = fn.get("name")
|
||||||
|
if isinstance(name, str) and name.strip():
|
||||||
|
return name.strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _json_object_from_text(text: str) -> dict[str, Any] | None:
|
||||||
|
raw = text.strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
if raw.startswith("```") and raw.endswith("```"):
|
||||||
|
raw = raw[3:-3].strip()
|
||||||
|
if raw.lower().startswith("json"):
|
||||||
|
raw = raw[4:].strip()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
return parsed if isinstance(parsed, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def _forced_tool_event_from_text(text: str, forced_tool_name: str) -> dict[str, Any] | None:
|
||||||
|
parsed = _json_object_from_text(text)
|
||||||
|
if parsed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
explicit_name: Any = parsed.get("name") or parsed.get("tool")
|
||||||
|
fn = parsed.get("function")
|
||||||
|
if explicit_name is None and isinstance(fn, dict):
|
||||||
|
explicit_name = fn.get("name")
|
||||||
|
if explicit_name is not None and str(explicit_name) != forced_tool_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_input: Any = None
|
||||||
|
if "input" in parsed:
|
||||||
|
tool_input = parsed.get("input")
|
||||||
|
elif "arguments" in parsed:
|
||||||
|
args = parsed.get("arguments")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
tool_input = json.loads(args)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tool_input = args
|
||||||
|
elif isinstance(fn, dict) and "arguments" in fn:
|
||||||
|
args = fn.get("arguments")
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
tool_input = json.loads(args)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tool_input = args
|
||||||
|
else:
|
||||||
|
reserved = {"name", "tool", "function", "arguments", "input", "result"}
|
||||||
|
tool_input = {k: v for k, v in parsed.items() if k not in reserved}
|
||||||
|
|
||||||
|
event: dict[str, Any] = {
|
||||||
|
"name": forced_tool_name,
|
||||||
|
"input": tool_input if tool_input is not None else {},
|
||||||
|
}
|
||||||
|
if "result" in parsed:
|
||||||
|
event["result"] = parsed.get("result")
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
||||||
@@ -504,13 +591,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
|
# 1. Reuse the upstream sessionId so Lingma/Qwen hits its KV cache.
|
||||||
# 2. Send only the new user message instead of the whole history.
|
# 2. Send only the new user message instead of the whole history.
|
||||||
# 3. Stick the request to the pool instance that originally served it.
|
# 3. Stick the request to the pool instance that originally served it.
|
||||||
ask_mode = settings.default_ask_mode
|
|
||||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
|
||||||
ask_mode = "agent"
|
|
||||||
|
|
||||||
tool_config = _openai_tool_config(req)
|
tool_config = _openai_tool_config(req)
|
||||||
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
||||||
|
|
||||||
|
ask_mode = settings.default_ask_mode
|
||||||
|
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||||
|
ask_mode = "agent"
|
||||||
|
|
||||||
reuse_eligible = (
|
reuse_eligible = (
|
||||||
session_cache.enabled
|
session_cache.enabled
|
||||||
and ask_mode == "chat"
|
and ask_mode == "chat"
|
||||||
@@ -800,6 +887,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
tool_id = str(item.get("id") or f"call_{idx}")
|
tool_id = str(item.get("id") or f"call_{idx}")
|
||||||
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||||
saw_tool_call = True
|
saw_tool_call = True
|
||||||
|
if not saw_tool_call:
|
||||||
|
forced_tool_name = _openai_forced_tool_name(req.tool_choice)
|
||||||
|
if forced_tool_name:
|
||||||
|
fallback_event = _forced_tool_event_from_text(message_content, forced_tool_name)
|
||||||
|
if fallback_event is not None:
|
||||||
|
tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0"))
|
||||||
|
saw_tool_call = True
|
||||||
|
message_content = ""
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -912,12 +1007,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------- session reuse
|
# ------------------------------------------------------------- session reuse
|
||||||
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
|
|
||||||
ask_mode = "chat"
|
|
||||||
|
|
||||||
tool_config = _anthropic_tool_config(req)
|
tool_config = _anthropic_tool_config(req)
|
||||||
has_tooling_context = _anthropic_has_tooling_context(req)
|
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||||
|
|
||||||
|
ask_mode = settings.default_ask_mode
|
||||||
|
if req.model.lower() in {"lingma-agent", "agent"} or has_tooling_context:
|
||||||
|
ask_mode = "agent"
|
||||||
|
|
||||||
reuse_eligible = (
|
reuse_eligible = (
|
||||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
||||||
)
|
)
|
||||||
@@ -1248,10 +1344,12 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
content_blocks.append({"type": "text", "text": text})
|
content_blocks.append({"type": "text", "text": text})
|
||||||
tool_events = result.get("toolEvents") or []
|
tool_events = result.get("toolEvents") or []
|
||||||
saw_pending_tool_use = False
|
saw_pending_tool_use = False
|
||||||
|
saw_tool_event = False
|
||||||
if isinstance(tool_events, list):
|
if isinstance(tool_events, list):
|
||||||
for idx, item in enumerate(tool_events):
|
for idx, item in enumerate(tool_events):
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
|
saw_tool_event = True
|
||||||
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
||||||
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
||||||
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
||||||
@@ -1260,7 +1358,21 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
else:
|
else:
|
||||||
saw_pending_tool_use = True
|
saw_pending_tool_use = True
|
||||||
|
|
||||||
|
if not saw_tool_event:
|
||||||
|
forced_tool_name = _anthropic_forced_tool_name(req.tool_choice)
|
||||||
|
if forced_tool_name:
|
||||||
|
fallback_event = _forced_tool_event_from_text(text, forced_tool_name)
|
||||||
|
if fallback_event is not None:
|
||||||
|
content_blocks = []
|
||||||
|
tool_id = "toolu_fallback_0"
|
||||||
|
content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id))
|
||||||
|
tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id)
|
||||||
|
saw_pending_tool_use = tool_result is None
|
||||||
|
if tool_result is not None:
|
||||||
|
content_blocks.append(tool_result)
|
||||||
|
|
||||||
response_body: dict = {
|
response_body: dict = {
|
||||||
|
|
||||||
"id": message_id,
|
"id": message_id,
|
||||||
"type": "message",
|
"type": "message",
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
|||||||
@@ -147,14 +147,18 @@ async def _collect_stream(response) -> str:
|
|||||||
class _SpyClient(_FakeClient):
|
class _SpyClient(_FakeClient):
|
||||||
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||||
super().__init__(stream_events=stream_events, complete_result=complete_result)
|
super().__init__(stream_events=stream_events, complete_result=complete_result)
|
||||||
|
self.last_complete_args: tuple = ()
|
||||||
|
self.last_stream_args: tuple = ()
|
||||||
self.last_complete_kwargs: dict = {}
|
self.last_complete_kwargs: dict = {}
|
||||||
self.last_stream_kwargs: dict = {}
|
self.last_stream_kwargs: dict = {}
|
||||||
|
|
||||||
async def chat_complete(self, *args, **kwargs) -> dict:
|
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||||
|
self.last_complete_args = tuple(args)
|
||||||
self.last_complete_kwargs = dict(kwargs)
|
self.last_complete_kwargs = dict(kwargs)
|
||||||
return await super().chat_complete(*args, **kwargs)
|
return await super().chat_complete(*args, **kwargs)
|
||||||
|
|
||||||
async def chat_stream(self, *args, **kwargs):
|
async def chat_stream(self, *args, **kwargs):
|
||||||
|
self.last_stream_args = tuple(args)
|
||||||
self.last_stream_kwargs = dict(kwargs)
|
self.last_stream_kwargs = dict(kwargs)
|
||||||
async for event in super().chat_stream(*args, **kwargs):
|
async for event in super().chat_stream(*args, **kwargs):
|
||||||
yield event
|
yield event
|
||||||
@@ -220,6 +224,42 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
{"query": "gateway"},
|
{"query": "gateway"},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def test_openai_non_stream_fallbacks_to_structured_tool_call_for_forced_tool(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "```json\n{\"arguments\": {\"query\": \"gateway\"}}\n```",
|
||||||
|
"toolEvents": [],
|
||||||
|
"sessionId": "sess-fallback-openai",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
):
|
||||||
|
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
message = payload["choices"][0]["message"]
|
||||||
|
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||||
|
self.assertEqual(message["content"], "")
|
||||||
|
self.assertIsInstance(message["tool_calls"], list)
|
||||||
|
self.assertEqual(message["tool_calls"][0]["function"]["name"], "lookup")
|
||||||
|
self.assertEqual(
|
||||||
|
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||||
|
{"query": "gateway"},
|
||||||
|
)
|
||||||
|
|
||||||
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
async def test_openai_stream_bridges_tool_and_text_events(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
stream_events=[
|
stream_events=[
|
||||||
@@ -302,6 +342,46 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(payload["content"][1]["name"], "lookup")
|
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||||
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||||
|
|
||||||
|
async def test_anthropic_non_stream_fallbacks_to_structured_tool_blocks_for_forced_tool(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "{\"input\": {\"k\": \"v\"}, \"result\": {\"value\": 1}}",
|
||||||
|
"toolEvents": [],
|
||||||
|
"sessionId": "sess-fallback-anthropic",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||||
|
tool_choice={"type": "tool", "name": "lookup"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
types = [item["type"] for item in payload["content"]]
|
||||||
|
self.assertEqual(types, ["tool_use", "tool_result"])
|
||||||
|
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||||
|
self.assertEqual(payload["content"][0]["name"], "lookup")
|
||||||
|
self.assertEqual(payload["content"][1]["tool_use_id"], "toolu_fallback_0")
|
||||||
|
|
||||||
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
stream_events=[
|
stream_events=[
|
||||||
@@ -496,6 +576,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(cfg["provider"], "openai")
|
self.assertEqual(cfg["provider"], "openai")
|
||||||
self.assertEqual(len(cfg["tools"]), 1)
|
self.assertEqual(len(cfg["tools"]), 1)
|
||||||
self.assertIsInstance(cfg["tool_choice"], dict)
|
self.assertIsInstance(cfg["tool_choice"], dict)
|
||||||
|
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||||
|
|
||||||
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
|
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
|
||||||
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||||
@@ -518,6 +599,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||||
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
|
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
|
||||||
|
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
|
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||||
@@ -551,6 +633,40 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertEqual(fake_cache.get_calls, [])
|
self.assertEqual(fake_cache.get_calls, [])
|
||||||
self.assertEqual(fake_cache.put_calls, [])
|
self.assertEqual(fake_cache.put_calls, [])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None:
|
||||||
|
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=128,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"name": "write_file", "input_schema": {"type": "object", "properties": {}}}],
|
||||||
|
tool_choice={"type": "auto"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
_SettingsPatch(tool_forward_enabled=True, default_ask_mode="chat"),
|
||||||
|
):
|
||||||
|
await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||||
|
cfg = spy_client.last_complete_kwargs["tool_config"]
|
||||||
|
self.assertEqual(cfg["provider"], "anthropic")
|
||||||
|
self.assertEqual(len(cfg["tools"]), 1)
|
||||||
|
self.assertEqual(spy_client.last_complete_args[2], "agent")
|
||||||
|
|
||||||
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
|
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||||
fake_cache = _FakeSessionCache()
|
fake_cache = _FakeSessionCache()
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
@@ -760,7 +876,6 @@ class SessionCacheToolFingerprintTests(unittest.TestCase):
|
|||||||
"result": {"hits": 3},
|
"result": {"hits": 3},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
event,
|
event,
|
||||||
{
|
{
|
||||||
@@ -770,3 +885,97 @@ class SessionCacheToolFingerprintTests(unittest.TestCase):
|
|||||||
"result": {"hits": 3},
|
"result": {"hits": 3},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_sync_triggers_approve_and_invoke_result_requests(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
class _WsStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.frames: list[bytes] = []
|
||||||
|
|
||||||
|
async def send(self, data: bytes) -> None:
|
||||||
|
self.frames.append(data)
|
||||||
|
|
||||||
|
def _decode(frame: bytes) -> dict:
|
||||||
|
body = frame.split(b"\r\n\r\n", 1)[1]
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
ws = _WsStub()
|
||||||
|
rpc = LspWsRpcClient(ws)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/call/sync",
|
||||||
|
"params": {
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"name": "run_in_terminal",
|
||||||
|
"parameters": {"command": "pwd"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = [_decode(frame) for frame in ws.frames]
|
||||||
|
methods = [item.get("method") for item in decoded]
|
||||||
|
self.assertIn("tool/call/approve", methods)
|
||||||
|
self.assertIn("tool/invokeResult", methods)
|
||||||
|
|
||||||
|
approve = next(item for item in decoded if item.get("method") == "tool/call/approve")
|
||||||
|
self.assertEqual(
|
||||||
|
approve["params"],
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"approval": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_result = next(item for item in decoded if item.get("method") == "tool/invokeResult")
|
||||||
|
self.assertEqual(invoke_result["params"]["toolCallId"], "call-1")
|
||||||
|
self.assertEqual(invoke_result["params"]["name"], "run_in_terminal")
|
||||||
|
self.assertTrue(invoke_result["params"]["success"])
|
||||||
|
self.assertEqual(invoke_result["params"]["errorMessage"], "")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_tool_sync_does_not_emit_roundtrip_without_request_id(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
class _WsStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.frames: list[bytes] = []
|
||||||
|
|
||||||
|
async def send(self, data: bytes) -> None:
|
||||||
|
self.frames.append(data)
|
||||||
|
|
||||||
|
ws = _WsStub()
|
||||||
|
rpc = LspWsRpcClient(ws)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/call/sync",
|
||||||
|
"params": {
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"name": "run_in_terminal",
|
||||||
|
"parameters": {"command": "pwd"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(ws.frames, [])
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|||||||
Reference in New Issue
Block a user