from __future__ import annotations import asyncio import contextlib import json import os import socket import subprocess import time import uuid from pathlib import Path from typing import Any, AsyncIterator, Callable, Optional import websockets from .logging_config import get_logger logger = get_logger("lingma_gateway.client") def _tool_config_summary(tool_config: dict[str, Any] | None) -> dict[str, Any]: if not isinstance(tool_config, dict): return {"present": False, "provider": None, "tool_names": [], "tool_choice": None} tools = tool_config.get("tools") tool_names: list[str] = [] if isinstance(tools, list): for tool in tools: if not isinstance(tool, dict): continue if tool.get("type") == "function": fn = tool.get("function") if isinstance(fn, dict) and isinstance(fn.get("name"), str) and fn.get("name").strip(): tool_names.append(fn.get("name").strip()) continue name = tool.get("name") if isinstance(name, str) and name.strip(): tool_names.append(name.strip()) return { "present": True, "provider": tool_config.get("provider"), "tool_names": tool_names, "tool_choice": tool_config.get("tool_choice"), } # Some callers live on Python 3.10 where asyncio.TimeoutError is a distinct class, # while 3.11+ unifies it with the builtin TimeoutError. Always catch both. TIMEOUT_EXCEPTIONS: tuple[type[BaseException], ...] = ( asyncio.TimeoutError, TimeoutError, ) def _is_port_open(host: str, port: int, timeout_sec: float = 0.5) -> bool: try: with socket.create_connection((host, port), timeout=timeout_sec): return True except OSError: return False def _read_info_file(info_path: Path): if not info_path.exists(): return None, None txt = info_path.read_text(encoding="utf-8", errors="ignore").strip() if not txt: return None, None lines = txt.splitlines() if len(lines) < 2: return None, None try: return int(lines[0].strip()), int(lines[1].strip()) except ValueError: return None, None def _wait_info_any(info_paths: list[Path], timeout_sec: int): start = time.time() while time.time() - start < timeout_sec: for p in info_paths: port, pid = _read_info_file(p) if port and pid: return port, pid, p time.sleep(0.2) raise TimeoutError(".info not ready") def _encode_lsp_frame(payload_obj: dict) -> bytes: body = json.dumps(payload_obj, ensure_ascii=False).encode("utf-8") header = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii") return header + body def _parse_lsp_frames(buf: bytes): frames = [] while True: header_end = buf.find(b"\r\n\r\n") if header_end < 0: break header = buf[:header_end] body_start = header_end + 4 content_length = None for line in header.split(b"\r\n"): if line.lower().startswith(b"content-length:"): content_length = int(line.split(b":", 1)[1].strip()) break if content_length is None: buf = buf[body_start:] continue if len(buf) < body_start + content_length: break body = buf[body_start : body_start + content_length] frames.append(body.decode("utf-8", errors="ignore")) buf = buf[body_start + content_length :] return frames, buf class LspWsRpcClient: def __init__(self, ws, on_disconnect: Optional[Callable[[BaseException], None]] = None): self.ws = ws self._id = 1 self._pending: dict[int, asyncio.Future] = {} self._send_lock = asyncio.Lock() self._reader_task: asyncio.Task | None = None self._rx_buffer = b"" self._chat_streams: dict[str, dict] = {} self._tool_stream_map: dict[str, str] = {} self._tool_roundtrip_done: set[str] = set() self._on_disconnect = on_disconnect self._closed = False @staticmethod def _extract_tool_event(params: dict[str, Any]) -> dict[str, Any] | None: candidates: list[dict[str, Any]] = [] def add_candidate(obj: Any) -> None: if isinstance(obj, dict): candidates.append(obj) add_candidate(params.get("toolCall")) add_candidate(params.get("tool_call")) add_candidate(params.get("tool")) data = params.get("data") if isinstance(data, dict): add_candidate(data.get("toolCall")) add_candidate(data.get("tool_call")) add_candidate(data.get("tool")) results = params.get("results") if isinstance(results, list): for item in results: add_candidate(item) if not candidates: fallback_id = params.get("toolCallId") or params.get("tool_call_id") if not fallback_id: return None return { "id": str(fallback_id), "name": str(params.get("name") or "tool"), "input": params.get("parameters") or {}, "result": params.get("result"), } raw = candidates[0] tool_id = ( raw.get("toolCallId") or raw.get("tool_call_id") or raw.get("id") or params.get("toolCallId") or params.get("tool_call_id") ) name = ( raw.get("name") or raw.get("toolName") or raw.get("tool_name") or params.get("name") ) call_input = raw.get("input") if call_input is None: call_input = raw.get("arguments") if call_input is None: call_input = raw.get("args") if call_input is None: call_input = raw.get("parameters") if call_input is None: call_input = params.get("parameters") result_payload = raw.get("result") if result_payload is None: result_payload = params.get("result") if result_payload is None and isinstance(data, dict): result_payload = data.get("result") if result_payload is None and isinstance(raw.get("results"), list): result_payload = raw.get("results") if not tool_id: return None event: dict[str, Any] = { "id": str(tool_id), "name": str(name or "tool"), "input": call_input if call_input is not None else {}, } if result_payload is not None: event["result"] = result_payload return event async def start(self): self._reader_task = asyncio.create_task(self._reader_loop()) async def close(self): self._closed = True if self._reader_task: self._reader_task.cancel() with contextlib.suppress(Exception): await self._reader_task # Abort any pending futures so callers fail fast instead of hanging. for fut in self._pending.values(): if not fut.done(): fut.set_exception(ConnectionError("lingma client closed")) self._pending.clear() # Signal open streams to terminate. for stream in self._chat_streams.values(): if not stream["done"].is_set(): stream["done"].set() stream["chunks"].put_nowait(None) self._chat_streams.clear() self._tool_stream_map.clear() self._tool_roundtrip_done.clear() async def _send(self, payload: dict): async with self._send_lock: await self.ws.send(_encode_lsp_frame(payload)) async def _reader_loop(self): try: while True: raw = await self.ws.recv() chunk = raw if isinstance(raw, bytes) else raw.encode("utf-8", errors="ignore") self._rx_buffer += chunk bodies, self._rx_buffer = _parse_lsp_frames(self._rx_buffer) for body in bodies: try: msg = json.loads(body) except Exception: continue if "method" in msg and "result" not in msg and "error" not in msg: await self._handle_server_message(msg) continue rid = msg.get("id") if rid is None: continue fut = self._pending.pop(rid, None) if fut and not fut.done(): fut.set_result(msg) except asyncio.CancelledError: pass except Exception as exc: if not self._closed: logger.warning("lingma reader loop terminated: %s", exc) # Propagate failure to anyone waiting on an RPC. for fut in self._pending.values(): if not fut.done(): fut.set_exception(exc) self._pending.clear() # Also unblock any in-flight chat streams so consumers exit. for stream in self._chat_streams.values(): if not stream["done"].is_set(): stream["done"].set() stream["chunks"].put_nowait(None) if not self._closed and self._on_disconnect is not None: try: self._on_disconnect(exc) except Exception: logger.exception("on_disconnect callback failed") @staticmethod def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None: event_id = None if isinstance(tool_event, dict): event_id = tool_event.get("id") if isinstance(event_id, str) and event_id.strip(): return event_id.strip() fallback_id = params.get("toolCallId") or params.get("tool_call_id") if isinstance(fallback_id, str) and fallback_id.strip(): return fallback_id.strip() req_id = params.get("requestId") name = None if isinstance(tool_event, dict): name = tool_event.get("name") if not name: name = params.get("name") if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip(): return f"{req_id.strip()}:tool:{name.strip()}" if isinstance(req_id, str) and req_id.strip(): return f"{req_id.strip()}:tool" return None @staticmethod def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]: merged = dict(existing or {}) changed = False val = incoming.get("id") if val and merged.get("id") != val: merged["id"] = val changed = True name = incoming.get("name") if name: existing_name = merged.get("name") if not existing_name: merged["name"] = name changed = True else: existing_norm = str(existing_name).strip().lower() incoming_norm = str(name).strip().lower() if existing_norm == "tool" and incoming_norm != "tool": merged["name"] = name changed = True elif existing_norm != "tool" and incoming_norm == "tool": pass elif merged.get("name") != name: merged["name"] = name changed = True if "input" in incoming and incoming.get("input") is not None: incoming_input = incoming.get("input") should_update_input = incoming_input != {} or "input" not in merged if should_update_input and merged.get("input") != incoming_input: merged["input"] = incoming_input changed = True if "result" in incoming and incoming.get("result") is not None: if merged.get("result") != incoming.get("result"): merged["result"] = incoming.get("result") changed = True return merged, changed @staticmethod def _is_tool_roundtrip_method(method: str | None) -> bool: return method in {"tool/call/sync", "tool/invoke"} @staticmethod def _build_tool_approve_params(params: dict[str, Any], tool_id: str) -> dict[str, Any] | None: req_id = params.get("requestId") session_id = params.get("sessionId") if not isinstance(req_id, str) or not req_id.strip(): return None if not isinstance(session_id, str) or not session_id.strip(): return None return { "type": "tool_call", "sessionId": session_id, "requestId": req_id, "toolCallId": tool_id, "approval": True, } @staticmethod def _build_tool_invoke_result_params(params: dict[str, Any], tool_event: dict[str, Any], tool_id: str) -> dict[str, Any]: return { "toolCallId": tool_id, "name": str(tool_event.get("name") or params.get("name") or "tool"), "success": True, "errorMessage": "", "result": tool_event.get("result") if "result" in tool_event else {}, } async def _maybe_emit_tool_roundtrip(self, method: str, params: dict[str, Any], tool_event: dict[str, Any]) -> None: if not self._is_tool_roundtrip_method(method): return tool_id = self._normalize_tool_id(method, params, tool_event) if not tool_id: return if tool_id in self._tool_roundtrip_done: return approve_params = self._build_tool_approve_params(params, tool_id) if approve_params is None: return self._tool_roundtrip_done.add(tool_id) await self.notify("tool/call/approve", approve_params) invoke_result_params = self._build_tool_invoke_result_params(params, tool_event, tool_id) await self.notify("tool/invokeResult", invoke_result_params) def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None: req_id = params.get("requestId") if isinstance(req_id, str) and req_id.strip(): stream = self._chat_streams.get(req_id) if stream is not None and tool_event is not None: tool_id = self._normalize_tool_id(method, params, tool_event) if tool_id: self._tool_stream_map[tool_id] = req_id return stream if tool_event is not None: tool_id = self._normalize_tool_id(method, params, tool_event) if tool_id: mapped_req = self._tool_stream_map.get(tool_id) if mapped_req: return self._chat_streams.get(mapped_req) return None async def _handle_server_message(self, msg: dict): method = msg.get("method") params = msg.get("params") or {} if method == "chat/answer": req_id = params.get("requestId") stream = self._chat_streams.get(req_id) if stream is not None: text = params.get("text") or params.get("content") or "" if text: stream["parts"].append(text) if stream["first_chunk_at"] is None: stream["first_chunk_at"] = time.monotonic() stream["chunks"].put_nowait({"type": "text", "text": text}) if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}: tool_event = self._extract_tool_event(params) logger.info( "lingma tool event method=%s request_id=%s tool=%s", method, params.get("requestId"), tool_event, ) stream = self._resolve_tool_stream(method, params, tool_event) if stream is not None and tool_event is not None: tool_id = self._normalize_tool_id(method, params, tool_event) if not tool_id: logger.warning("drop unroutable tool event: method=%s missing tool id", method) else: await self._maybe_emit_tool_roundtrip(method, params, tool_event) tool_states = stream["tool_states"] order = stream["tool_order"] existing = tool_states.get(tool_id) merged, changed = self._merge_tool_event(existing, tool_event) if not existing: if "id" not in merged or not merged.get("id"): merged["id"] = tool_id tool_states[tool_id] = merged order.append(tool_id) stream["chunks"].put_nowait({"type": "tool", "tool": merged}) elif changed: tool_states[tool_id] = merged stream["chunks"].put_nowait({"type": "tool", "tool": merged}) elif tool_event is not None: logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId")) if method == "chat/finish": logger.info( "lingma finish request_id=%s session_id=%s", params.get("requestId"), params.get("sessionId"), ) req_id = params.get("requestId") stream = self._chat_streams.get(req_id) if stream is not None and not stream["done"].is_set(): stream["finish"] = params stream["finish_at"] = time.monotonic() stream["done"].set() stream["chunks"].put_nowait(None) if "id" in msg: await self._send({"jsonrpc": "2.0", "id": msg.get("id"), "result": {}}) async def request(self, method, params=None, timeout=20): rid = self._id self._id += 1 payload = {"jsonrpc": "2.0", "id": rid, "method": method, "params": params or {}} fut = asyncio.get_running_loop().create_future() self._pending[rid] = fut await self._send(payload) try: msg = await asyncio.wait_for(fut, timeout=timeout) except TIMEOUT_EXCEPTIONS: self._pending.pop(rid, None) raise TimeoutError(f"RPC timeout: {method}") if "error" in msg: raise RuntimeError(f"RPC {method} error: {msg['error']}") return msg.get("result") async def notify(self, method, params=None): await self._send({"jsonrpc": "2.0", "method": method, "params": params or {}}) def create_stream(self, request_id: str): self._chat_streams[request_id] = { "parts": [], "chunks": asyncio.Queue(), "done": asyncio.Event(), "finish": None, "tool_states": {}, "tool_order": [], "started_at": time.monotonic(), "first_chunk_at": None, "finish_at": None, } def pop_stream(self, request_id: str) -> None: stream = self._chat_streams.pop(request_id, None) if stream is None: return for tool_id, mapped_req in list(self._tool_stream_map.items()): if mapped_req == request_id: self._tool_stream_map.pop(tool_id, None) self._tool_roundtrip_done.discard(tool_id) # Drain queue so no stray future gets stuck if the consumer bailed early. if not stream["done"].is_set(): stream["done"].set() with contextlib.suppress(Exception): stream["chunks"].put_nowait(None) async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[dict[str, Any]]: stream = self._chat_streams.get(request_id) if stream is None: return start = time.monotonic() last_chunk_at = start while True: remain = timeout - (time.monotonic() - start) if remain <= 0: first_chunk_at = stream.get("first_chunk_at") raise TimeoutError( "chat stream timeout " f"request_id={request_id} timeout={timeout:.1f}s " f"first_chunk_at={None if first_chunk_at is None else round(first_chunk_at - start, 3)}s " f"last_chunk_at={round(last_chunk_at - start, 3)}s" ) chunk = await asyncio.wait_for(stream["chunks"].get(), timeout=remain) if chunk is None: break last_chunk_at = time.monotonic() yield chunk def get_stream_result(self, request_id: str) -> dict: stream = self._chat_streams.get(request_id) or {} first_ms = None total_ms = None if stream.get("first_chunk_at") is not None: first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000) if stream.get("finish_at") is not None: total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000) ordered_tool_events: list[dict[str, Any]] = [] tool_states = stream.get("tool_states") or {} for tool_id in stream.get("tool_order") or []: event = tool_states.get(tool_id) if isinstance(event, dict): ordered_tool_events.append(event) return { "text": "".join(stream.get("parts") or []), "finish": stream.get("finish") or {}, "firstTokenLatencyMs": first_ms, "totalLatencyMs": total_ms, "toolEvents": ordered_tool_events, } class LingmaGatewayClient: """Owns the Lingma subprocess and the LSP-over-WS connection. Adds a small state machine + reconnect loop so the gateway can survive Lingma restarts and slow cold starts without bringing down the FastAPI app. """ STATE_STOPPED = "stopped" STATE_STARTING = "starting" STATE_READY = "ready" STATE_RECONNECTING = "reconnecting" STATE_FAILED = "failed" STATE_CLOSED = "closed" def __init__( self, lingma_bin: str, work_dir: str, socket_port: int, startup_timeout: int, rpc_timeout: int, default_model: str, default_ask_mode: str, *, name: str = "lingma", extra_info_paths: list[Path] | None = None, ): self.name = name self.lingma_bin = Path(lingma_bin) self.work_dir = Path(work_dir) self.socket_port = socket_port self.startup_timeout = startup_timeout self.rpc_timeout = rpc_timeout self.default_model = default_model self.default_ask_mode = default_ask_mode # Each pool instance should only look at its own workDir .info to avoid # cross-instance clobbering via the shared ~/.lingma/.info path. if extra_info_paths is None: extra_info_paths = [Path.home() / ".lingma" / ".info"] self._extra_info_paths = list(extra_info_paths) self._rpc: LspWsRpcClient | None = None self._ws = None self._state = self.STATE_STOPPED self._state_lock = asyncio.Lock() self._ready_event = asyncio.Event() self._reconnect_task: asyncio.Task | None = None self._last_error: str = "" # Lingma subprocess handle. Kept so we can reap on shutdown and read # stderr for debugging (pre-v0.4 we forked with DEVNULL + new_session # which orphaned the process and hid crash logs). self._proc: subprocess.Popen | None = None self._stderr_task: asyncio.Task | None = None # ------------------------------------------------------------------ state @property def state(self) -> str: return self._state @property def last_error(self) -> str: return self._last_error def _set_state(self, state: str, err: str = "") -> None: if state != self._state: logger.info("lingma client state %s -> %s", self._state, state, extra={"ctx_new_state": state}) self._state = state if err: self._last_error = err if state == self.STATE_READY: self._ready_event.set() else: self._ready_event.clear() # -------------------------------------------------------------- lifecycle async def start(self) -> None: """Initial start. Failure is non-fatal: ensure_ready() will retry later.""" try: await self._connect(initial=True) except Exception as exc: self._set_state(self.STATE_FAILED, err=str(exc)) logger.exception("initial lingma start failed; will retry on demand") async def close(self) -> None: self._set_state(self.STATE_CLOSED) if self._reconnect_task and not self._reconnect_task.done(): self._reconnect_task.cancel() with contextlib.suppress(Exception): await self._reconnect_task if self._rpc: await self._rpc.close() if self._ws: with contextlib.suppress(Exception): await self._ws.close() await self._terminate_proc() if self._stderr_task and not self._stderr_task.done(): self._stderr_task.cancel() with contextlib.suppress(Exception): await self._stderr_task async def _drain_stderr(self, proc: subprocess.Popen) -> None: """Mirror Lingma stderr to the logger at DEBUG level. Running in a worker thread (readline is blocking) and dumping lines through logger.debug means crashes like native-module load failures are visible when LOG_LEVEL=DEBUG but don't spam production logs. """ if proc.stderr is None: return name = self.name def reader() -> None: try: for line in iter(proc.stderr.readline, b""): if not line: break text = line.decode("utf-8", errors="replace").rstrip() if text: logger.debug("[%s] lingma stderr: %s", name, text) except Exception as exc: # pragma: no cover -- defensive logger.debug("[%s] stderr drain aborted: %s", name, exc) try: await asyncio.to_thread(reader) except asyncio.CancelledError: pass async def _terminate_proc(self) -> None: """Reap the Lingma subprocess we spawned. SIGTERM first with a short grace period, then SIGKILL. Blocking waits are off-loaded to a thread so they don't stall the FastAPI shutdown event loop. Idempotent: safe to call even if nothing was spawned. """ proc = self._proc if proc is None: return self._proc = None try: if proc.poll() is None: try: proc.terminate() except Exception as exc: logger.warning("[%s] proc.terminate failed: %s", self.name, exc) try: await asyncio.wait_for(asyncio.to_thread(proc.wait), timeout=5.0) except TIMEOUT_EXCEPTIONS: logger.warning( "[%s] lingma (pid=%s) didn't exit in 5s, sending SIGKILL", self.name, proc.pid, ) with contextlib.suppress(Exception): proc.kill() with contextlib.suppress(Exception): await asyncio.wait_for( asyncio.to_thread(proc.wait), timeout=3.0 ) finally: # Close stderr pipe so the drain thread can exit cleanly. if proc.stderr is not None: with contextlib.suppress(Exception): proc.stderr.close() async def ensure_ready(self, timeout: float | None = None) -> None: """Block until the RPC connection is usable, (re)connecting on demand.""" if self._state == self.STATE_CLOSED: raise RuntimeError("lingma client is closed") if self._state == self.STATE_READY and self._ws is not None: return async with self._state_lock: if self._state == self.STATE_READY and self._ws is not None: return if self._state in (self.STATE_STOPPED, self.STATE_FAILED): try: await self._connect(initial=False) return except Exception as exc: self._set_state(self.STATE_FAILED, err=str(exc)) raise wait_timeout = timeout if timeout is not None else max( 30.0, float(self.startup_timeout) + 10.0 ) try: await asyncio.wait_for(self._ready_event.wait(), timeout=wait_timeout) except TIMEOUT_EXCEPTIONS: raise RuntimeError(f"lingma not ready (state={self._state}, err={self._last_error})") # --------------------------------------------------------------- connect async def _connect(self, *, initial: bool) -> None: self._set_state(self.STATE_STARTING) if not self.lingma_bin.exists(): raise FileNotFoundError(f"Lingma not found: {self.lingma_bin}") info_paths = [self.work_dir / ".info", *self._extra_info_paths] # socket_port <= 0 is the pool-friendly "always spawn and read .info" mode. port_prewarmed = self.socket_port > 0 and _is_port_open( "127.0.0.1", self.socket_port ) if not port_prewarmed: self.work_dir.mkdir(parents=True, exist_ok=True) # Remove stale info files from host-mounted workspace before boot. for p in info_paths: with contextlib.suppress(Exception): if p.exists(): p.unlink() logger.info( "[%s] spawning lingma: %s start --workDir %s", self.name, self.lingma_bin, self.work_dir, ) # Reap any old proc from a previous connect attempt before spawning # a fresh one so we never accumulate zombie Lingma instances. await self._terminate_proc() if self._stderr_task and not self._stderr_task.done(): self._stderr_task.cancel() with contextlib.suppress(Exception): await self._stderr_task self._stderr_task = None self._proc = subprocess.Popen( [str(self.lingma_bin), "start", "--workDir", str(self.work_dir)], cwd=str(self.lingma_bin.parent), stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, ) logger.info( "[%s] lingma spawned (pid=%s)", self.name, self._proc.pid ) self._stderr_task = asyncio.create_task( self._drain_stderr(self._proc) ) info, _, _ = _wait_info_any(info_paths, timeout_sec=self.startup_timeout) self.socket_port = info deadline = time.time() + self.startup_timeout while time.time() < deadline: if _is_port_open("127.0.0.1", self.socket_port, timeout_sec=0.3): break await asyncio.sleep(0.2) else: raise TimeoutError(f"Lingma socket not open on port {self.socket_port}") # Close any stale ws/rpc before creating fresh ones (reconnect path). if self._rpc is not None: with contextlib.suppress(Exception): await self._rpc.close() self._rpc = None if self._ws is not None: with contextlib.suppress(Exception): await self._ws.close() self._ws = None ws_url = f"ws://127.0.0.1:{self.socket_port}" self._ws = await websockets.connect(ws_url, max_size=10 * 1024 * 1024) self._rpc = LspWsRpcClient(self._ws, on_disconnect=self._on_disconnect) await self._rpc.start() await self._rpc.request( "initialize", { "processId": os.getpid(), "clientInfo": {"name": "lingma-openai-gateway", "version": "0.1.0"}, "capabilities": {}, "workspaceFolders": [], "rootUri": None, }, timeout=self.rpc_timeout, ) await self._rpc.notify("initialized", {}) self._set_state(self.STATE_READY) logger.info( "[%s] lingma ready on port %d (initial=%s)", self.name, self.socket_port, initial, ) def _on_disconnect(self, exc: BaseException) -> None: if self._state == self.STATE_CLOSED: return self._set_state(self.STATE_RECONNECTING, err=str(exc)) if self._reconnect_task and not self._reconnect_task.done(): return try: loop = asyncio.get_running_loop() except RuntimeError: return self._reconnect_task = loop.create_task(self._reconnect_loop()) async def _reconnect_loop(self) -> None: backoff = 1.0 max_backoff = 30.0 max_attempts = 20 for attempt in range(1, max_attempts + 1): if self._state == self.STATE_CLOSED: return await asyncio.sleep(backoff) try: async with self._state_lock: await self._connect(initial=False) logger.info("lingma reconnected after %d attempt(s)", attempt) return except Exception as exc: self._last_error = str(exc) logger.warning("lingma reconnect attempt %d failed: %s", attempt, exc) backoff = min(backoff * 2, max_backoff) self._set_state(self.STATE_FAILED, err="reconnect exhausted") # ------------------------------------------------------------------ RPC @property def rpc(self) -> LspWsRpcClient: if self._rpc is None: raise RuntimeError(f"Lingma RPC not initialized (state={self._state})") return self._rpc async def auth_status(self): await self.ensure_ready() return await self.rpc.request("auth/status", {}, timeout=self.rpc_timeout) async def query_models(self): await self.ensure_ready() return await self.rpc.request("config/queryModels", {}, timeout=self.rpc_timeout) async def get_endpoint(self): await self.ensure_ready() return await self.rpc.request("config/getEndpoint", {}, timeout=self.rpc_timeout) async def update_endpoint(self, endpoint: str): await self.ensure_ready() return await self.rpc.request( "config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout ) async def generate_login_url(self): await self.ensure_ready() result = await self.rpc.request("login/generateUrl", {}, timeout=self.rpc_timeout) if isinstance(result, str): return result, {"raw": result} if isinstance(result, dict): for key in ("loginUrl", "url", "login_url"): if isinstance(result.get(key), str): return result[key], result return "", result return "", {"raw": result} # ------------------------------------------------------------------ chat def _build_payload( self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str, *, is_reply: bool = False, tool_config: dict[str, Any] | None = None, ): session_type = "ask" if ask_mode == "agent" else "chat" payload = { "requestId": request_id, "sessionId": session_id, "sessionType": session_type, "chatTask": "chat" if ask_mode == "agent" else "FREE_INPUT", "mode": ask_mode, "stream": True, "source": 1, "isReply": is_reply, "taskDefinitionType": "system", "content": prompt, "text": prompt, "message": prompt, "questionText": prompt, "extra": { "modelConfig": {"key": model_key}, "workspacePath": str(Path.cwd()), }, "pluginPayloadConfig": { "isEnableAskAgent": ask_mode == "agent", "isEnableAutoMemory": True, }, "chatContext": { "text": prompt, "features": [], "preferredLanguage": "zh-CN", "localeLang": "zh-CN", }, } if tool_config is not None: if "tools" in tool_config and tool_config["tools"]: payload["tools"] = tool_config["tools"] if "tool_choice" in tool_config and tool_config["tool_choice"]: payload["tool_choice"] = tool_config["tool_choice"] logger.info( "lingma payload request_id=%s session_id=%s mode=%s tool_config=%s", request_id, session_id, ask_mode, _tool_config_summary(tool_config), ) return payload async def _kick_chat_ask(self, payload: dict) -> None: """Fire chat/ask as a notification. Lingma streams answers back via `chat/answer` + `chat/finish` and never returns a JSON-RPC `result` for `chat/ask`. Waiting for one wasted `rpc_timeout` seconds before the first byte could leave the gateway — matching our previous 30s TTFB bug. `notify` sidesteps that entirely by not registering a pending future. """ await self.rpc.notify("chat/ask", payload) async def chat_complete( self, prompt: str, model_key: str, ask_mode: str, *, session_id: str | None = None, is_reply: bool = False, tool_config: dict[str, Any] | None = None, ) -> dict: await self.ensure_ready() request_id = str(uuid.uuid4()) sid = session_id or str(uuid.uuid4()) payload = self._build_payload( prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply, tool_config=tool_config, ) self.rpc.create_stream(request_id) try: await self._kick_chat_ask(payload) # Consume until chat/finish closes the stream or the upstream idles. async for _ in self.rpc.consume_stream( request_id, timeout=max(60.0, self.rpc_timeout + 30.0) ): pass result = self.rpc.get_stream_result(request_id) finally: self.rpc.pop_stream(request_id) finish = result.get("finish") or {} result["requestId"] = request_id # Prefer upstream-reported sessionId so the next turn binds to whatever # Lingma actually allocated (sometimes differs from our hint). result["sessionId"] = finish.get("sessionId") or sid result["model"] = model_key result["mode"] = ask_mode result["isReply"] = is_reply return result async def chat_stream( self, prompt: str, model_key: str, ask_mode: str, *, session_id: str | None = None, is_reply: bool = False, tool_config: dict[str, Any] | None = None, out_meta: dict | None = None, ) -> AsyncIterator[dict[str, Any]]: """Stream chat events. Yields structured events: * {"type": "text", "text": "..."} * {"type": "tool", "tool": {...}} If `out_meta` is provided, the final `chat/finish` payload's sessionId (and the raw finish dict) is written into it when the stream ends or is cancelled. This is the hook the session cache uses to record the upstream sessionId without holding a second reference to the RPC. """ await self.ensure_ready() request_id = str(uuid.uuid4()) sid = session_id or str(uuid.uuid4()) payload = self._build_payload( prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply, tool_config=tool_config, ) self.rpc.create_stream(request_id) try: await self._kick_chat_ask(payload) async for event in self.rpc.consume_stream( request_id, timeout=max(60.0, self.rpc_timeout + 60.0) ): yield event finally: # Runs on normal completion, exception, or consumer GeneratorExit (client disconnect). if out_meta is not None: try: stream_result = self.rpc.get_stream_result(request_id) finish = stream_result.get("finish") or {} out_meta["session_id"] = finish.get("sessionId") or sid out_meta["finish"] = finish out_meta["request_id"] = request_id out_meta["chars"] = len(stream_result.get("text") or "") out_meta["tool_events"] = stream_result.get("toolEvents") or [] except Exception: pass self.rpc.pop_stream(request_id)