fix: emit Lingma tool approve/invoke roundtrip
Forward tool/call/sync and tool/invoke events to Lingma with auto-approve and invokeResult so tool calls can complete end-to-end. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -101,6 +101,7 @@ class LspWsRpcClient:
|
|||||||
self._rx_buffer = b""
|
self._rx_buffer = b""
|
||||||
self._chat_streams: dict[str, dict] = {}
|
self._chat_streams: dict[str, dict] = {}
|
||||||
self._tool_stream_map: dict[str, str] = {}
|
self._tool_stream_map: dict[str, str] = {}
|
||||||
|
self._tool_roundtrip_done: set[str] = set()
|
||||||
self._on_disconnect = on_disconnect
|
self._on_disconnect = on_disconnect
|
||||||
self._closed = False
|
self._closed = False
|
||||||
|
|
||||||
@@ -204,6 +205,7 @@ class LspWsRpcClient:
|
|||||||
stream["chunks"].put_nowait(None)
|
stream["chunks"].put_nowait(None)
|
||||||
self._chat_streams.clear()
|
self._chat_streams.clear()
|
||||||
self._tool_stream_map.clear()
|
self._tool_stream_map.clear()
|
||||||
|
self._tool_roundtrip_done.clear()
|
||||||
|
|
||||||
async def _send(self, payload: dict):
|
async def _send(self, payload: dict):
|
||||||
async with self._send_lock:
|
async with self._send_lock:
|
||||||
@@ -320,6 +322,55 @@ class LspWsRpcClient:
|
|||||||
return merged, changed
|
return merged, changed
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_tool_roundtrip_method(method: str | None) -> bool:
|
||||||
|
return method in {"tool/call/sync", "tool/invoke"}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_approve_params(params: dict[str, Any], tool_id: str) -> dict[str, Any] | None:
|
||||||
|
req_id = params.get("requestId")
|
||||||
|
session_id = params.get("sessionId")
|
||||||
|
if not isinstance(req_id, str) or not req_id.strip():
|
||||||
|
return None
|
||||||
|
if not isinstance(session_id, str) or not session_id.strip():
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"type": "tool_call",
|
||||||
|
"sessionId": session_id,
|
||||||
|
"requestId": req_id,
|
||||||
|
"toolCallId": tool_id,
|
||||||
|
"approval": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_tool_invoke_result_params(params: dict[str, Any], tool_event: dict[str, Any], tool_id: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"toolCallId": tool_id,
|
||||||
|
"name": str(tool_event.get("name") or params.get("name") or "tool"),
|
||||||
|
"success": True,
|
||||||
|
"errorMessage": "",
|
||||||
|
"result": tool_event.get("result") if "result" in tool_event else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _maybe_emit_tool_roundtrip(self, method: str, params: dict[str, Any], tool_event: dict[str, Any]) -> None:
|
||||||
|
if not self._is_tool_roundtrip_method(method):
|
||||||
|
return
|
||||||
|
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||||
|
if not tool_id:
|
||||||
|
return
|
||||||
|
if tool_id in self._tool_roundtrip_done:
|
||||||
|
return
|
||||||
|
|
||||||
|
approve_params = self._build_tool_approve_params(params, tool_id)
|
||||||
|
if approve_params is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._tool_roundtrip_done.add(tool_id)
|
||||||
|
await self.notify("tool/call/approve", approve_params)
|
||||||
|
invoke_result_params = self._build_tool_invoke_result_params(params, tool_event, tool_id)
|
||||||
|
await self.notify("tool/invokeResult", invoke_result_params)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||||
req_id = params.get("requestId")
|
req_id = params.get("requestId")
|
||||||
if isinstance(req_id, str) and req_id.strip():
|
if isinstance(req_id, str) and req_id.strip():
|
||||||
@@ -363,6 +414,7 @@ class LspWsRpcClient:
|
|||||||
if not tool_id:
|
if not tool_id:
|
||||||
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||||
else:
|
else:
|
||||||
|
await self._maybe_emit_tool_roundtrip(method, params, tool_event)
|
||||||
tool_states = stream["tool_states"]
|
tool_states = stream["tool_states"]
|
||||||
order = stream["tool_order"]
|
order = stream["tool_order"]
|
||||||
existing = tool_states.get(tool_id)
|
existing = tool_states.get(tool_id)
|
||||||
@@ -431,6 +483,7 @@ class LspWsRpcClient:
|
|||||||
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||||
if mapped_req == request_id:
|
if mapped_req == request_id:
|
||||||
self._tool_stream_map.pop(tool_id, None)
|
self._tool_stream_map.pop(tool_id, None)
|
||||||
|
self._tool_roundtrip_done.discard(tool_id)
|
||||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||||
if not stream["done"].is_set():
|
if not stream["done"].is_set():
|
||||||
stream["done"].set()
|
stream["done"].set()
|
||||||
|
|||||||
@@ -798,7 +798,6 @@ class SessionCacheToolFingerprintTests(unittest.TestCase):
|
|||||||
"result": {"hits": 3},
|
"result": {"hits": 3},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
event,
|
event,
|
||||||
{
|
{
|
||||||
@@ -808,3 +807,97 @@ class SessionCacheToolFingerprintTests(unittest.TestCase):
|
|||||||
"result": {"hits": 3},
|
"result": {"hits": 3},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_sync_triggers_approve_and_invoke_result_requests(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
class _WsStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.frames: list[bytes] = []
|
||||||
|
|
||||||
|
async def send(self, data: bytes) -> None:
|
||||||
|
self.frames.append(data)
|
||||||
|
|
||||||
|
def _decode(frame: bytes) -> dict:
|
||||||
|
body = frame.split(b"\r\n\r\n", 1)[1]
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
ws = _WsStub()
|
||||||
|
rpc = LspWsRpcClient(ws)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/call/sync",
|
||||||
|
"params": {
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"name": "run_in_terminal",
|
||||||
|
"parameters": {"command": "pwd"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = [_decode(frame) for frame in ws.frames]
|
||||||
|
methods = [item.get("method") for item in decoded]
|
||||||
|
self.assertIn("tool/call/approve", methods)
|
||||||
|
self.assertIn("tool/invokeResult", methods)
|
||||||
|
|
||||||
|
approve = next(item for item in decoded if item.get("method") == "tool/call/approve")
|
||||||
|
self.assertEqual(
|
||||||
|
approve["params"],
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"requestId": "req-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"approval": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
invoke_result = next(item for item in decoded if item.get("method") == "tool/invokeResult")
|
||||||
|
self.assertEqual(invoke_result["params"]["toolCallId"], "call-1")
|
||||||
|
self.assertEqual(invoke_result["params"]["name"], "run_in_terminal")
|
||||||
|
self.assertTrue(invoke_result["params"]["success"])
|
||||||
|
self.assertEqual(invoke_result["params"]["errorMessage"], "")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
def test_tool_sync_does_not_emit_roundtrip_without_request_id(self) -> None:
|
||||||
|
from app.lingma_client import LspWsRpcClient
|
||||||
|
|
||||||
|
class _WsStub:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.frames: list[bytes] = []
|
||||||
|
|
||||||
|
async def send(self, data: bytes) -> None:
|
||||||
|
self.frames.append(data)
|
||||||
|
|
||||||
|
ws = _WsStub()
|
||||||
|
rpc = LspWsRpcClient(ws)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
rpc.create_stream("req-1")
|
||||||
|
await rpc._handle_server_message(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tool/call/sync",
|
||||||
|
"params": {
|
||||||
|
"sessionId": "sess-1",
|
||||||
|
"toolCallId": "call-1",
|
||||||
|
"name": "run_in_terminal",
|
||||||
|
"parameters": {"command": "pwd"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.assertEqual(ws.frames, [])
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|||||||
Reference in New Issue
Block a user