fix: harden tooling session reuse and event routing
Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class Settings:
|
|||||||
session_reuse_enabled: bool = True
|
session_reuse_enabled: bool = True
|
||||||
session_cache_max_entries: int = 256
|
session_cache_max_entries: int = 256
|
||||||
session_cache_ttl_sec: float = 1800.0
|
session_cache_ttl_sec: float = 1800.0
|
||||||
|
tool_forward_enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
def _bool_env(name: str, default: bool) -> bool:
|
def _bool_env(name: str, default: bool) -> bool:
|
||||||
@@ -175,4 +176,5 @@ def load_settings() -> Settings:
|
|||||||
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
|
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
|
||||||
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
|
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
|
||||||
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
|
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
|
||||||
|
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ class LspWsRpcClient:
|
|||||||
self._reader_task: asyncio.Task | None = None
|
self._reader_task: asyncio.Task | None = None
|
||||||
self._rx_buffer = b""
|
self._rx_buffer = b""
|
||||||
self._chat_streams: dict[str, dict] = {}
|
self._chat_streams: dict[str, dict] = {}
|
||||||
|
self._tool_stream_map: dict[str, str] = {}
|
||||||
self._on_disconnect = on_disconnect
|
self._on_disconnect = on_disconnect
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
@@ -202,6 +203,7 @@ class LspWsRpcClient:
|
|||||||
stream["done"].set()
|
stream["done"].set()
|
||||||
stream["chunks"].put_nowait(None)
|
stream["chunks"].put_nowait(None)
|
||||||
self._chat_streams.clear()
|
self._chat_streams.clear()
|
||||||
|
self._tool_stream_map.clear()
|
||||||
|
|
||||||
async def _send(self, payload: dict):
|
async def _send(self, payload: dict):
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
@@ -251,6 +253,92 @@ class LspWsRpcClient:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("on_disconnect callback failed")
|
logger.exception("on_disconnect callback failed")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
|
||||||
|
event_id = None
|
||||||
|
if isinstance(tool_event, dict):
|
||||||
|
event_id = tool_event.get("id")
|
||||||
|
if isinstance(event_id, str) and event_id.strip():
|
||||||
|
return event_id.strip()
|
||||||
|
|
||||||
|
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
|
||||||
|
if isinstance(fallback_id, str) and fallback_id.strip():
|
||||||
|
return fallback_id.strip()
|
||||||
|
|
||||||
|
req_id = params.get("requestId")
|
||||||
|
name = None
|
||||||
|
if isinstance(tool_event, dict):
|
||||||
|
name = tool_event.get("name")
|
||||||
|
if not name:
|
||||||
|
name = params.get("name")
|
||||||
|
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
|
||||||
|
return f"{req_id.strip()}:tool:{name.strip()}"
|
||||||
|
if isinstance(req_id, str) and req_id.strip():
|
||||||
|
return f"{req_id.strip()}:tool"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||||
|
merged = dict(existing or {})
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
val = incoming.get("id")
|
||||||
|
if val and merged.get("id") != val:
|
||||||
|
merged["id"] = val
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
name = incoming.get("name")
|
||||||
|
if name:
|
||||||
|
existing_name = merged.get("name")
|
||||||
|
if not existing_name:
|
||||||
|
merged["name"] = name
|
||||||
|
changed = True
|
||||||
|
else:
|
||||||
|
existing_norm = str(existing_name).strip().lower()
|
||||||
|
incoming_norm = str(name).strip().lower()
|
||||||
|
if existing_norm == "tool" and incoming_norm != "tool":
|
||||||
|
merged["name"] = name
|
||||||
|
changed = True
|
||||||
|
elif existing_norm != "tool" and incoming_norm == "tool":
|
||||||
|
pass
|
||||||
|
elif merged.get("name") != name:
|
||||||
|
merged["name"] = name
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if "input" in incoming and incoming.get("input") is not None:
|
||||||
|
incoming_input = incoming.get("input")
|
||||||
|
should_update_input = incoming_input != {} or "input" not in merged
|
||||||
|
if should_update_input and merged.get("input") != incoming_input:
|
||||||
|
merged["input"] = incoming_input
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if "result" in incoming and incoming.get("result") is not None:
|
||||||
|
if merged.get("result") != incoming.get("result"):
|
||||||
|
merged["result"] = incoming.get("result")
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return merged, changed
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||||
|
req_id = params.get("requestId")
|
||||||
|
if isinstance(req_id, str) and req_id.strip():
|
||||||
|
stream = self._chat_streams.get(req_id)
|
||||||
|
if stream is not None and tool_event is not None:
|
||||||
|
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||||
|
if tool_id:
|
||||||
|
self._tool_stream_map[tool_id] = req_id
|
||||||
|
return stream
|
||||||
|
|
||||||
|
if tool_event is not None:
|
||||||
|
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||||
|
if tool_id:
|
||||||
|
mapped_req = self._tool_stream_map.get(tool_id)
|
||||||
|
if mapped_req:
|
||||||
|
return self._chat_streams.get(mapped_req)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def _handle_server_message(self, msg: dict):
|
async def _handle_server_message(self, msg: dict):
|
||||||
method = msg.get("method")
|
method = msg.get("method")
|
||||||
params = msg.get("params") or {}
|
params = msg.get("params") or {}
|
||||||
@@ -268,21 +356,29 @@ class LspWsRpcClient:
|
|||||||
|
|
||||||
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
|
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
|
||||||
tool_event = self._extract_tool_event(params)
|
tool_event = self._extract_tool_event(params)
|
||||||
req_id = params.get("requestId")
|
stream = self._resolve_tool_stream(method, params, tool_event)
|
||||||
stream = self._chat_streams.get(req_id) if req_id else None
|
|
||||||
|
|
||||||
if stream is None and tool_event is not None:
|
|
||||||
for item in self._chat_streams.values():
|
|
||||||
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
|
|
||||||
stream = item
|
|
||||||
break
|
|
||||||
|
|
||||||
if stream is None and len(self._chat_streams) == 1:
|
|
||||||
stream = next(iter(self._chat_streams.values()))
|
|
||||||
|
|
||||||
if stream is not None and tool_event is not None:
|
if stream is not None and tool_event is not None:
|
||||||
stream["tool_events"].append(tool_event)
|
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||||
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
|
if not tool_id:
|
||||||
|
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||||
|
else:
|
||||||
|
tool_states = stream["tool_states"]
|
||||||
|
order = stream["tool_order"]
|
||||||
|
existing = tool_states.get(tool_id)
|
||||||
|
merged, changed = self._merge_tool_event(existing, tool_event)
|
||||||
|
if not existing:
|
||||||
|
if "id" not in merged or not merged.get("id"):
|
||||||
|
merged["id"] = tool_id
|
||||||
|
tool_states[tool_id] = merged
|
||||||
|
order.append(tool_id)
|
||||||
|
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||||
|
elif changed:
|
||||||
|
tool_states[tool_id] = merged
|
||||||
|
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||||
|
elif tool_event is not None:
|
||||||
|
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
|
||||||
|
|
||||||
|
|
||||||
if method == "chat/finish":
|
if method == "chat/finish":
|
||||||
req_id = params.get("requestId")
|
req_id = params.get("requestId")
|
||||||
@@ -321,7 +417,8 @@ class LspWsRpcClient:
|
|||||||
"chunks": asyncio.Queue(),
|
"chunks": asyncio.Queue(),
|
||||||
"done": asyncio.Event(),
|
"done": asyncio.Event(),
|
||||||
"finish": None,
|
"finish": None,
|
||||||
"tool_events": [],
|
"tool_states": {},
|
||||||
|
"tool_order": [],
|
||||||
"started_at": time.monotonic(),
|
"started_at": time.monotonic(),
|
||||||
"first_chunk_at": None,
|
"first_chunk_at": None,
|
||||||
"finish_at": None,
|
"finish_at": None,
|
||||||
@@ -331,6 +428,9 @@ class LspWsRpcClient:
|
|||||||
stream = self._chat_streams.pop(request_id, None)
|
stream = self._chat_streams.pop(request_id, None)
|
||||||
if stream is None:
|
if stream is None:
|
||||||
return
|
return
|
||||||
|
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||||
|
if mapped_req == request_id:
|
||||||
|
self._tool_stream_map.pop(tool_id, None)
|
||||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||||
if not stream["done"].is_set():
|
if not stream["done"].is_set():
|
||||||
stream["done"].set()
|
stream["done"].set()
|
||||||
@@ -359,12 +459,20 @@ class LspWsRpcClient:
|
|||||||
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
||||||
if stream.get("finish_at") is not None:
|
if stream.get("finish_at") is not None:
|
||||||
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
||||||
|
|
||||||
|
ordered_tool_events: list[dict[str, Any]] = []
|
||||||
|
tool_states = stream.get("tool_states") or {}
|
||||||
|
for tool_id in stream.get("tool_order") or []:
|
||||||
|
event = tool_states.get(tool_id)
|
||||||
|
if isinstance(event, dict):
|
||||||
|
ordered_tool_events.append(event)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text": "".join(stream.get("parts") or []),
|
"text": "".join(stream.get("parts") or []),
|
||||||
"finish": stream.get("finish") or {},
|
"finish": stream.get("finish") or {},
|
||||||
"firstTokenLatencyMs": first_ms,
|
"firstTokenLatencyMs": first_ms,
|
||||||
"totalLatencyMs": total_ms,
|
"totalLatencyMs": total_ms,
|
||||||
"toolEvents": stream.get("tool_events") or [],
|
"toolEvents": ordered_tool_events,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
*,
|
*,
|
||||||
is_reply: bool = False,
|
is_reply: bool = False,
|
||||||
|
tool_config: dict[str, Any] | None = None,
|
||||||
):
|
):
|
||||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
session_type = "developer" if ask_mode == "agent" else "chat"
|
||||||
return {
|
payload = {
|
||||||
"requestId": request_id,
|
"requestId": request_id,
|
||||||
"sessionId": session_id,
|
"sessionId": session_id,
|
||||||
"sessionType": session_type,
|
"sessionType": session_type,
|
||||||
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
|
|||||||
"localeLang": "zh-CN",
|
"localeLang": "zh-CN",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
if tool_config is not None:
|
||||||
|
payload["toolConfig"] = tool_config
|
||||||
|
return payload
|
||||||
|
|
||||||
async def _kick_chat_ask(self, payload: dict) -> None:
|
async def _kick_chat_ask(self, payload: dict) -> None:
|
||||||
"""Fire chat/ask as a notification.
|
"""Fire chat/ask as a notification.
|
||||||
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
|
|||||||
*,
|
*,
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
is_reply: bool = False,
|
is_reply: bool = False,
|
||||||
|
tool_config: dict[str, Any] | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
await self.ensure_ready()
|
await self.ensure_ready()
|
||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
sid = session_id or str(uuid.uuid4())
|
sid = session_id or str(uuid.uuid4())
|
||||||
payload = self._build_payload(
|
payload = self._build_payload(
|
||||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
prompt,
|
||||||
|
model_key,
|
||||||
|
ask_mode,
|
||||||
|
sid,
|
||||||
|
request_id,
|
||||||
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
self.rpc.create_stream(request_id)
|
self.rpc.create_stream(request_id)
|
||||||
try:
|
try:
|
||||||
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
|
|||||||
*,
|
*,
|
||||||
session_id: str | None = None,
|
session_id: str | None = None,
|
||||||
is_reply: bool = False,
|
is_reply: bool = False,
|
||||||
|
tool_config: dict[str, Any] | None = None,
|
||||||
out_meta: dict | None = None,
|
out_meta: dict | None = None,
|
||||||
) -> AsyncIterator[dict[str, Any]]:
|
) -> AsyncIterator[dict[str, Any]]:
|
||||||
"""Stream chat events.
|
"""Stream chat events.
|
||||||
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
|
|||||||
request_id = str(uuid.uuid4())
|
request_id = str(uuid.uuid4())
|
||||||
sid = session_id or str(uuid.uuid4())
|
sid = session_id or str(uuid.uuid4())
|
||||||
payload = self._build_payload(
|
payload = self._build_payload(
|
||||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
prompt,
|
||||||
|
model_key,
|
||||||
|
ask_mode,
|
||||||
|
sid,
|
||||||
|
request_id,
|
||||||
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
self.rpc.create_stream(request_id)
|
self.rpc.create_stream(request_id)
|
||||||
try:
|
try:
|
||||||
|
|||||||
180
app/main.py
180
app/main.py
@@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool:
|
|||||||
return bool(stream_options.get("include_usage"))
|
return bool(stream_options.get("include_usage"))
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
|
||||||
|
if not settings.tool_forward_enabled:
|
||||||
|
return None
|
||||||
|
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||||
|
has_choice = req.tool_choice is not None
|
||||||
|
if not has_tools and not has_choice:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"provider": "openai",
|
||||||
|
"tools": req.tools or [],
|
||||||
|
"tool_choice": req.tool_choice,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
|
||||||
|
if not settings.tool_forward_enabled:
|
||||||
|
return None
|
||||||
|
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||||
|
has_choice = req.tool_choice is not None
|
||||||
|
if not has_tools and not has_choice:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"provider": "anthropic",
|
||||||
|
"tools": req.tools or [],
|
||||||
|
"tool_choice": req.tool_choice,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
|
||||||
|
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||||
|
return True
|
||||||
|
if req.tool_choice is not None:
|
||||||
|
return True
|
||||||
|
for m in messages:
|
||||||
|
role = m.get("role")
|
||||||
|
if role == "tool":
|
||||||
|
return True
|
||||||
|
if role == "assistant" and m.get("tool_calls"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return False
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
|
||||||
|
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||||
|
return True
|
||||||
|
if req.tool_choice is not None:
|
||||||
|
return True
|
||||||
|
if _anthropic_content_has_tool_blocks(req.system):
|
||||||
|
return True
|
||||||
|
for m in req.messages:
|
||||||
|
if _anthropic_content_has_tool_blocks(m.content):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _stream_event_type(event: Any) -> str:
|
def _stream_event_type(event: Any) -> str:
|
||||||
if isinstance(event, dict):
|
if isinstance(event, dict):
|
||||||
t = event.get("type")
|
t = event.get("type")
|
||||||
@@ -388,9 +452,9 @@ def _json_string(value: Any) -> str:
|
|||||||
return "{}"
|
return "{}"
|
||||||
|
|
||||||
|
|
||||||
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
|
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
|
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": str(tool.get("name") or "tool"),
|
"name": str(tool.get("name") or "tool"),
|
||||||
@@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
|
def _anthropic_tool_use_block(
|
||||||
|
tool: dict[str, Any], *, forced_id: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "tool_use",
|
"type": "tool_use",
|
||||||
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
|
"id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
|
||||||
"name": str(tool.get("name") or "tool"),
|
"name": str(tool.get("name") or "tool"),
|
||||||
"input": tool.get("input") if tool.get("input") is not None else {},
|
"input": tool.get("input") if tool.get("input") is not None else {},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
|
def _anthropic_tool_result_block(
|
||||||
|
tool: dict[str, Any], *, forced_id: str | None = None
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
if "result" not in tool:
|
if "result" not in tool:
|
||||||
return None
|
return None
|
||||||
result = tool.get("result")
|
result = tool.get("result")
|
||||||
@@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
|
|||||||
content = _json_string(result)
|
content = _json_string(result)
|
||||||
return {
|
return {
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": str(tool.get("id") or ""),
|
"tool_use_id": str(tool.get("id") or forced_id or ""),
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||||
ask_mode = "agent"
|
ask_mode = "agent"
|
||||||
|
|
||||||
|
tool_config = _openai_tool_config(req)
|
||||||
|
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
||||||
|
|
||||||
reuse_eligible = (
|
reuse_eligible = (
|
||||||
session_cache.enabled
|
session_cache.enabled
|
||||||
and ask_mode == "chat"
|
and ask_mode == "chat"
|
||||||
and len(messages_dump) >= 2
|
and len(messages_dump) >= 2
|
||||||
|
and not has_tooling_context
|
||||||
)
|
)
|
||||||
lookup_key: str | None = None
|
lookup_key: str | None = None
|
||||||
write_key: str | None = None
|
write_key: str | None = None
|
||||||
cached_session_id: str | None = None
|
cached_session_id: str | None = None
|
||||||
cached_instance_name: str | None = None
|
cached_instance_name: str | None = None
|
||||||
if reuse_eligible:
|
if reuse_eligible:
|
||||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||||
write_key = session_cache.build_key(api_key, messages_dump)
|
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||||
entry = await session_cache.get(lookup_key)
|
entry = await session_cache.get(lookup_key)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
cached_session_id = entry.session_id
|
cached_session_id = entry.session_id
|
||||||
@@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
|
|
||||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||||
success = False
|
success = False
|
||||||
|
tool_call_indexes: dict[str, int] = {}
|
||||||
|
saw_tool_call = False
|
||||||
try:
|
try:
|
||||||
async for chunk in _inst.client.chat_stream(
|
async for chunk in _inst.client.chat_stream(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
ask_mode,
|
ask_mode,
|
||||||
session_id=cached_session_id,
|
session_id=cached_session_id,
|
||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
out_meta=_meta,
|
out_meta=_meta,
|
||||||
):
|
):
|
||||||
if _stream_event_type(chunk) == "tool":
|
if _stream_event_type(chunk) == "tool":
|
||||||
tool = _stream_tool_event(chunk)
|
tool = _stream_tool_event(chunk)
|
||||||
if not tool:
|
if not tool:
|
||||||
continue
|
continue
|
||||||
|
tool_id = str(tool.get("id") or "")
|
||||||
|
if not tool_id:
|
||||||
|
tool_id = f"call_{len(tool_call_indexes)}"
|
||||||
|
idx = tool_call_indexes.get(tool_id)
|
||||||
|
if idx is None:
|
||||||
|
idx = len(tool_call_indexes)
|
||||||
|
tool_call_indexes[tool_id] = idx
|
||||||
|
saw_tool_call = True
|
||||||
payload = {
|
payload = {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
@@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
"delta": {
|
"delta": {
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": idx,
|
||||||
**_openai_tool_call(tool),
|
**_openai_tool_call(tool, forced_id=tool_id),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": created,
|
"created": created,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "tool_calls" if saw_tool_call else "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
if include_usage:
|
if include_usage:
|
||||||
usage_payload = {
|
usage_payload = {
|
||||||
"id": completion_id,
|
"id": completion_id,
|
||||||
@@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
ask_mode,
|
ask_mode,
|
||||||
session_id=cached_session_id,
|
session_id=cached_session_id,
|
||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
||||||
@@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
tool_events = result.get("toolEvents") or []
|
tool_events = result.get("toolEvents") or []
|
||||||
message_content = result.get("text") or ""
|
message_content = result.get("text") or ""
|
||||||
tool_calls: list[dict[str, Any]] = []
|
tool_calls: list[dict[str, Any]] = []
|
||||||
|
saw_tool_call = False
|
||||||
if isinstance(tool_events, list):
|
if isinstance(tool_events, list):
|
||||||
for item in tool_events:
|
for idx, item in enumerate(tool_events):
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
tool_calls.append(_openai_tool_call(item))
|
tool_id = str(item.get("id") or f"call_{idx}")
|
||||||
|
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||||
|
saw_tool_call = True
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
@@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
choices=[
|
choices=[
|
||||||
ChatCompletionChoice(
|
ChatCompletionChoice(
|
||||||
index=0,
|
index=0,
|
||||||
finish_reason="stop",
|
finish_reason="tool_calls" if saw_tool_call else "stop",
|
||||||
message={
|
message={
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": message_content,
|
"content": message_content,
|
||||||
@@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
data = response.model_dump()
|
data = response.model_dump()
|
||||||
data["latency"] = {
|
data["latency"] = {
|
||||||
"first_token_ms": result.get("firstTokenLatencyMs"),
|
"first_token_ms": result.get("firstTokenLatencyMs"),
|
||||||
@@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str:
|
def _anthropic_stop_reason(
|
||||||
"""Approximate Anthropic `stop_reason`.
|
completion_tokens: int,
|
||||||
|
max_tokens: int,
|
||||||
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it;
|
*,
|
||||||
we report `max_tokens` only when the generated length meets or exceeds
|
has_pending_tool_use: bool = False,
|
||||||
the caller's stated ceiling. Everything else is `end_turn`.
|
) -> str:
|
||||||
"""
|
"""Approximate Anthropic `stop_reason`."""
|
||||||
|
if has_pending_tool_use:
|
||||||
|
return "tool_use"
|
||||||
if max_tokens and completion_tokens >= max_tokens:
|
if max_tokens and completion_tokens >= max_tokens:
|
||||||
return "max_tokens"
|
return "max_tokens"
|
||||||
return "end_turn"
|
return "end_turn"
|
||||||
@@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
|
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
|
||||||
ask_mode = "chat"
|
ask_mode = "chat"
|
||||||
|
|
||||||
|
tool_config = _anthropic_tool_config(req)
|
||||||
|
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||||
|
|
||||||
reuse_eligible = (
|
reuse_eligible = (
|
||||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2
|
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
||||||
)
|
)
|
||||||
lookup_key: str | None = None
|
lookup_key: str | None = None
|
||||||
write_key: str | None = None
|
write_key: str | None = None
|
||||||
cached_session_id: str | None = None
|
cached_session_id: str | None = None
|
||||||
cached_instance_name: str | None = None
|
cached_instance_name: str | None = None
|
||||||
if reuse_eligible:
|
if reuse_eligible:
|
||||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||||
write_key = session_cache.build_key(api_key, messages_dump)
|
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||||
entry = await session_cache.get(lookup_key)
|
entry = await session_cache.get(lookup_key)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
cached_session_id = entry.session_id
|
cached_session_id = entry.session_id
|
||||||
@@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
return _anthropic_error(400, "invalid_request_error", "messages is empty")
|
return _anthropic_error(400, "invalid_request_error", "messages is empty")
|
||||||
|
|
||||||
prompt_tokens = estimate_tokens(prompt)
|
prompt_tokens = estimate_tokens(prompt)
|
||||||
|
|
||||||
# ------------------------------------------------------------- backpressure
|
# ------------------------------------------------------------- backpressure
|
||||||
try:
|
try:
|
||||||
ticket = await chat_guard.try_acquire()
|
ticket = await chat_guard.try_acquire()
|
||||||
@@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
success = False
|
success = False
|
||||||
block_index = 0
|
block_index = 0
|
||||||
text_block_open = False
|
text_block_open = False
|
||||||
|
saw_pending_tool_use = False
|
||||||
try:
|
try:
|
||||||
# 1) message_start — Anthropic SDKs read this first to get
|
# 1) message_start — Anthropic SDKs read this first to get
|
||||||
# the message envelope (id/model/initial usage).
|
# the message envelope (id/model/initial usage).
|
||||||
@@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
ask_mode,
|
ask_mode,
|
||||||
session_id=cached_session_id,
|
session_id=cached_session_id,
|
||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
out_meta=_meta,
|
out_meta=_meta,
|
||||||
):
|
):
|
||||||
if _stream_event_type(chunk) == "tool":
|
if _stream_event_type(chunk) == "tool":
|
||||||
@@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
tool = _stream_tool_event(chunk)
|
tool = _stream_tool_event(chunk)
|
||||||
if not tool:
|
if not tool:
|
||||||
continue
|
continue
|
||||||
|
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
|
||||||
|
|
||||||
tool_use_block = _anthropic_tool_use_block(tool)
|
tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
|
||||||
yield _sse(
|
yield _sse(
|
||||||
"content_block_start",
|
"content_block_start",
|
||||||
{
|
{
|
||||||
@@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
)
|
)
|
||||||
block_index += 1
|
block_index += 1
|
||||||
|
|
||||||
tool_result_block = _anthropic_tool_result_block(tool)
|
tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
|
||||||
if tool_result_block is not None:
|
if tool_result_block is not None:
|
||||||
yield _sse(
|
yield _sse(
|
||||||
"content_block_start",
|
"content_block_start",
|
||||||
@@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
{"type": "content_block_stop", "index": block_index},
|
{"type": "content_block_stop", "index": block_index},
|
||||||
)
|
)
|
||||||
block_index += 1
|
block_index += 1
|
||||||
|
else:
|
||||||
|
saw_pending_tool_use = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
text = _stream_text(chunk)
|
text = _stream_text(chunk)
|
||||||
@@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
# 5) message_delta carries the terminal stop_reason and
|
# 5) message_delta carries the terminal stop_reason and
|
||||||
# the final cumulative output_tokens count.
|
# the final cumulative output_tokens count.
|
||||||
stop_reason = _anthropic_stop_reason(
|
stop_reason = _anthropic_stop_reason(
|
||||||
completion_tokens_holder["n"], max_tokens
|
completion_tokens_holder["n"],
|
||||||
|
max_tokens,
|
||||||
|
has_pending_tool_use=saw_pending_tool_use,
|
||||||
)
|
)
|
||||||
yield _sse(
|
yield _sse(
|
||||||
"message_delta",
|
"message_delta",
|
||||||
@@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 6) message_stop — terminal event, no [DONE] sentinel.
|
# 6) message_stop — terminal event, no [DONE] sentinel.
|
||||||
yield _sse("message_stop", {"type": "message_stop"})
|
yield _sse("message_stop", {"type": "message_stop"})
|
||||||
success = True
|
success = True
|
||||||
@@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
ask_mode,
|
ask_mode,
|
||||||
session_id=cached_session_id,
|
session_id=cached_session_id,
|
||||||
is_reply=is_reply,
|
is_reply=is_reply,
|
||||||
|
tool_config=tool_config,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
|
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
|
||||||
@@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
if text:
|
if text:
|
||||||
content_blocks.append({"type": "text", "text": text})
|
content_blocks.append({"type": "text", "text": text})
|
||||||
tool_events = result.get("toolEvents") or []
|
tool_events = result.get("toolEvents") or []
|
||||||
|
saw_pending_tool_use = False
|
||||||
if isinstance(tool_events, list):
|
if isinstance(tool_events, list):
|
||||||
for item in tool_events:
|
for idx, item in enumerate(tool_events):
|
||||||
if not isinstance(item, dict):
|
if not isinstance(item, dict):
|
||||||
continue
|
continue
|
||||||
content_blocks.append(_anthropic_tool_use_block(item))
|
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
||||||
tool_result = _anthropic_tool_result_block(item)
|
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
||||||
|
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
||||||
if tool_result is not None:
|
if tool_result is not None:
|
||||||
content_blocks.append(tool_result)
|
content_blocks.append(tool_result)
|
||||||
|
else:
|
||||||
|
saw_pending_tool_use = True
|
||||||
|
|
||||||
response_body: dict = {
|
response_body: dict = {
|
||||||
"id": message_id,
|
"id": message_id,
|
||||||
@@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"model": model,
|
"model": model,
|
||||||
"content": content_blocks,
|
"content": content_blocks,
|
||||||
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
|
"stop_reason": _anthropic_stop_reason(
|
||||||
|
completion_tokens,
|
||||||
|
req.max_tokens,
|
||||||
|
has_pending_tool_use=saw_pending_tool_use,
|
||||||
|
),
|
||||||
"stop_sequence": None,
|
"stop_sequence": None,
|
||||||
"usage": {
|
"usage": {
|
||||||
"input_tokens": prompt_tokens,
|
"input_tokens": prompt_tokens,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str:
|
|||||||
return h.hexdigest()
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_fingerprint(tool_config: dict | None) -> str:
|
||||||
|
if not isinstance(tool_config, dict):
|
||||||
|
return "-"
|
||||||
|
try:
|
||||||
|
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||||
|
except Exception:
|
||||||
|
canonical = str(tool_config)
|
||||||
|
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
class SessionCache:
|
class SessionCache:
|
||||||
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
||||||
|
|
||||||
@@ -79,11 +90,11 @@ class SessionCache:
|
|||||||
def enabled(self) -> bool:
|
def enabled(self) -> bool:
|
||||||
return self.max > 0
|
return self.max > 0
|
||||||
|
|
||||||
def build_key(self, api_key: str, messages: list[dict]) -> str:
|
def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
|
||||||
# API key scoping prevents cross-tenant session leakage even when
|
# API key scoping prevents cross-tenant session leakage even when
|
||||||
# different clients happen to produce identical histories.
|
# different clients happen to produce identical histories.
|
||||||
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||||
return f"{key_scope}:{hash_user_context(messages)}"
|
return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
||||||
|
|
||||||
async def get(self, key: str) -> SessionEntry | None:
|
async def get(self, key: str) -> SessionEntry | None:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
|
|||||||
@@ -6,6 +6,31 @@ import types
|
|||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSessionCache:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.enabled = True
|
||||||
|
self.keys: list[str] = []
|
||||||
|
self.get_calls: list[str] = []
|
||||||
|
self.put_calls: list[tuple[str, str, str]] = []
|
||||||
|
self.invalidate_calls: list[str] = []
|
||||||
|
|
||||||
|
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
|
||||||
|
marker = "with_tool" if tool_config is not None else "no_tool"
|
||||||
|
key = f"{api_key}:{len(messages)}:{marker}"
|
||||||
|
self.keys.append(key)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def get(self, key: str):
|
||||||
|
self.get_calls.append(key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
|
||||||
|
self.put_calls.append((key, session_id, instance_name))
|
||||||
|
|
||||||
|
async def invalidate(self, key: str) -> None:
|
||||||
|
self.invalidate_calls.append(key)
|
||||||
|
|
||||||
# app.main imports playwright via auto_login; tests don't exercise that path.
|
# app.main imports playwright via auto_login; tests don't exercise that path.
|
||||||
# Inject a lightweight stub so unit tests run without installing playwright.
|
# Inject a lightweight stub so unit tests run without installing playwright.
|
||||||
_playwright = types.ModuleType("playwright")
|
_playwright = types.ModuleType("playwright")
|
||||||
@@ -119,6 +144,38 @@ async def _collect_stream(response) -> str:
|
|||||||
return "".join(chunks)
|
return "".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
class _SpyClient(_FakeClient):
|
||||||
|
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
|
||||||
|
super().__init__(stream_events=stream_events, complete_result=complete_result)
|
||||||
|
self.last_complete_kwargs: dict = {}
|
||||||
|
self.last_stream_kwargs: dict = {}
|
||||||
|
|
||||||
|
async def chat_complete(self, *args, **kwargs) -> dict:
|
||||||
|
self.last_complete_kwargs = dict(kwargs)
|
||||||
|
return await super().chat_complete(*args, **kwargs)
|
||||||
|
|
||||||
|
async def chat_stream(self, *args, **kwargs):
|
||||||
|
self.last_stream_kwargs = dict(kwargs)
|
||||||
|
async for event in super().chat_stream(*args, **kwargs):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
class _SettingsPatch:
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()]
|
||||||
|
for p in self._patchers:
|
||||||
|
p.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
for p in reversed(self._patchers):
|
||||||
|
p.stop()
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
||||||
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
|
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
@@ -156,6 +213,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
message = payload["choices"][0]["message"]
|
message = payload["choices"][0]["message"]
|
||||||
self.assertEqual(message["content"], "done")
|
self.assertEqual(message["content"], "done")
|
||||||
self.assertIsInstance(message["tool_calls"], list)
|
self.assertIsInstance(message["tool_calls"], list)
|
||||||
|
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
|
||||||
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
|
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
json.loads(message["tool_calls"][0]["function"]["arguments"]),
|
||||||
@@ -195,6 +253,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
self.assertIn('"tool_calls"', body)
|
self.assertIn('"tool_calls"', body)
|
||||||
self.assertIn('"content": "hello"', body)
|
self.assertIn('"content": "hello"', body)
|
||||||
|
self.assertIn('"finish_reason": "tool_calls"', body)
|
||||||
self.assertIn('"usage"', body)
|
self.assertIn('"usage"', body)
|
||||||
self.assertIn("data: [DONE]", body)
|
self.assertIn("data: [DONE]", body)
|
||||||
|
|
||||||
@@ -239,9 +298,132 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
payload = json.loads(response.body)
|
payload = json.loads(response.body)
|
||||||
types = [item["type"] for item in payload["content"]]
|
types = [item["type"] for item in payload["content"]]
|
||||||
self.assertEqual(types, ["text", "tool_use", "tool_result"])
|
self.assertEqual(types, ["text", "tool_use", "tool_result"])
|
||||||
|
self.assertEqual(payload["stop_reason"], "end_turn")
|
||||||
self.assertEqual(payload["content"][1]["name"], "lookup")
|
self.assertEqual(payload["content"][1]["name"], "lookup")
|
||||||
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
|
||||||
|
|
||||||
|
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"tool": {
|
||||||
|
"id": "call_a",
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {"path": "README.md"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"tool": {
|
||||||
|
"id": "call_b",
|
||||||
|
"name": "search_docs",
|
||||||
|
"input": {"query": "gateway"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
complete_result={},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
):
|
||||||
|
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
body = await _collect_stream(response)
|
||||||
|
|
||||||
|
self.assertIn('"id": "call_a"', body)
|
||||||
|
self.assertIn('"id": "call_b"', body)
|
||||||
|
self.assertIn('"index": 0', body)
|
||||||
|
self.assertIn('"index": 1', body)
|
||||||
|
|
||||||
|
async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={
|
||||||
|
"text": "",
|
||||||
|
"toolEvents": [
|
||||||
|
{
|
||||||
|
"name": "lookup",
|
||||||
|
"input": {"k": "v"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sessionId": "sess-2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = json.loads(response.body)
|
||||||
|
self.assertEqual(payload["stop_reason"], "tool_use")
|
||||||
|
self.assertEqual(len(payload["content"]), 1)
|
||||||
|
self.assertEqual(payload["content"][0]["type"], "tool_use")
|
||||||
|
|
||||||
|
async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[
|
||||||
|
{
|
||||||
|
"type": "tool",
|
||||||
|
"tool": {
|
||||||
|
"name": "read",
|
||||||
|
"input": {"file": "a.txt"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
complete_result={},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=256,
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
response = await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
body = await _collect_stream(response)
|
||||||
|
|
||||||
|
self.assertIn('"type": "tool_use"', body)
|
||||||
|
self.assertIn('"stop_reason": "tool_use"', body)
|
||||||
|
|
||||||
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
|
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
|
||||||
fake_client = _FakeClient(
|
fake_client = _FakeClient(
|
||||||
stream_events=[
|
stream_events=[
|
||||||
@@ -284,11 +466,262 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
|
|||||||
self.assertIn("event: message_start", body)
|
self.assertIn("event: message_start", body)
|
||||||
self.assertIn('"type": "tool_use"', body)
|
self.assertIn('"type": "tool_use"', body)
|
||||||
self.assertIn('"type": "tool_result"', body)
|
self.assertIn('"type": "tool_result"', body)
|
||||||
|
self.assertIn('"stop_reason": "end_turn"', body)
|
||||||
self.assertIn('"type": "text_delta"', body)
|
self.assertIn('"type": "text_delta"', body)
|
||||||
self.assertIn("event: message_stop", body)
|
self.assertIn("event: message_stop", body)
|
||||||
|
|
||||||
|
|
||||||
class LingmaClientToolEventExtractionTests(unittest.TestCase):
|
|
||||||
|
async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None:
|
||||||
|
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
_SettingsPatch(tool_forward_enabled=True),
|
||||||
|
):
|
||||||
|
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||||
|
cfg = spy_client.last_complete_kwargs["tool_config"]
|
||||||
|
self.assertEqual(cfg["provider"], "openai")
|
||||||
|
self.assertEqual(len(cfg["tools"]), 1)
|
||||||
|
self.assertIsInstance(cfg["tool_choice"], dict)
|
||||||
|
|
||||||
|
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
|
||||||
|
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
_SettingsPatch(tool_forward_enabled=False),
|
||||||
|
):
|
||||||
|
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
self.assertIn("tool_config", spy_client.last_complete_kwargs)
|
||||||
|
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||||
|
fake_cache = _FakeSessionCache()
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"},
|
||||||
|
)
|
||||||
|
req = ChatCompletionsRequest(
|
||||||
|
model="org_auto",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "turn-1"},
|
||||||
|
{"role": "user", "content": "turn-2"},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
tool_choice={"type": "function", "function": {"name": "lookup"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "session_cache", fake_cache),
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
_SettingsPatch(tool_forward_enabled=True),
|
||||||
|
):
|
||||||
|
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
|
||||||
|
|
||||||
|
self.assertEqual(fake_cache.keys, [])
|
||||||
|
self.assertEqual(fake_cache.get_calls, [])
|
||||||
|
self.assertEqual(fake_cache.put_calls, [])
|
||||||
|
|
||||||
|
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
|
||||||
|
fake_cache = _FakeSessionCache()
|
||||||
|
fake_client = _FakeClient(
|
||||||
|
stream_events=[],
|
||||||
|
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"},
|
||||||
|
)
|
||||||
|
req = AnthropicMessagesRequest(
|
||||||
|
model="claude-3-5-sonnet-20241022",
|
||||||
|
max_tokens=128,
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "turn-1"},
|
||||||
|
{"role": "user", "content": "turn-2"},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
|
||||||
|
tool_choice={"type": "auto"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(main, "session_cache", fake_cache),
|
||||||
|
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
|
||||||
|
patch.object(main, "chat_guard", _FakeGuard()),
|
||||||
|
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
|
||||||
|
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
|
||||||
|
patch.object(main.settings, "api_keys", ["test-key"]),
|
||||||
|
):
|
||||||
|
await main.v1_messages(
|
||||||
|
req,
|
||||||
|
_make_request(
|
||||||
|
"/v1/messages",
|
||||||
|
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(fake_cache.keys, [])
|
||||||
|
self.assertEqual(fake_cache.get_calls, [])
|
||||||
|
self.assertEqual(fake_cache.put_calls, [])
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCacheToolFingerprintTests(unittest.TestCase):
|
||||||
|
def test_build_key_changes_with_tool_config(self) -> None:
|
||||||
|
from app.session_cache import SessionCache
|
||||||
|
|
||||||
|
cache = SessionCache(max_entries=8, ttl_sec=60)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "sys"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
|
||||||
|
cfg_a = {
|
||||||
|
"provider": "openai",
|
||||||
|
"tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
|
||||||
|
"tool_choice": {"type": "function", "function": {"name": "lookup"}},
|
||||||
|
}
|
||||||
|
cfg_a_reordered = {
|
||||||
|
"tool_choice": {"function": {"name": "lookup"}, "type": "function"},
|
||||||
|
"tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}],
|
||||||
|
"provider": "openai",
|
||||||
|
}
|
||||||
|
cfg_b = {
|
||||||
|
"provider": "openai",
|
||||||
|
"tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}],
|
||||||
|
"tool_choice": {"type": "function", "function": {"name": "lookup_v2"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
key_no_tool = cache.build_key("api-key", messages)
|
||||||
|
key_a = cache.build_key("api-key", messages, tool_config=cfg_a)
|
||||||
|
key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered)
|
||||||
|
key_b = cache.build_key("api-key", messages, tool_config=cfg_b)
|
||||||
|
|
||||||
|
self.assertNotEqual(key_no_tool, key_a)
|
||||||
|
self.assertEqual(key_a, key_a_reordered)
|
||||||
|
self.assertNotEqual(key_a, key_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/invoke",
|
||||||
|
"params": {
|
||||||
|
"name": "lookup",
|
||||||
|
"parameters": {"q": "x"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
stream = rpc._chat_streams["req-1"]
|
||||||
|
self.assertEqual(stream["tool_order"], [])
|
||||||
|
self.assertEqual(stream["tool_states"], {})
|
||||||
|
self.assertTrue(stream["chunks"].empty())
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/invoke",
|
||||||
|
"params": {
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"name": "lookup",
|
||||||
|
"parameters": {"q": "a"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/invokeResult",
|
||||||
|
"params": {
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"result": {"ok": True},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = rpc.get_stream_result("req-1")
|
||||||
|
self.assertEqual(len(result["toolEvents"]), 1)
|
||||||
|
self.assertEqual(result["toolEvents"][0]["id"], "call-1")
|
||||||
|
self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"})
|
||||||
|
self.assertEqual(result["toolEvents"][0]["result"], {"ok": True})
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
rpc = LspWsRpcClient("ws://127.0.0.1:1")
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
msg = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/invoke",
|
||||||
|
"params": {
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-dup",
|
||||||
|
"name": "lookup",
|
||||||
|
"parameters": {"q": "dup"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await rpc._handle_server_message(msg)
|
||||||
|
await rpc._handle_server_message(msg)
|
||||||
|
|
||||||
|
stream = rpc._chat_streams["req-1"]
|
||||||
|
self.assertEqual(stream["tool_order"], ["call-dup"])
|
||||||
|
self.assertEqual(stream["chunks"].qsize(), 1)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
def test_extracts_tool_event_from_results_and_parameters(self) -> None:
|
def test_extracts_tool_event_from_results_and_parameters(self) -> None:
|
||||||
from app.lingma_client import LspWsRpcClient
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user