fix: harden tooling session reuse and event routing

Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-19 19:29:30 +08:00
parent 5aa7fbfae5
commit e600bae27c
5 changed files with 741 additions and 53 deletions

View File

@@ -100,6 +100,7 @@ class LspWsRpcClient:
self._reader_task: asyncio.Task | None = None
self._rx_buffer = b""
self._chat_streams: dict[str, dict] = {}
self._tool_stream_map: dict[str, str] = {}
self._on_disconnect = on_disconnect
self._closed = False
@@ -202,6 +203,7 @@ class LspWsRpcClient:
stream["done"].set()
stream["chunks"].put_nowait(None)
self._chat_streams.clear()
self._tool_stream_map.clear()
async def _send(self, payload: dict):
async with self._send_lock:
@@ -251,6 +253,92 @@ class LspWsRpcClient:
except Exception:
logger.exception("on_disconnect callback failed")
@staticmethod
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
event_id = None
if isinstance(tool_event, dict):
event_id = tool_event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id.strip()
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
if isinstance(fallback_id, str) and fallback_id.strip():
return fallback_id.strip()
req_id = params.get("requestId")
name = None
if isinstance(tool_event, dict):
name = tool_event.get("name")
if not name:
name = params.get("name")
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
return f"{req_id.strip()}:tool:{name.strip()}"
if isinstance(req_id, str) and req_id.strip():
return f"{req_id.strip()}:tool"
return None
@staticmethod
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
merged = dict(existing or {})
changed = False
val = incoming.get("id")
if val and merged.get("id") != val:
merged["id"] = val
changed = True
name = incoming.get("name")
if name:
existing_name = merged.get("name")
if not existing_name:
merged["name"] = name
changed = True
else:
existing_norm = str(existing_name).strip().lower()
incoming_norm = str(name).strip().lower()
if existing_norm == "tool" and incoming_norm != "tool":
merged["name"] = name
changed = True
elif existing_norm != "tool" and incoming_norm == "tool":
pass
elif merged.get("name") != name:
merged["name"] = name
changed = True
if "input" in incoming and incoming.get("input") is not None:
incoming_input = incoming.get("input")
should_update_input = incoming_input != {} or "input" not in merged
if should_update_input and merged.get("input") != incoming_input:
merged["input"] = incoming_input
changed = True
if "result" in incoming and incoming.get("result") is not None:
if merged.get("result") != incoming.get("result"):
merged["result"] = incoming.get("result")
changed = True
return merged, changed
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
req_id = params.get("requestId")
if isinstance(req_id, str) and req_id.strip():
stream = self._chat_streams.get(req_id)
if stream is not None and tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
self._tool_stream_map[tool_id] = req_id
return stream
if tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
mapped_req = self._tool_stream_map.get(tool_id)
if mapped_req:
return self._chat_streams.get(mapped_req)
return None
async def _handle_server_message(self, msg: dict):
method = msg.get("method")
params = msg.get("params") or {}
@@ -268,21 +356,29 @@ class LspWsRpcClient:
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
tool_event = self._extract_tool_event(params)
req_id = params.get("requestId")
stream = self._chat_streams.get(req_id) if req_id else None
if stream is None and tool_event is not None:
for item in self._chat_streams.values():
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
stream = item
break
if stream is None and len(self._chat_streams) == 1:
stream = next(iter(self._chat_streams.values()))
stream = self._resolve_tool_stream(method, params, tool_event)
if stream is not None and tool_event is not None:
stream["tool_events"].append(tool_event)
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
tool_id = self._normalize_tool_id(method, params, tool_event)
if not tool_id:
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
else:
tool_states = stream["tool_states"]
order = stream["tool_order"]
existing = tool_states.get(tool_id)
merged, changed = self._merge_tool_event(existing, tool_event)
if not existing:
if "id" not in merged or not merged.get("id"):
merged["id"] = tool_id
tool_states[tool_id] = merged
order.append(tool_id)
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif changed:
tool_states[tool_id] = merged
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif tool_event is not None:
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
if method == "chat/finish":
req_id = params.get("requestId")
@@ -321,7 +417,8 @@ class LspWsRpcClient:
"chunks": asyncio.Queue(),
"done": asyncio.Event(),
"finish": None,
"tool_events": [],
"tool_states": {},
"tool_order": [],
"started_at": time.monotonic(),
"first_chunk_at": None,
"finish_at": None,
@@ -331,6 +428,9 @@ class LspWsRpcClient:
stream = self._chat_streams.pop(request_id, None)
if stream is None:
return
for tool_id, mapped_req in list(self._tool_stream_map.items()):
if mapped_req == request_id:
self._tool_stream_map.pop(tool_id, None)
# Drain queue so no stray future gets stuck if the consumer bailed early.
if not stream["done"].is_set():
stream["done"].set()
@@ -359,12 +459,20 @@ class LspWsRpcClient:
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
if stream.get("finish_at") is not None:
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
ordered_tool_events: list[dict[str, Any]] = []
tool_states = stream.get("tool_states") or {}
for tool_id in stream.get("tool_order") or []:
event = tool_states.get(tool_id)
if isinstance(event, dict):
ordered_tool_events.append(event)
return {
"text": "".join(stream.get("parts") or []),
"finish": stream.get("finish") or {},
"firstTokenLatencyMs": first_ms,
"totalLatencyMs": total_ms,
"toolEvents": stream.get("tool_events") or [],
"toolEvents": ordered_tool_events,
}
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
request_id: str,
*,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
):
session_type = "developer" if ask_mode == "agent" else "chat"
return {
payload = {
"requestId": request_id,
"sessionId": session_id,
"sessionType": session_type,
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
"localeLang": "zh-CN",
},
}
if tool_config is not None:
payload["toolConfig"] = tool_config
return payload
async def _kick_chat_ask(self, payload: dict) -> None:
"""Fire chat/ask as a notification.
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
*,
session_id: str | None = None,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
) -> dict:
await self.ensure_ready()
request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
)
self.rpc.create_stream(request_id)
try:
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
*,
session_id: str | None = None,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
out_meta: dict | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""Stream chat events.
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
)
self.rpc.create_stream(request_id)
try: