Behavior hardening (M1):
- Fix `_chat_streams` memory leak: pop_stream on completion, error, and
client disconnect.
- Add WebSocket reconnect with state machine (stopped/starting/ready/
reconnecting/failed/closed) and exponential backoff, so a Lingma
restart no longer requires restarting the gateway.
- Lazy initialization: startup failure is non-fatal, first real request
triggers retry, `/healthz` reflects readiness.
- Migrate FastAPI on_event to lifespan.
- Structured JSON logging with request_id ContextVar; `x-request-id`
propagated to responses.
- SSE now sets `Cache-Control: no-cache`, `X-Accel-Buffering: no` to
defeat proxy buffering.
- OpenAI schema compatibility: `content` accepts str | list[parts] | None,
added `developer`/`function` roles, `tools/tool_choice/stream_options/
user/max_tokens` fields, and `stream_options.include_usage` emits final
usage chunk.
- `require_bearer` uses `hmac.compare_digest`; `/metrics` now requires
Bearer when `METRICS_TOKEN` or `API_KEYS` are set.
- Python 3.10/3.11 `TimeoutError` vs `asyncio.TimeoutError` unified.
- Error responses no longer leak `auto_login.status()` details.
Backpressure (M2 / A2):
- New `InFlightGuard` with per-request ticket, queue + rejection
accounting, `BackpressureRejected` raises 429 + `Retry-After` once
`GATEWAY_QUEUE_TIMEOUT_SEC` elapses.
- Streaming ticket ownership transfers to the generator so CancelledError
from client disconnect still releases the slot.
- `/internal/stats.concurrency` and `/metrics` expose in_flight/queued/
accepted_total/rejected_total/max_in_flight.
Multi-instance pool (M2 / A1 + B3):
- New `LingmaPool` with N processes, each with its own workDir, socket
port (dynamic when N>1), and `AutoLoginManager`.
- Account parser supports CSV (`u1:p1,u2:p2`) and JSON formats via
`LINGMA_ACCOUNTS`; falls back to `LINGMA_USERNAME/LINGMA_PASSWORD` for
backwards compatibility (N=1 keeps legacy paths/ports).
- Routing: sticky affinity by `user` / system-prompt hash, then
least-in-flight, finally round-robin fallback for unhealthy pool.
- `/healthz` reports per-instance state and ready count.
- `/internal/stats.pool` and `/metrics` expose per-instance
`gateway_pool_instance_in_flight{name}` / `gateway_pool_instance_ready{name}`.
- `/internal/auto-login/start?instance=inst-N` targets a specific instance;
`/internal/auto-login/status` lists all instances.
Compat notes:
- `.env.example` adds `METRICS_TOKEN`, `LOG_LEVEL`, `GATEWAY_MAX_IN_FLIGHT`,
`GATEWAY_QUEUE_TIMEOUT_SEC`, `LINGMA_ACCOUNTS`, `LINGMA_INSTANCE_COUNT`.
- `.gitignore` cleaned up data/ duplication.
- Existing single-instance deployments keep working without config change.
Made-with: Cursor
615 lines
23 KiB
Python
615 lines
23 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import json
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import time
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import AsyncIterator, Callable, Optional
|
|
|
|
import websockets
|
|
|
|
from .logging_config import get_logger
|
|
|
|
|
|
logger = get_logger("lingma_gateway.client")
|
|
|
|
|
|
# Some callers live on Python 3.10 where asyncio.TimeoutError is a distinct class,
|
|
# while 3.11+ unifies it with the builtin TimeoutError. Always catch both.
|
|
TIMEOUT_EXCEPTIONS: tuple[type[BaseException], ...] = (
|
|
asyncio.TimeoutError,
|
|
TimeoutError,
|
|
)
|
|
|
|
|
|
def _is_port_open(host: str, port: int, timeout_sec: float = 0.5) -> bool:
|
|
try:
|
|
with socket.create_connection((host, port), timeout=timeout_sec):
|
|
return True
|
|
except OSError:
|
|
return False
|
|
|
|
|
|
def _read_info_file(info_path: Path):
|
|
if not info_path.exists():
|
|
return None, None
|
|
txt = info_path.read_text(encoding="utf-8", errors="ignore").strip()
|
|
if not txt:
|
|
return None, None
|
|
lines = txt.splitlines()
|
|
if len(lines) < 2:
|
|
return None, None
|
|
try:
|
|
return int(lines[0].strip()), int(lines[1].strip())
|
|
except ValueError:
|
|
return None, None
|
|
|
|
|
|
def _wait_info_any(info_paths: list[Path], timeout_sec: int):
|
|
start = time.time()
|
|
while time.time() - start < timeout_sec:
|
|
for p in info_paths:
|
|
port, pid = _read_info_file(p)
|
|
if port and pid:
|
|
return port, pid, p
|
|
time.sleep(0.2)
|
|
raise TimeoutError(".info not ready")
|
|
|
|
|
|
def _encode_lsp_frame(payload_obj: dict) -> bytes:
|
|
body = json.dumps(payload_obj, ensure_ascii=False).encode("utf-8")
|
|
header = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")
|
|
return header + body
|
|
|
|
|
|
def _parse_lsp_frames(buf: bytes):
|
|
frames = []
|
|
while True:
|
|
header_end = buf.find(b"\r\n\r\n")
|
|
if header_end < 0:
|
|
break
|
|
header = buf[:header_end]
|
|
body_start = header_end + 4
|
|
content_length = None
|
|
for line in header.split(b"\r\n"):
|
|
if line.lower().startswith(b"content-length:"):
|
|
content_length = int(line.split(b":", 1)[1].strip())
|
|
break
|
|
if content_length is None:
|
|
buf = buf[body_start:]
|
|
continue
|
|
if len(buf) < body_start + content_length:
|
|
break
|
|
body = buf[body_start : body_start + content_length]
|
|
frames.append(body.decode("utf-8", errors="ignore"))
|
|
buf = buf[body_start + content_length :]
|
|
return frames, buf
|
|
|
|
|
|
class LspWsRpcClient:
|
|
def __init__(self, ws, on_disconnect: Optional[Callable[[BaseException], None]] = None):
|
|
self.ws = ws
|
|
self._id = 1
|
|
self._pending: dict[int, asyncio.Future] = {}
|
|
self._send_lock = asyncio.Lock()
|
|
self._reader_task: asyncio.Task | None = None
|
|
self._rx_buffer = b""
|
|
self._chat_streams: dict[str, dict] = {}
|
|
self._on_disconnect = on_disconnect
|
|
self._closed = False
|
|
|
|
async def start(self):
|
|
self._reader_task = asyncio.create_task(self._reader_loop())
|
|
|
|
async def close(self):
|
|
self._closed = True
|
|
if self._reader_task:
|
|
self._reader_task.cancel()
|
|
with contextlib.suppress(Exception):
|
|
await self._reader_task
|
|
# Abort any pending futures so callers fail fast instead of hanging.
|
|
for fut in self._pending.values():
|
|
if not fut.done():
|
|
fut.set_exception(ConnectionError("lingma client closed"))
|
|
self._pending.clear()
|
|
# Signal open streams to terminate.
|
|
for stream in self._chat_streams.values():
|
|
if not stream["done"].is_set():
|
|
stream["done"].set()
|
|
stream["chunks"].put_nowait(None)
|
|
self._chat_streams.clear()
|
|
|
|
async def _send(self, payload: dict):
|
|
async with self._send_lock:
|
|
await self.ws.send(_encode_lsp_frame(payload))
|
|
|
|
async def _reader_loop(self):
|
|
try:
|
|
while True:
|
|
raw = await self.ws.recv()
|
|
chunk = raw if isinstance(raw, bytes) else raw.encode("utf-8", errors="ignore")
|
|
self._rx_buffer += chunk
|
|
bodies, self._rx_buffer = _parse_lsp_frames(self._rx_buffer)
|
|
for body in bodies:
|
|
try:
|
|
msg = json.loads(body)
|
|
except Exception:
|
|
continue
|
|
|
|
if "method" in msg and "result" not in msg and "error" not in msg:
|
|
await self._handle_server_message(msg)
|
|
continue
|
|
|
|
rid = msg.get("id")
|
|
if rid is None:
|
|
continue
|
|
fut = self._pending.pop(rid, None)
|
|
if fut and not fut.done():
|
|
fut.set_result(msg)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as exc:
|
|
if not self._closed:
|
|
logger.warning("lingma reader loop terminated: %s", exc)
|
|
# Propagate failure to anyone waiting on an RPC.
|
|
for fut in self._pending.values():
|
|
if not fut.done():
|
|
fut.set_exception(exc)
|
|
self._pending.clear()
|
|
# Also unblock any in-flight chat streams so consumers exit.
|
|
for stream in self._chat_streams.values():
|
|
if not stream["done"].is_set():
|
|
stream["done"].set()
|
|
stream["chunks"].put_nowait(None)
|
|
if not self._closed and self._on_disconnect is not None:
|
|
try:
|
|
self._on_disconnect(exc)
|
|
except Exception:
|
|
logger.exception("on_disconnect callback failed")
|
|
|
|
async def _handle_server_message(self, msg: dict):
|
|
method = msg.get("method")
|
|
params = msg.get("params") or {}
|
|
|
|
if method == "chat/answer":
|
|
req_id = params.get("requestId")
|
|
stream = self._chat_streams.get(req_id)
|
|
if stream is not None:
|
|
text = params.get("text") or params.get("content") or ""
|
|
if text:
|
|
stream["parts"].append(text)
|
|
if stream["first_chunk_at"] is None:
|
|
stream["first_chunk_at"] = time.monotonic()
|
|
stream["chunks"].put_nowait(text)
|
|
|
|
if method == "chat/finish":
|
|
req_id = params.get("requestId")
|
|
stream = self._chat_streams.get(req_id)
|
|
if stream is not None and not stream["done"].is_set():
|
|
stream["finish"] = params
|
|
stream["finish_at"] = time.monotonic()
|
|
stream["done"].set()
|
|
stream["chunks"].put_nowait(None)
|
|
|
|
if "id" in msg:
|
|
await self._send({"jsonrpc": "2.0", "id": msg.get("id"), "result": {}})
|
|
|
|
async def request(self, method, params=None, timeout=20):
|
|
rid = self._id
|
|
self._id += 1
|
|
payload = {"jsonrpc": "2.0", "id": rid, "method": method, "params": params or {}}
|
|
fut = asyncio.get_running_loop().create_future()
|
|
self._pending[rid] = fut
|
|
await self._send(payload)
|
|
try:
|
|
msg = await asyncio.wait_for(fut, timeout=timeout)
|
|
except TIMEOUT_EXCEPTIONS:
|
|
self._pending.pop(rid, None)
|
|
raise TimeoutError(f"RPC timeout: {method}")
|
|
if "error" in msg:
|
|
raise RuntimeError(f"RPC {method} error: {msg['error']}")
|
|
return msg.get("result")
|
|
|
|
async def notify(self, method, params=None):
|
|
await self._send({"jsonrpc": "2.0", "method": method, "params": params or {}})
|
|
|
|
def create_stream(self, request_id: str):
|
|
self._chat_streams[request_id] = {
|
|
"parts": [],
|
|
"chunks": asyncio.Queue(),
|
|
"done": asyncio.Event(),
|
|
"finish": None,
|
|
"started_at": time.monotonic(),
|
|
"first_chunk_at": None,
|
|
"finish_at": None,
|
|
}
|
|
|
|
def pop_stream(self, request_id: str) -> None:
|
|
stream = self._chat_streams.pop(request_id, None)
|
|
if stream is None:
|
|
return
|
|
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
|
if not stream["done"].is_set():
|
|
stream["done"].set()
|
|
with contextlib.suppress(Exception):
|
|
stream["chunks"].put_nowait(None)
|
|
|
|
async def consume_stream(self, request_id: str, timeout: float) -> AsyncIterator[str]:
|
|
stream = self._chat_streams.get(request_id)
|
|
if stream is None:
|
|
return
|
|
start = time.monotonic()
|
|
while True:
|
|
remain = timeout - (time.monotonic() - start)
|
|
if remain <= 0:
|
|
raise TimeoutError("chat stream timeout")
|
|
chunk = await asyncio.wait_for(stream["chunks"].get(), timeout=remain)
|
|
if chunk is None:
|
|
break
|
|
yield chunk
|
|
|
|
def get_stream_result(self, request_id: str) -> dict:
|
|
stream = self._chat_streams.get(request_id) or {}
|
|
first_ms = None
|
|
total_ms = None
|
|
if stream.get("first_chunk_at") is not None:
|
|
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
|
if stream.get("finish_at") is not None:
|
|
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
|
return {
|
|
"text": "".join(stream.get("parts") or []),
|
|
"finish": stream.get("finish") or {},
|
|
"firstTokenLatencyMs": first_ms,
|
|
"totalLatencyMs": total_ms,
|
|
}
|
|
|
|
|
|
class LingmaGatewayClient:
|
|
"""Owns the Lingma subprocess and the LSP-over-WS connection.
|
|
|
|
Adds a small state machine + reconnect loop so the gateway can survive Lingma
|
|
restarts and slow cold starts without bringing down the FastAPI app.
|
|
"""
|
|
|
|
STATE_STOPPED = "stopped"
|
|
STATE_STARTING = "starting"
|
|
STATE_READY = "ready"
|
|
STATE_RECONNECTING = "reconnecting"
|
|
STATE_FAILED = "failed"
|
|
STATE_CLOSED = "closed"
|
|
|
|
def __init__(
|
|
self,
|
|
lingma_bin: str,
|
|
work_dir: str,
|
|
socket_port: int,
|
|
startup_timeout: int,
|
|
rpc_timeout: int,
|
|
default_model: str,
|
|
default_ask_mode: str,
|
|
*,
|
|
name: str = "lingma",
|
|
extra_info_paths: list[Path] | None = None,
|
|
):
|
|
self.name = name
|
|
self.lingma_bin = Path(lingma_bin)
|
|
self.work_dir = Path(work_dir)
|
|
self.socket_port = socket_port
|
|
self.startup_timeout = startup_timeout
|
|
self.rpc_timeout = rpc_timeout
|
|
self.default_model = default_model
|
|
self.default_ask_mode = default_ask_mode
|
|
# Each pool instance should only look at its own workDir .info to avoid
|
|
# cross-instance clobbering via the shared ~/.lingma/.info path.
|
|
if extra_info_paths is None:
|
|
extra_info_paths = [Path.home() / ".lingma" / ".info"]
|
|
self._extra_info_paths = list(extra_info_paths)
|
|
self._rpc: LspWsRpcClient | None = None
|
|
self._ws = None
|
|
self._state = self.STATE_STOPPED
|
|
self._state_lock = asyncio.Lock()
|
|
self._ready_event = asyncio.Event()
|
|
self._reconnect_task: asyncio.Task | None = None
|
|
self._last_error: str = ""
|
|
|
|
# ------------------------------------------------------------------ state
|
|
|
|
@property
|
|
def state(self) -> str:
|
|
return self._state
|
|
|
|
@property
|
|
def last_error(self) -> str:
|
|
return self._last_error
|
|
|
|
def _set_state(self, state: str, err: str = "") -> None:
|
|
if state != self._state:
|
|
logger.info("lingma client state %s -> %s", self._state, state, extra={"ctx_new_state": state})
|
|
self._state = state
|
|
if err:
|
|
self._last_error = err
|
|
if state == self.STATE_READY:
|
|
self._ready_event.set()
|
|
else:
|
|
self._ready_event.clear()
|
|
|
|
# -------------------------------------------------------------- lifecycle
|
|
|
|
async def start(self) -> None:
|
|
"""Initial start. Failure is non-fatal: ensure_ready() will retry later."""
|
|
try:
|
|
await self._connect(initial=True)
|
|
except Exception as exc:
|
|
self._set_state(self.STATE_FAILED, err=str(exc))
|
|
logger.exception("initial lingma start failed; will retry on demand")
|
|
|
|
async def close(self) -> None:
|
|
self._set_state(self.STATE_CLOSED)
|
|
if self._reconnect_task and not self._reconnect_task.done():
|
|
self._reconnect_task.cancel()
|
|
with contextlib.suppress(Exception):
|
|
await self._reconnect_task
|
|
if self._rpc:
|
|
await self._rpc.close()
|
|
if self._ws:
|
|
with contextlib.suppress(Exception):
|
|
await self._ws.close()
|
|
|
|
async def ensure_ready(self, timeout: float | None = None) -> None:
|
|
"""Block until the RPC connection is usable, (re)connecting on demand."""
|
|
if self._state == self.STATE_CLOSED:
|
|
raise RuntimeError("lingma client is closed")
|
|
if self._state == self.STATE_READY and self._ws is not None:
|
|
return
|
|
|
|
async with self._state_lock:
|
|
if self._state == self.STATE_READY and self._ws is not None:
|
|
return
|
|
if self._state in (self.STATE_STOPPED, self.STATE_FAILED):
|
|
try:
|
|
await self._connect(initial=False)
|
|
return
|
|
except Exception as exc:
|
|
self._set_state(self.STATE_FAILED, err=str(exc))
|
|
raise
|
|
|
|
wait_timeout = timeout if timeout is not None else max(
|
|
30.0, float(self.startup_timeout) + 10.0
|
|
)
|
|
try:
|
|
await asyncio.wait_for(self._ready_event.wait(), timeout=wait_timeout)
|
|
except TIMEOUT_EXCEPTIONS:
|
|
raise RuntimeError(f"lingma not ready (state={self._state}, err={self._last_error})")
|
|
|
|
# --------------------------------------------------------------- connect
|
|
|
|
async def _connect(self, *, initial: bool) -> None:
|
|
self._set_state(self.STATE_STARTING)
|
|
|
|
if not self.lingma_bin.exists():
|
|
raise FileNotFoundError(f"Lingma not found: {self.lingma_bin}")
|
|
|
|
info_paths = [self.work_dir / ".info", *self._extra_info_paths]
|
|
|
|
# socket_port <= 0 is the pool-friendly "always spawn and read .info" mode.
|
|
port_prewarmed = self.socket_port > 0 and _is_port_open(
|
|
"127.0.0.1", self.socket_port
|
|
)
|
|
if not port_prewarmed:
|
|
self.work_dir.mkdir(parents=True, exist_ok=True)
|
|
# Remove stale info files from host-mounted workspace before boot.
|
|
for p in info_paths:
|
|
with contextlib.suppress(Exception):
|
|
if p.exists():
|
|
p.unlink()
|
|
logger.info(
|
|
"[%s] spawning lingma: %s start --workDir %s",
|
|
self.name,
|
|
self.lingma_bin,
|
|
self.work_dir,
|
|
)
|
|
subprocess.Popen(
|
|
[str(self.lingma_bin), "start", "--workDir", str(self.work_dir)],
|
|
cwd=str(self.lingma_bin.parent),
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL,
|
|
start_new_session=True,
|
|
)
|
|
info, _, _ = _wait_info_any(info_paths, timeout_sec=self.startup_timeout)
|
|
self.socket_port = info
|
|
|
|
deadline = time.time() + self.startup_timeout
|
|
while time.time() < deadline:
|
|
if _is_port_open("127.0.0.1", self.socket_port, timeout_sec=0.3):
|
|
break
|
|
await asyncio.sleep(0.2)
|
|
else:
|
|
raise TimeoutError(f"Lingma socket not open on port {self.socket_port}")
|
|
|
|
# Close any stale ws/rpc before creating fresh ones (reconnect path).
|
|
if self._rpc is not None:
|
|
with contextlib.suppress(Exception):
|
|
await self._rpc.close()
|
|
self._rpc = None
|
|
if self._ws is not None:
|
|
with contextlib.suppress(Exception):
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
ws_url = f"ws://127.0.0.1:{self.socket_port}"
|
|
self._ws = await websockets.connect(ws_url, max_size=10 * 1024 * 1024)
|
|
self._rpc = LspWsRpcClient(self._ws, on_disconnect=self._on_disconnect)
|
|
await self._rpc.start()
|
|
await self._rpc.request(
|
|
"initialize",
|
|
{
|
|
"processId": os.getpid(),
|
|
"clientInfo": {"name": "lingma-openai-gateway", "version": "0.1.0"},
|
|
"capabilities": {},
|
|
"workspaceFolders": [],
|
|
"rootUri": None,
|
|
},
|
|
timeout=self.rpc_timeout,
|
|
)
|
|
await self._rpc.notify("initialized", {})
|
|
self._set_state(self.STATE_READY)
|
|
logger.info(
|
|
"[%s] lingma ready on port %d (initial=%s)",
|
|
self.name,
|
|
self.socket_port,
|
|
initial,
|
|
)
|
|
|
|
def _on_disconnect(self, exc: BaseException) -> None:
|
|
if self._state == self.STATE_CLOSED:
|
|
return
|
|
self._set_state(self.STATE_RECONNECTING, err=str(exc))
|
|
if self._reconnect_task and not self._reconnect_task.done():
|
|
return
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
return
|
|
self._reconnect_task = loop.create_task(self._reconnect_loop())
|
|
|
|
async def _reconnect_loop(self) -> None:
|
|
backoff = 1.0
|
|
max_backoff = 30.0
|
|
max_attempts = 20
|
|
for attempt in range(1, max_attempts + 1):
|
|
if self._state == self.STATE_CLOSED:
|
|
return
|
|
await asyncio.sleep(backoff)
|
|
try:
|
|
async with self._state_lock:
|
|
await self._connect(initial=False)
|
|
logger.info("lingma reconnected after %d attempt(s)", attempt)
|
|
return
|
|
except Exception as exc:
|
|
self._last_error = str(exc)
|
|
logger.warning("lingma reconnect attempt %d failed: %s", attempt, exc)
|
|
backoff = min(backoff * 2, max_backoff)
|
|
self._set_state(self.STATE_FAILED, err="reconnect exhausted")
|
|
|
|
# ------------------------------------------------------------------ RPC
|
|
|
|
@property
|
|
def rpc(self) -> LspWsRpcClient:
|
|
if self._rpc is None:
|
|
raise RuntimeError(f"Lingma RPC not initialized (state={self._state})")
|
|
return self._rpc
|
|
|
|
async def auth_status(self):
|
|
await self.ensure_ready()
|
|
return await self.rpc.request("auth/status", {}, timeout=self.rpc_timeout)
|
|
|
|
async def query_models(self):
|
|
await self.ensure_ready()
|
|
return await self.rpc.request("config/queryModels", {}, timeout=self.rpc_timeout)
|
|
|
|
async def get_endpoint(self):
|
|
await self.ensure_ready()
|
|
return await self.rpc.request("config/getEndpoint", {}, timeout=self.rpc_timeout)
|
|
|
|
async def update_endpoint(self, endpoint: str):
|
|
await self.ensure_ready()
|
|
return await self.rpc.request(
|
|
"config/updateEndpoint", {"endpoint": endpoint}, timeout=self.rpc_timeout
|
|
)
|
|
|
|
async def generate_login_url(self):
|
|
await self.ensure_ready()
|
|
result = await self.rpc.request("login/generateUrl", {}, timeout=self.rpc_timeout)
|
|
if isinstance(result, str):
|
|
return result, {"raw": result}
|
|
if isinstance(result, dict):
|
|
for key in ("loginUrl", "url", "login_url"):
|
|
if isinstance(result.get(key), str):
|
|
return result[key], result
|
|
return "", result
|
|
return "", {"raw": result}
|
|
|
|
# ------------------------------------------------------------------ chat
|
|
|
|
def _build_payload(self, prompt: str, model_key: str, ask_mode: str, session_id: str, request_id: str):
|
|
session_type = "developer" if ask_mode == "agent" else "chat"
|
|
return {
|
|
"requestId": request_id,
|
|
"sessionId": session_id,
|
|
"sessionType": session_type,
|
|
"chatTask": "FREE_INPUT",
|
|
"mode": ask_mode,
|
|
"stream": True,
|
|
"source": 1,
|
|
"isReply": False,
|
|
"taskDefinitionType": "system",
|
|
"content": prompt,
|
|
"text": prompt,
|
|
"message": prompt,
|
|
"questionText": prompt,
|
|
"extra": {
|
|
"modelConfig": {"key": model_key},
|
|
"workspacePath": str(Path.cwd()),
|
|
},
|
|
"pluginPayloadConfig": {
|
|
"isEnableAskAgent": ask_mode == "agent",
|
|
"isEnableAutoMemory": True,
|
|
},
|
|
"chatContext": {
|
|
"text": prompt,
|
|
"features": [],
|
|
"preferredLanguage": "zh-CN",
|
|
"localeLang": "zh-CN",
|
|
},
|
|
}
|
|
|
|
async def chat_complete(self, prompt: str, model_key: str, ask_mode: str) -> dict:
|
|
await self.ensure_ready()
|
|
request_id = str(uuid.uuid4())
|
|
session_id = str(uuid.uuid4())
|
|
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
|
|
self.rpc.create_stream(request_id)
|
|
try:
|
|
try:
|
|
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
|
|
except TIMEOUT_EXCEPTIONS:
|
|
# chat/ask often returns nothing until chat/finish arrives; tolerate.
|
|
pass
|
|
async for _ in self.rpc.consume_stream(
|
|
request_id, timeout=max(20.0, self.rpc_timeout + 20.0)
|
|
):
|
|
pass
|
|
result = self.rpc.get_stream_result(request_id)
|
|
finally:
|
|
self.rpc.pop_stream(request_id)
|
|
finish = result.get("finish") or {}
|
|
result["requestId"] = request_id
|
|
result["sessionId"] = finish.get("sessionId") or session_id
|
|
result["model"] = model_key
|
|
result["mode"] = ask_mode
|
|
return result
|
|
|
|
async def chat_stream(self, prompt: str, model_key: str, ask_mode: str) -> AsyncIterator[str]:
|
|
await self.ensure_ready()
|
|
request_id = str(uuid.uuid4())
|
|
session_id = str(uuid.uuid4())
|
|
payload = self._build_payload(prompt, model_key, ask_mode, session_id, request_id)
|
|
self.rpc.create_stream(request_id)
|
|
try:
|
|
try:
|
|
await self.rpc.request("chat/ask", payload, timeout=self.rpc_timeout)
|
|
except TIMEOUT_EXCEPTIONS:
|
|
pass
|
|
async for chunk in self.rpc.consume_stream(
|
|
request_id, timeout=max(20.0, self.rpc_timeout + 40.0)
|
|
):
|
|
yield chunk
|
|
finally:
|
|
# Runs on normal completion, exception, or consumer GeneratorExit (client disconnect).
|
|
self.rpc.pop_stream(request_id)
|