fix: harden tooling session reuse and event routing

Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-19 19:29:30 +08:00
parent 5aa7fbfae5
commit e600bae27c
5 changed files with 741 additions and 53 deletions

View File

@@ -44,6 +44,7 @@ class Settings:
session_reuse_enabled: bool = True session_reuse_enabled: bool = True
session_cache_max_entries: int = 256 session_cache_max_entries: int = 256
session_cache_ttl_sec: float = 1800.0 session_cache_ttl_sec: float = 1800.0
tool_forward_enabled: bool = False
def _bool_env(name: str, default: bool) -> bool: def _bool_env(name: str, default: bool) -> bool:
@@ -175,4 +176,5 @@ def load_settings() -> Settings:
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True), session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")), session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")), session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
) )

View File

@@ -100,6 +100,7 @@ class LspWsRpcClient:
self._reader_task: asyncio.Task | None = None self._reader_task: asyncio.Task | None = None
self._rx_buffer = b"" self._rx_buffer = b""
self._chat_streams: dict[str, dict] = {} self._chat_streams: dict[str, dict] = {}
self._tool_stream_map: dict[str, str] = {}
self._on_disconnect = on_disconnect self._on_disconnect = on_disconnect
self._closed = False self._closed = False
@@ -202,6 +203,7 @@ class LspWsRpcClient:
stream["done"].set() stream["done"].set()
stream["chunks"].put_nowait(None) stream["chunks"].put_nowait(None)
self._chat_streams.clear() self._chat_streams.clear()
self._tool_stream_map.clear()
async def _send(self, payload: dict): async def _send(self, payload: dict):
async with self._send_lock: async with self._send_lock:
@@ -251,6 +253,92 @@ class LspWsRpcClient:
except Exception: except Exception:
logger.exception("on_disconnect callback failed") logger.exception("on_disconnect callback failed")
@staticmethod
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
event_id = None
if isinstance(tool_event, dict):
event_id = tool_event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id.strip()
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
if isinstance(fallback_id, str) and fallback_id.strip():
return fallback_id.strip()
req_id = params.get("requestId")
name = None
if isinstance(tool_event, dict):
name = tool_event.get("name")
if not name:
name = params.get("name")
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
return f"{req_id.strip()}:tool:{name.strip()}"
if isinstance(req_id, str) and req_id.strip():
return f"{req_id.strip()}:tool"
return None
@staticmethod
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
merged = dict(existing or {})
changed = False
val = incoming.get("id")
if val and merged.get("id") != val:
merged["id"] = val
changed = True
name = incoming.get("name")
if name:
existing_name = merged.get("name")
if not existing_name:
merged["name"] = name
changed = True
else:
existing_norm = str(existing_name).strip().lower()
incoming_norm = str(name).strip().lower()
if existing_norm == "tool" and incoming_norm != "tool":
merged["name"] = name
changed = True
elif existing_norm != "tool" and incoming_norm == "tool":
pass
elif merged.get("name") != name:
merged["name"] = name
changed = True
if "input" in incoming and incoming.get("input") is not None:
incoming_input = incoming.get("input")
should_update_input = incoming_input != {} or "input" not in merged
if should_update_input and merged.get("input") != incoming_input:
merged["input"] = incoming_input
changed = True
if "result" in incoming and incoming.get("result") is not None:
if merged.get("result") != incoming.get("result"):
merged["result"] = incoming.get("result")
changed = True
return merged, changed
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
req_id = params.get("requestId")
if isinstance(req_id, str) and req_id.strip():
stream = self._chat_streams.get(req_id)
if stream is not None and tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
self._tool_stream_map[tool_id] = req_id
return stream
if tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
mapped_req = self._tool_stream_map.get(tool_id)
if mapped_req:
return self._chat_streams.get(mapped_req)
return None
async def _handle_server_message(self, msg: dict): async def _handle_server_message(self, msg: dict):
method = msg.get("method") method = msg.get("method")
params = msg.get("params") or {} params = msg.get("params") or {}
@@ -268,21 +356,29 @@ class LspWsRpcClient:
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}: if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
tool_event = self._extract_tool_event(params) tool_event = self._extract_tool_event(params)
req_id = params.get("requestId") stream = self._resolve_tool_stream(method, params, tool_event)
stream = self._chat_streams.get(req_id) if req_id else None
if stream is None and tool_event is not None:
for item in self._chat_streams.values():
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
stream = item
break
if stream is None and len(self._chat_streams) == 1:
stream = next(iter(self._chat_streams.values()))
if stream is not None and tool_event is not None: if stream is not None and tool_event is not None:
stream["tool_events"].append(tool_event) tool_id = self._normalize_tool_id(method, params, tool_event)
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event}) if not tool_id:
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
else:
tool_states = stream["tool_states"]
order = stream["tool_order"]
existing = tool_states.get(tool_id)
merged, changed = self._merge_tool_event(existing, tool_event)
if not existing:
if "id" not in merged or not merged.get("id"):
merged["id"] = tool_id
tool_states[tool_id] = merged
order.append(tool_id)
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif changed:
tool_states[tool_id] = merged
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif tool_event is not None:
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
if method == "chat/finish": if method == "chat/finish":
req_id = params.get("requestId") req_id = params.get("requestId")
@@ -321,7 +417,8 @@ class LspWsRpcClient:
"chunks": asyncio.Queue(), "chunks": asyncio.Queue(),
"done": asyncio.Event(), "done": asyncio.Event(),
"finish": None, "finish": None,
"tool_events": [], "tool_states": {},
"tool_order": [],
"started_at": time.monotonic(), "started_at": time.monotonic(),
"first_chunk_at": None, "first_chunk_at": None,
"finish_at": None, "finish_at": None,
@@ -331,6 +428,9 @@ class LspWsRpcClient:
stream = self._chat_streams.pop(request_id, None) stream = self._chat_streams.pop(request_id, None)
if stream is None: if stream is None:
return return
for tool_id, mapped_req in list(self._tool_stream_map.items()):
if mapped_req == request_id:
self._tool_stream_map.pop(tool_id, None)
# Drain queue so no stray future gets stuck if the consumer bailed early. # Drain queue so no stray future gets stuck if the consumer bailed early.
if not stream["done"].is_set(): if not stream["done"].is_set():
stream["done"].set() stream["done"].set()
@@ -359,12 +459,20 @@ class LspWsRpcClient:
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000) first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
if stream.get("finish_at") is not None: if stream.get("finish_at") is not None:
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000) total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
ordered_tool_events: list[dict[str, Any]] = []
tool_states = stream.get("tool_states") or {}
for tool_id in stream.get("tool_order") or []:
event = tool_states.get(tool_id)
if isinstance(event, dict):
ordered_tool_events.append(event)
return { return {
"text": "".join(stream.get("parts") or []), "text": "".join(stream.get("parts") or []),
"finish": stream.get("finish") or {}, "finish": stream.get("finish") or {},
"firstTokenLatencyMs": first_ms, "firstTokenLatencyMs": first_ms,
"totalLatencyMs": total_ms, "totalLatencyMs": total_ms,
"toolEvents": stream.get("tool_events") or [], "toolEvents": ordered_tool_events,
} }
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
request_id: str, request_id: str,
*, *,
is_reply: bool = False, is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
): ):
session_type = "developer" if ask_mode == "agent" else "chat" session_type = "developer" if ask_mode == "agent" else "chat"
return { payload = {
"requestId": request_id, "requestId": request_id,
"sessionId": session_id, "sessionId": session_id,
"sessionType": session_type, "sessionType": session_type,
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
"localeLang": "zh-CN", "localeLang": "zh-CN",
}, },
} }
if tool_config is not None:
payload["toolConfig"] = tool_config
return payload
async def _kick_chat_ask(self, payload: dict) -> None: async def _kick_chat_ask(self, payload: dict) -> None:
"""Fire chat/ask as a notification. """Fire chat/ask as a notification.
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
*, *,
session_id: str | None = None, session_id: str | None = None,
is_reply: bool = False, is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
) -> dict: ) -> dict:
await self.ensure_ready() await self.ensure_ready()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4()) sid = session_id or str(uuid.uuid4())
payload = self._build_payload( payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
) )
self.rpc.create_stream(request_id) self.rpc.create_stream(request_id)
try: try:
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
*, *,
session_id: str | None = None, session_id: str | None = None,
is_reply: bool = False, is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
out_meta: dict | None = None, out_meta: dict | None = None,
) -> AsyncIterator[dict[str, Any]]: ) -> AsyncIterator[dict[str, Any]]:
"""Stream chat events. """Stream chat events.
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4()) sid = session_id or str(uuid.uuid4())
payload = self._build_payload( payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
) )
self.rpc.create_stream(request_id) self.rpc.create_stream(request_id)
try: try:

View File

@@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage")) return bool(stream_options.get("include_usage"))
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "openai",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "anthropic",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
for m in messages:
role = m.get("role")
if role == "tool":
return True
if role == "assistant" and m.get("tool_calls"):
return True
return False
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
if not isinstance(content, list):
return False
for item in content:
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
return True
return False
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
if _anthropic_content_has_tool_blocks(req.system):
return True
for m in req.messages:
if _anthropic_content_has_tool_blocks(m.content):
return True
return False
def _stream_event_type(event: Any) -> str: def _stream_event_type(event: Any) -> str:
if isinstance(event, dict): if isinstance(event, dict):
t = event.get("type") t = event.get("type")
@@ -388,9 +452,9 @@ def _json_string(value: Any) -> str:
return "{}" return "{}"
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]: def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
return { return {
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"), "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
"type": "function", "type": "function",
"function": { "function": {
"name": str(tool.get("name") or "tool"), "name": str(tool.get("name") or "tool"),
@@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
} }
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]: def _anthropic_tool_use_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any]:
return { return {
"type": "tool_use", "type": "tool_use",
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"), "id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
"name": str(tool.get("name") or "tool"), "name": str(tool.get("name") or "tool"),
"input": tool.get("input") if tool.get("input") is not None else {}, "input": tool.get("input") if tool.get("input") is not None else {},
} }
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None: def _anthropic_tool_result_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any] | None:
if "result" not in tool: if "result" not in tool:
return None return None
result = tool.get("result") result = tool.get("result")
@@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
content = _json_string(result) content = _json_string(result)
return { return {
"type": "tool_result", "type": "tool_result",
"tool_use_id": str(tool.get("id") or ""), "tool_use_id": str(tool.get("id") or forced_id or ""),
"content": content, "content": content,
} }
@@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
if req.model.lower() in {"lingma-agent", "agent"}: if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent" ask_mode = "agent"
tool_config = _openai_tool_config(req)
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
reuse_eligible = ( reuse_eligible = (
session_cache.enabled session_cache.enabled
and ask_mode == "chat" and ask_mode == "chat"
and len(messages_dump) >= 2 and len(messages_dump) >= 2
and not has_tooling_context
) )
lookup_key: str | None = None lookup_key: str | None = None
write_key: str | None = None write_key: str | None = None
cached_session_id: str | None = None cached_session_id: str | None = None
cached_instance_name: str | None = None cached_instance_name: str | None = None
if reuse_eligible: if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump) write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key) entry = await session_cache.get(lookup_key)
if entry is not None: if entry is not None:
cached_session_id = entry.session_id cached_session_id = entry.session_id
@@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False success = False
tool_call_indexes: dict[str, int] = {}
saw_tool_call = False
try: try:
async for chunk in _inst.client.chat_stream( async for chunk in _inst.client.chat_stream(
prompt, prompt,
@@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode, ask_mode,
session_id=cached_session_id, session_id=cached_session_id,
is_reply=is_reply, is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta, out_meta=_meta,
): ):
if _stream_event_type(chunk) == "tool": if _stream_event_type(chunk) == "tool":
tool = _stream_tool_event(chunk) tool = _stream_tool_event(chunk)
if not tool: if not tool:
continue continue
tool_id = str(tool.get("id") or "")
if not tool_id:
tool_id = f"call_{len(tool_call_indexes)}"
idx = tool_call_indexes.get(tool_id)
if idx is None:
idx = len(tool_call_indexes)
tool_call_indexes[tool_id] = idx
saw_tool_call = True
payload = { payload = {
"id": completion_id, "id": completion_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"delta": { "delta": {
"tool_calls": [ "tool_calls": [
{ {
"index": 0, "index": idx,
**_openai_tool_call(tool), **_openai_tool_call(tool, forced_id=tool_id),
} }
] ]
}, },
@@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": created, "created": created,
"model": model, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "choices": [
{
"index": 0,
"delta": {},
"finish_reason": "tool_calls" if saw_tool_call else "stop",
}
],
} }
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
if include_usage: if include_usage:
usage_payload = { usage_payload = {
"id": completion_id, "id": completion_id,
@@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode, ask_mode,
session_id=cached_session_id, session_id=cached_session_id,
is_reply=is_reply, is_reply=is_reply,
tool_config=tool_config,
) )
except Exception as exc: except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
@@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
tool_events = result.get("toolEvents") or [] tool_events = result.get("toolEvents") or []
message_content = result.get("text") or "" message_content = result.get("text") or ""
tool_calls: list[dict[str, Any]] = [] tool_calls: list[dict[str, Any]] = []
saw_tool_call = False
if isinstance(tool_events, list): if isinstance(tool_events, list):
for item in tool_events: for idx, item in enumerate(tool_events):
if isinstance(item, dict): if isinstance(item, dict):
tool_calls.append(_openai_tool_call(item)) tool_id = str(item.get("id") or f"call_{idx}")
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
saw_tool_call = True
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}", id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()), created=int(time.time()),
@@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
choices=[ choices=[
ChatCompletionChoice( ChatCompletionChoice(
index=0, index=0,
finish_reason="stop", finish_reason="tool_calls" if saw_tool_call else "stop",
message={ message={
"role": "assistant", "role": "assistant",
"content": message_content, "content": message_content,
@@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
], ],
) )
data = response.model_dump() data = response.model_dump()
data["latency"] = { data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"), "first_token_ms": result.get("firstTokenLatencyMs"),
@@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
) )
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str: def _anthropic_stop_reason(
"""Approximate Anthropic `stop_reason`. completion_tokens: int,
max_tokens: int,
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it; *,
we report `max_tokens` only when the generated length meets or exceeds has_pending_tool_use: bool = False,
the caller's stated ceiling. Everything else is `end_turn`. ) -> str:
""" """Approximate Anthropic `stop_reason`."""
if has_pending_tool_use:
return "tool_use"
if max_tokens and completion_tokens >= max_tokens: if max_tokens and completion_tokens >= max_tokens:
return "max_tokens" return "max_tokens"
return "end_turn" return "end_turn"
@@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# Anthropic clients don't expose an ask_mode, so we always run in "chat". # Anthropic clients don't expose an ask_mode, so we always run in "chat".
ask_mode = "chat" ask_mode = "chat"
tool_config = _anthropic_tool_config(req)
has_tooling_context = _anthropic_has_tooling_context(req)
reuse_eligible = ( reuse_eligible = (
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
) )
lookup_key: str | None = None lookup_key: str | None = None
write_key: str | None = None write_key: str | None = None
cached_session_id: str | None = None cached_session_id: str | None = None
cached_instance_name: str | None = None cached_instance_name: str | None = None
if reuse_eligible: if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump) write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key) entry = await session_cache.get(lookup_key)
if entry is not None: if entry is not None:
cached_session_id = entry.session_id cached_session_id = entry.session_id
@@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
return _anthropic_error(400, "invalid_request_error", "messages is empty") return _anthropic_error(400, "invalid_request_error", "messages is empty")
prompt_tokens = estimate_tokens(prompt) prompt_tokens = estimate_tokens(prompt)
# ------------------------------------------------------------- backpressure # ------------------------------------------------------------- backpressure
try: try:
ticket = await chat_guard.try_acquire() ticket = await chat_guard.try_acquire()
@@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
success = False success = False
block_index = 0 block_index = 0
text_block_open = False text_block_open = False
saw_pending_tool_use = False
try: try:
# 1) message_start — Anthropic SDKs read this first to get # 1) message_start — Anthropic SDKs read this first to get
# the message envelope (id/model/initial usage). # the message envelope (id/model/initial usage).
@@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode, ask_mode,
session_id=cached_session_id, session_id=cached_session_id,
is_reply=is_reply, is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta, out_meta=_meta,
): ):
if _stream_event_type(chunk) == "tool": if _stream_event_type(chunk) == "tool":
@@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
tool = _stream_tool_event(chunk) tool = _stream_tool_event(chunk)
if not tool: if not tool:
continue continue
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
tool_use_block = _anthropic_tool_use_block(tool) tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
yield _sse( yield _sse(
"content_block_start", "content_block_start",
{ {
@@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
) )
block_index += 1 block_index += 1
tool_result_block = _anthropic_tool_result_block(tool) tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
if tool_result_block is not None: if tool_result_block is not None:
yield _sse( yield _sse(
"content_block_start", "content_block_start",
@@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
{"type": "content_block_stop", "index": block_index}, {"type": "content_block_stop", "index": block_index},
) )
block_index += 1 block_index += 1
else:
saw_pending_tool_use = True
continue continue
text = _stream_text(chunk) text = _stream_text(chunk)
@@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# 5) message_delta carries the terminal stop_reason and # 5) message_delta carries the terminal stop_reason and
# the final cumulative output_tokens count. # the final cumulative output_tokens count.
stop_reason = _anthropic_stop_reason( stop_reason = _anthropic_stop_reason(
completion_tokens_holder["n"], max_tokens completion_tokens_holder["n"],
max_tokens,
has_pending_tool_use=saw_pending_tool_use,
) )
yield _sse( yield _sse(
"message_delta", "message_delta",
@@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
}, },
) )
# 6) message_stop — terminal event, no [DONE] sentinel. # 6) message_stop — terminal event, no [DONE] sentinel.
yield _sse("message_stop", {"type": "message_stop"}) yield _sse("message_stop", {"type": "message_stop"})
success = True success = True
@@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode, ask_mode,
session_id=cached_session_id, session_id=cached_session_id,
is_reply=is_reply, is_reply=is_reply,
tool_config=tool_config,
) )
except Exception as exc: except Exception as exc:
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc) logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
@@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if text: if text:
content_blocks.append({"type": "text", "text": text}) content_blocks.append({"type": "text", "text": text})
tool_events = result.get("toolEvents") or [] tool_events = result.get("toolEvents") or []
saw_pending_tool_use = False
if isinstance(tool_events, list): if isinstance(tool_events, list):
for item in tool_events: for idx, item in enumerate(tool_events):
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
content_blocks.append(_anthropic_tool_use_block(item)) tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
tool_result = _anthropic_tool_result_block(item) content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
if tool_result is not None: if tool_result is not None:
content_blocks.append(tool_result) content_blocks.append(tool_result)
else:
saw_pending_tool_use = True
response_body: dict = { response_body: dict = {
"id": message_id, "id": message_id,
@@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
"role": "assistant", "role": "assistant",
"model": model, "model": model,
"content": content_blocks, "content": content_blocks,
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens), "stop_reason": _anthropic_stop_reason(
completion_tokens,
req.max_tokens,
has_pending_tool_use=saw_pending_tool_use,
),
"stop_sequence": None, "stop_sequence": None,
"usage": { "usage": {
"input_tokens": prompt_tokens, "input_tokens": prompt_tokens,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import hashlib import hashlib
import json
import time import time
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
@@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str:
return h.hexdigest() return h.hexdigest()
def _tool_fingerprint(tool_config: dict | None) -> str:
if not isinstance(tool_config, dict):
return "-"
try:
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
except Exception:
canonical = str(tool_config)
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
class SessionCache: class SessionCache:
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId. """LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
@@ -79,11 +90,11 @@ class SessionCache:
def enabled(self) -> bool: def enabled(self) -> bool:
return self.max > 0 return self.max > 0
def build_key(self, api_key: str, messages: list[dict]) -> str: def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
# API key scoping prevents cross-tenant session leakage even when # API key scoping prevents cross-tenant session leakage even when
# different clients happen to produce identical histories. # different clients happen to produce identical histories.
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12] key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
return f"{key_scope}:{hash_user_context(messages)}" return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
async def get(self, key: str) -> SessionEntry | None: async def get(self, key: str) -> SessionEntry | None:
if not self.enabled: if not self.enabled:

View File

@@ -6,6 +6,31 @@ import types
import unittest import unittest
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
class _FakeSessionCache:
def __init__(self) -> None:
self.enabled = True
self.keys: list[str] = []
self.get_calls: list[str] = []
self.put_calls: list[tuple[str, str, str]] = []
self.invalidate_calls: list[str] = []
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
marker = "with_tool" if tool_config is not None else "no_tool"
key = f"{api_key}:{len(messages)}:{marker}"
self.keys.append(key)
return key
async def get(self, key: str):
self.get_calls.append(key)
return None
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
self.put_calls.append((key, session_id, instance_name))
async def invalidate(self, key: str) -> None:
self.invalidate_calls.append(key)
# app.main imports playwright via auto_login; tests don't exercise that path. # app.main imports playwright via auto_login; tests don't exercise that path.
# Inject a lightweight stub so unit tests run without installing playwright. # Inject a lightweight stub so unit tests run without installing playwright.
_playwright = types.ModuleType("playwright") _playwright = types.ModuleType("playwright")
@@ -119,6 +144,38 @@ async def _collect_stream(response) -> str:
return "".join(chunks) return "".join(chunks)
class _SpyClient(_FakeClient):
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
super().__init__(stream_events=stream_events, complete_result=complete_result)
self.last_complete_kwargs: dict = {}
self.last_stream_kwargs: dict = {}
async def chat_complete(self, *args, **kwargs) -> dict:
self.last_complete_kwargs = dict(kwargs)
return await super().chat_complete(*args, **kwargs)
async def chat_stream(self, *args, **kwargs):
self.last_stream_kwargs = dict(kwargs)
async for event in super().chat_stream(*args, **kwargs):
yield event
class _SettingsPatch:
def __init__(self, **kwargs) -> None:
self._kwargs = kwargs
def __enter__(self):
self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()]
for p in self._patchers:
p.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for p in reversed(self._patchers):
p.stop()
return False
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
async def test_openai_non_stream_bridges_tool_calls(self) -> None: async def test_openai_non_stream_bridges_tool_calls(self) -> None:
fake_client = _FakeClient( fake_client = _FakeClient(
@@ -156,6 +213,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
message = payload["choices"][0]["message"] message = payload["choices"][0]["message"]
self.assertEqual(message["content"], "done") self.assertEqual(message["content"], "done")
self.assertIsInstance(message["tool_calls"], list) self.assertIsInstance(message["tool_calls"], list)
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs") self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
self.assertEqual( self.assertEqual(
json.loads(message["tool_calls"][0]["function"]["arguments"]), json.loads(message["tool_calls"][0]["function"]["arguments"]),
@@ -195,6 +253,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn('"tool_calls"', body) self.assertIn('"tool_calls"', body)
self.assertIn('"content": "hello"', body) self.assertIn('"content": "hello"', body)
self.assertIn('"finish_reason": "tool_calls"', body)
self.assertIn('"usage"', body) self.assertIn('"usage"', body)
self.assertIn("data: [DONE]", body) self.assertIn("data: [DONE]", body)
@@ -239,9 +298,132 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
payload = json.loads(response.body) payload = json.loads(response.body)
types = [item["type"] for item in payload["content"]] types = [item["type"] for item in payload["content"]]
self.assertEqual(types, ["text", "tool_use", "tool_result"]) self.assertEqual(types, ["text", "tool_use", "tool_result"])
self.assertEqual(payload["stop_reason"], "end_turn")
self.assertEqual(payload["content"][1]["name"], "lookup") self.assertEqual(payload["content"][1]["name"], "lookup")
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"id": "call_a",
"name": "read_file",
"input": {"path": "README.md"},
},
},
{
"type": "tool",
"tool": {
"id": "call_b",
"name": "search_docs",
"input": {"query": "gateway"},
},
},
],
complete_result={},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
body = await _collect_stream(response)
self.assertIn('"id": "call_a"', body)
self.assertIn('"id": "call_b"', body)
self.assertIn('"index": 0', body)
self.assertIn('"index": 1', body)
async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "",
"toolEvents": [
{
"name": "lookup",
"input": {"k": "v"},
}
],
"sessionId": "sess-2",
},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
payload = json.loads(response.body)
self.assertEqual(payload["stop_reason"], "tool_use")
self.assertEqual(len(payload["content"]), 1)
self.assertEqual(payload["content"][0]["type"], "tool_use")
async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"name": "read",
"input": {"file": "a.txt"},
},
}
],
complete_result={},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
body = await _collect_stream(response)
self.assertIn('"type": "tool_use"', body)
self.assertIn('"stop_reason": "tool_use"', body)
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None: async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
fake_client = _FakeClient( fake_client = _FakeClient(
stream_events=[ stream_events=[
@@ -284,11 +466,262 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn("event: message_start", body) self.assertIn("event: message_start", body)
self.assertIn('"type": "tool_use"', body) self.assertIn('"type": "tool_use"', body)
self.assertIn('"type": "tool_result"', body) self.assertIn('"type": "tool_result"', body)
self.assertIn('"stop_reason": "end_turn"', body)
self.assertIn('"type": "text_delta"', body) self.assertIn('"type": "text_delta"', body)
self.assertIn("event: message_stop", body) self.assertIn("event: message_stop", body)
class LingmaClientToolEventExtractionTests(unittest.TestCase):
async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertIn("tool_config", spy_client.last_complete_kwargs)
cfg = spy_client.last_complete_kwargs["tool_config"]
self.assertEqual(cfg["provider"], "openai")
self.assertEqual(len(cfg["tools"]), 1)
self.assertIsInstance(cfg["tool_choice"], dict)
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=False),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertIn("tool_config", spy_client.last_complete_kwargs)
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
stream_events=[],
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[
{"role": "user", "content": "turn-1"},
{"role": "user", "content": "turn-2"},
],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "session_cache", fake_cache),
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertEqual(fake_cache.keys, [])
self.assertEqual(fake_cache.get_calls, [])
self.assertEqual(fake_cache.put_calls, [])
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
stream_events=[],
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=128,
messages=[
{"role": "user", "content": "turn-1"},
{"role": "user", "content": "turn-2"},
],
stream=False,
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
tool_choice={"type": "auto"},
)
with (
patch.object(main, "session_cache", fake_cache),
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
self.assertEqual(fake_cache.keys, [])
self.assertEqual(fake_cache.get_calls, [])
self.assertEqual(fake_cache.put_calls, [])
class SessionCacheToolFingerprintTests(unittest.TestCase):
def test_build_key_changes_with_tool_config(self) -> None:
from app.session_cache import SessionCache
cache = SessionCache(max_entries=8, ttl_sec=60)
messages = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "hello"},
]
cfg_a = {
"provider": "openai",
"tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
"tool_choice": {"type": "function", "function": {"name": "lookup"}},
}
cfg_a_reordered = {
"tool_choice": {"function": {"name": "lookup"}, "type": "function"},
"tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}],
"provider": "openai",
}
cfg_b = {
"provider": "openai",
"tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}],
"tool_choice": {"type": "function", "function": {"name": "lookup_v2"}},
}
key_no_tool = cache.build_key("api-key", messages)
key_a = cache.build_key("api-key", messages, tool_config=cfg_a)
key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered)
key_b = cache.build_key("api-key", messages, tool_config=cfg_b)
self.assertNotEqual(key_no_tool, key_a)
self.assertEqual(key_a, key_a_reordered)
self.assertNotEqual(key_a, key_b)
def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"name": "lookup",
"parameters": {"q": "x"},
},
}
)
stream = rpc._chat_streams["req-1"]
self.assertEqual(stream["tool_order"], [])
self.assertEqual(stream["tool_states"], {})
self.assertTrue(stream["chunks"].empty())
import asyncio
asyncio.run(run())
def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"requestId": "req-1",
"toolCallId": "call-1",
"name": "lookup",
"parameters": {"q": "a"},
},
}
)
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invokeResult",
"params": {
"toolCallId": "call-1",
"result": {"ok": True},
},
}
)
result = rpc.get_stream_result("req-1")
self.assertEqual(len(result["toolEvents"]), 1)
self.assertEqual(result["toolEvents"][0]["id"], "call-1")
self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"})
self.assertEqual(result["toolEvents"][0]["result"], {"ok": True})
import asyncio
asyncio.run(run())
def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
msg = {
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"requestId": "req-1",
"toolCallId": "call-dup",
"name": "lookup",
"parameters": {"q": "dup"},
},
}
await rpc._handle_server_message(msg)
await rpc._handle_server_message(msg)
stream = rpc._chat_streams["req-1"]
self.assertEqual(stream["tool_order"], ["call-dup"])
self.assertEqual(stream["chunks"].qsize(), 1)
import asyncio
asyncio.run(run())
def test_extracts_tool_event_from_results_and_parameters(self) -> None: def test_extracts_tool_event_from_results_and_parameters(self) -> None:
from app.lingma_client import LspWsRpcClient from app.lingma_client import LspWsRpcClient