diff --git a/app/lingma_client.py b/app/lingma_client.py index 2ff98fb..93fa106 100644 --- a/app/lingma_client.py +++ b/app/lingma_client.py @@ -101,6 +101,7 @@ class LspWsRpcClient: self._rx_buffer = b"" self._chat_streams: dict[str, dict] = {} self._tool_stream_map: dict[str, str] = {} + self._tool_roundtrip_done: set[str] = set() self._on_disconnect = on_disconnect self._closed = False @@ -204,6 +205,7 @@ class LspWsRpcClient: stream["chunks"].put_nowait(None) self._chat_streams.clear() self._tool_stream_map.clear() + self._tool_roundtrip_done.clear() async def _send(self, payload: dict): async with self._send_lock: @@ -320,6 +322,55 @@ class LspWsRpcClient: return merged, changed + @staticmethod + def _is_tool_roundtrip_method(method: str | None) -> bool: + return method in {"tool/call/sync", "tool/invoke"} + + @staticmethod + def _build_tool_approve_params(params: dict[str, Any], tool_id: str) -> dict[str, Any] | None: + req_id = params.get("requestId") + session_id = params.get("sessionId") + if not isinstance(req_id, str) or not req_id.strip(): + return None + if not isinstance(session_id, str) or not session_id.strip(): + return None + return { + "type": "tool_call", + "sessionId": session_id, + "requestId": req_id, + "toolCallId": tool_id, + "approval": True, + } + + @staticmethod + def _build_tool_invoke_result_params(params: dict[str, Any], tool_event: dict[str, Any], tool_id: str) -> dict[str, Any]: + return { + "toolCallId": tool_id, + "name": str(tool_event.get("name") or params.get("name") or "tool"), + "success": True, + "errorMessage": "", + "result": tool_event.get("result") if "result" in tool_event else {}, + } + + async def _maybe_emit_tool_roundtrip(self, method: str, params: dict[str, Any], tool_event: dict[str, Any]) -> None: + if not self._is_tool_roundtrip_method(method): + return + tool_id = self._normalize_tool_id(method, params, tool_event) + if not tool_id: + return + if tool_id in self._tool_roundtrip_done: + return + + approve_params = self._build_tool_approve_params(params, tool_id) + if approve_params is None: + return + + self._tool_roundtrip_done.add(tool_id) + await self.notify("tool/call/approve", approve_params) + invoke_result_params = self._build_tool_invoke_result_params(params, tool_event, tool_id) + await self.notify("tool/invokeResult", invoke_result_params) + + def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None: req_id = params.get("requestId") if isinstance(req_id, str) and req_id.strip(): @@ -363,6 +414,7 @@ class LspWsRpcClient: if not tool_id: logger.warning("drop unroutable tool event: method=%s missing tool id", method) else: + await self._maybe_emit_tool_roundtrip(method, params, tool_event) tool_states = stream["tool_states"] order = stream["tool_order"] existing = tool_states.get(tool_id) @@ -431,6 +483,7 @@ class LspWsRpcClient: for tool_id, mapped_req in list(self._tool_stream_map.items()): if mapped_req == request_id: self._tool_stream_map.pop(tool_id, None) + self._tool_roundtrip_done.discard(tool_id) # Drain queue so no stray future gets stuck if the consumer bailed early. if not stream["done"].is_set(): stream["done"].set() diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index 3b02791..b10a5cb 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -798,7 +798,6 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): "result": {"hits": 3}, } ) - self.assertEqual( event, { @@ -808,3 +807,97 @@ class SessionCacheToolFingerprintTests(unittest.TestCase): "result": {"hits": 3}, }, ) + + + def test_tool_sync_triggers_approve_and_invoke_result_requests(self) -> None: + from app.lingma_client import LspWsRpcClient + + class _WsStub: + def __init__(self) -> None: + self.frames: list[bytes] = [] + + async def send(self, data: bytes) -> None: + self.frames.append(data) + + def _decode(frame: bytes) -> dict: + body = frame.split(b"\r\n\r\n", 1)[1] + return json.loads(body.decode("utf-8")) + + ws = _WsStub() + rpc = LspWsRpcClient(ws) + + async def run() -> None: + rpc.create_stream("req-1") + await rpc._handle_server_message( + { + "jsonrpc": "2.0", + "method": "tool/call/sync", + "params": { + "sessionId": "sess-1", + "requestId": "req-1", + "toolCallId": "call-1", + "name": "run_in_terminal", + "parameters": {"command": "pwd"}, + }, + } + ) + + decoded = [_decode(frame) for frame in ws.frames] + methods = [item.get("method") for item in decoded] + self.assertIn("tool/call/approve", methods) + self.assertIn("tool/invokeResult", methods) + + approve = next(item for item in decoded if item.get("method") == "tool/call/approve") + self.assertEqual( + approve["params"], + { + "type": "tool_call", + "sessionId": "sess-1", + "requestId": "req-1", + "toolCallId": "call-1", + "approval": True, + }, + ) + + invoke_result = next(item for item in decoded if item.get("method") == "tool/invokeResult") + self.assertEqual(invoke_result["params"]["toolCallId"], "call-1") + self.assertEqual(invoke_result["params"]["name"], "run_in_terminal") + self.assertTrue(invoke_result["params"]["success"]) + self.assertEqual(invoke_result["params"]["errorMessage"], "") + + import asyncio + + asyncio.run(run()) + + def test_tool_sync_does_not_emit_roundtrip_without_request_id(self) -> None: + from app.lingma_client import LspWsRpcClient + + class _WsStub: + def __init__(self) -> None: + self.frames: list[bytes] = [] + + async def send(self, data: bytes) -> None: + self.frames.append(data) + + ws = _WsStub() + rpc = LspWsRpcClient(ws) + + async def run() -> None: + rpc.create_stream("req-1") + await rpc._handle_server_message( + { + "jsonrpc": "2.0", + "method": "tool/call/sync", + "params": { + "sessionId": "sess-1", + "toolCallId": "call-1", + "name": "run_in_terminal", + "parameters": {"command": "pwd"}, + }, + } + ) + self.assertEqual(ws.frames, []) + + import asyncio + + asyncio.run(run())