fix: harden tooling session reuse and event routing
Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -100,6 +100,7 @@ class LspWsRpcClient:
|
||||
self._reader_task: asyncio.Task | None = None
|
||||
self._rx_buffer = b""
|
||||
self._chat_streams: dict[str, dict] = {}
|
||||
self._tool_stream_map: dict[str, str] = {}
|
||||
self._on_disconnect = on_disconnect
|
||||
self._closed = False
|
||||
|
||||
@@ -202,6 +203,7 @@ class LspWsRpcClient:
|
||||
stream["done"].set()
|
||||
stream["chunks"].put_nowait(None)
|
||||
self._chat_streams.clear()
|
||||
self._tool_stream_map.clear()
|
||||
|
||||
async def _send(self, payload: dict):
|
||||
async with self._send_lock:
|
||||
@@ -251,6 +253,92 @@ class LspWsRpcClient:
|
||||
except Exception:
|
||||
logger.exception("on_disconnect callback failed")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
|
||||
event_id = None
|
||||
if isinstance(tool_event, dict):
|
||||
event_id = tool_event.get("id")
|
||||
if isinstance(event_id, str) and event_id.strip():
|
||||
return event_id.strip()
|
||||
|
||||
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
|
||||
if isinstance(fallback_id, str) and fallback_id.strip():
|
||||
return fallback_id.strip()
|
||||
|
||||
req_id = params.get("requestId")
|
||||
name = None
|
||||
if isinstance(tool_event, dict):
|
||||
name = tool_event.get("name")
|
||||
if not name:
|
||||
name = params.get("name")
|
||||
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
|
||||
return f"{req_id.strip()}:tool:{name.strip()}"
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
return f"{req_id.strip()}:tool"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||
merged = dict(existing or {})
|
||||
changed = False
|
||||
|
||||
val = incoming.get("id")
|
||||
if val and merged.get("id") != val:
|
||||
merged["id"] = val
|
||||
changed = True
|
||||
|
||||
name = incoming.get("name")
|
||||
if name:
|
||||
existing_name = merged.get("name")
|
||||
if not existing_name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
else:
|
||||
existing_norm = str(existing_name).strip().lower()
|
||||
incoming_norm = str(name).strip().lower()
|
||||
if existing_norm == "tool" and incoming_norm != "tool":
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
elif existing_norm != "tool" and incoming_norm == "tool":
|
||||
pass
|
||||
elif merged.get("name") != name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
|
||||
if "input" in incoming and incoming.get("input") is not None:
|
||||
incoming_input = incoming.get("input")
|
||||
should_update_input = incoming_input != {} or "input" not in merged
|
||||
if should_update_input and merged.get("input") != incoming_input:
|
||||
merged["input"] = incoming_input
|
||||
changed = True
|
||||
|
||||
if "result" in incoming and incoming.get("result") is not None:
|
||||
if merged.get("result") != incoming.get("result"):
|
||||
merged["result"] = incoming.get("result")
|
||||
changed = True
|
||||
|
||||
return merged, changed
|
||||
|
||||
|
||||
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||
req_id = params.get("requestId")
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
stream = self._chat_streams.get(req_id)
|
||||
if stream is not None and tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
self._tool_stream_map[tool_id] = req_id
|
||||
return stream
|
||||
|
||||
if tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
mapped_req = self._tool_stream_map.get(tool_id)
|
||||
if mapped_req:
|
||||
return self._chat_streams.get(mapped_req)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_server_message(self, msg: dict):
|
||||
method = msg.get("method")
|
||||
params = msg.get("params") or {}
|
||||
@@ -268,21 +356,29 @@ class LspWsRpcClient:
|
||||
|
||||
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
|
||||
tool_event = self._extract_tool_event(params)
|
||||
req_id = params.get("requestId")
|
||||
stream = self._chat_streams.get(req_id) if req_id else None
|
||||
|
||||
if stream is None and tool_event is not None:
|
||||
for item in self._chat_streams.values():
|
||||
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
|
||||
stream = item
|
||||
break
|
||||
|
||||
if stream is None and len(self._chat_streams) == 1:
|
||||
stream = next(iter(self._chat_streams.values()))
|
||||
stream = self._resolve_tool_stream(method, params, tool_event)
|
||||
|
||||
if stream is not None and tool_event is not None:
|
||||
stream["tool_events"].append(tool_event)
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if not tool_id:
|
||||
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||
else:
|
||||
tool_states = stream["tool_states"]
|
||||
order = stream["tool_order"]
|
||||
existing = tool_states.get(tool_id)
|
||||
merged, changed = self._merge_tool_event(existing, tool_event)
|
||||
if not existing:
|
||||
if "id" not in merged or not merged.get("id"):
|
||||
merged["id"] = tool_id
|
||||
tool_states[tool_id] = merged
|
||||
order.append(tool_id)
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif changed:
|
||||
tool_states[tool_id] = merged
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif tool_event is not None:
|
||||
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
|
||||
|
||||
|
||||
if method == "chat/finish":
|
||||
req_id = params.get("requestId")
|
||||
@@ -321,7 +417,8 @@ class LspWsRpcClient:
|
||||
"chunks": asyncio.Queue(),
|
||||
"done": asyncio.Event(),
|
||||
"finish": None,
|
||||
"tool_events": [],
|
||||
"tool_states": {},
|
||||
"tool_order": [],
|
||||
"started_at": time.monotonic(),
|
||||
"first_chunk_at": None,
|
||||
"finish_at": None,
|
||||
@@ -331,6 +428,9 @@ class LspWsRpcClient:
|
||||
stream = self._chat_streams.pop(request_id, None)
|
||||
if stream is None:
|
||||
return
|
||||
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||
if mapped_req == request_id:
|
||||
self._tool_stream_map.pop(tool_id, None)
|
||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||
if not stream["done"].is_set():
|
||||
stream["done"].set()
|
||||
@@ -359,12 +459,20 @@ class LspWsRpcClient:
|
||||
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
||||
if stream.get("finish_at") is not None:
|
||||
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
||||
|
||||
ordered_tool_events: list[dict[str, Any]] = []
|
||||
tool_states = stream.get("tool_states") or {}
|
||||
for tool_id in stream.get("tool_order") or []:
|
||||
event = tool_states.get(tool_id)
|
||||
if isinstance(event, dict):
|
||||
ordered_tool_events.append(event)
|
||||
|
||||
return {
|
||||
"text": "".join(stream.get("parts") or []),
|
||||
"finish": stream.get("finish") or {},
|
||||
"firstTokenLatencyMs": first_ms,
|
||||
"totalLatencyMs": total_ms,
|
||||
"toolEvents": stream.get("tool_events") or [],
|
||||
"toolEvents": ordered_tool_events,
|
||||
}
|
||||
|
||||
|
||||
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
|
||||
request_id: str,
|
||||
*,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
):
|
||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
||||
return {
|
||||
payload = {
|
||||
"requestId": request_id,
|
||||
"sessionId": session_id,
|
||||
"sessionType": session_type,
|
||||
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
|
||||
"localeLang": "zh-CN",
|
||||
},
|
||||
}
|
||||
if tool_config is not None:
|
||||
payload["toolConfig"] = tool_config
|
||||
return payload
|
||||
|
||||
async def _kick_chat_ask(self, payload: dict) -> None:
|
||||
"""Fire chat/ask as a notification.
|
||||
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
await self.ensure_ready()
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
out_meta: dict | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Stream chat events.
|
||||
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user