fix: harden tooling session reuse and event routing

Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-19 19:29:30 +08:00
parent 5aa7fbfae5
commit e600bae27c
5 changed files with 741 additions and 53 deletions

View File

@@ -44,6 +44,7 @@ class Settings:
session_reuse_enabled: bool = True
session_cache_max_entries: int = 256
session_cache_ttl_sec: float = 1800.0
tool_forward_enabled: bool = False
def _bool_env(name: str, default: bool) -> bool:
@@ -175,4 +176,5 @@ def load_settings() -> Settings:
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
)

View File

@@ -100,6 +100,7 @@ class LspWsRpcClient:
self._reader_task: asyncio.Task | None = None
self._rx_buffer = b""
self._chat_streams: dict[str, dict] = {}
self._tool_stream_map: dict[str, str] = {}
self._on_disconnect = on_disconnect
self._closed = False
@@ -202,6 +203,7 @@ class LspWsRpcClient:
stream["done"].set()
stream["chunks"].put_nowait(None)
self._chat_streams.clear()
self._tool_stream_map.clear()
async def _send(self, payload: dict):
async with self._send_lock:
@@ -251,6 +253,92 @@ class LspWsRpcClient:
except Exception:
logger.exception("on_disconnect callback failed")
@staticmethod
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
event_id = None
if isinstance(tool_event, dict):
event_id = tool_event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id.strip()
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
if isinstance(fallback_id, str) and fallback_id.strip():
return fallback_id.strip()
req_id = params.get("requestId")
name = None
if isinstance(tool_event, dict):
name = tool_event.get("name")
if not name:
name = params.get("name")
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
return f"{req_id.strip()}:tool:{name.strip()}"
if isinstance(req_id, str) and req_id.strip():
return f"{req_id.strip()}:tool"
return None
@staticmethod
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
merged = dict(existing or {})
changed = False
val = incoming.get("id")
if val and merged.get("id") != val:
merged["id"] = val
changed = True
name = incoming.get("name")
if name:
existing_name = merged.get("name")
if not existing_name:
merged["name"] = name
changed = True
else:
existing_norm = str(existing_name).strip().lower()
incoming_norm = str(name).strip().lower()
if existing_norm == "tool" and incoming_norm != "tool":
merged["name"] = name
changed = True
elif existing_norm != "tool" and incoming_norm == "tool":
pass
elif merged.get("name") != name:
merged["name"] = name
changed = True
if "input" in incoming and incoming.get("input") is not None:
incoming_input = incoming.get("input")
should_update_input = incoming_input != {} or "input" not in merged
if should_update_input and merged.get("input") != incoming_input:
merged["input"] = incoming_input
changed = True
if "result" in incoming and incoming.get("result") is not None:
if merged.get("result") != incoming.get("result"):
merged["result"] = incoming.get("result")
changed = True
return merged, changed
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
req_id = params.get("requestId")
if isinstance(req_id, str) and req_id.strip():
stream = self._chat_streams.get(req_id)
if stream is not None and tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
self._tool_stream_map[tool_id] = req_id
return stream
if tool_event is not None:
tool_id = self._normalize_tool_id(method, params, tool_event)
if tool_id:
mapped_req = self._tool_stream_map.get(tool_id)
if mapped_req:
return self._chat_streams.get(mapped_req)
return None
async def _handle_server_message(self, msg: dict):
method = msg.get("method")
params = msg.get("params") or {}
@@ -268,21 +356,29 @@ class LspWsRpcClient:
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
tool_event = self._extract_tool_event(params)
req_id = params.get("requestId")
stream = self._chat_streams.get(req_id) if req_id else None
if stream is None and tool_event is not None:
for item in self._chat_streams.values():
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
stream = item
break
if stream is None and len(self._chat_streams) == 1:
stream = next(iter(self._chat_streams.values()))
stream = self._resolve_tool_stream(method, params, tool_event)
if stream is not None and tool_event is not None:
stream["tool_events"].append(tool_event)
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
tool_id = self._normalize_tool_id(method, params, tool_event)
if not tool_id:
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
else:
tool_states = stream["tool_states"]
order = stream["tool_order"]
existing = tool_states.get(tool_id)
merged, changed = self._merge_tool_event(existing, tool_event)
if not existing:
if "id" not in merged or not merged.get("id"):
merged["id"] = tool_id
tool_states[tool_id] = merged
order.append(tool_id)
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif changed:
tool_states[tool_id] = merged
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
elif tool_event is not None:
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
if method == "chat/finish":
req_id = params.get("requestId")
@@ -321,7 +417,8 @@ class LspWsRpcClient:
"chunks": asyncio.Queue(),
"done": asyncio.Event(),
"finish": None,
"tool_events": [],
"tool_states": {},
"tool_order": [],
"started_at": time.monotonic(),
"first_chunk_at": None,
"finish_at": None,
@@ -331,6 +428,9 @@ class LspWsRpcClient:
stream = self._chat_streams.pop(request_id, None)
if stream is None:
return
for tool_id, mapped_req in list(self._tool_stream_map.items()):
if mapped_req == request_id:
self._tool_stream_map.pop(tool_id, None)
# Drain queue so no stray future gets stuck if the consumer bailed early.
if not stream["done"].is_set():
stream["done"].set()
@@ -359,12 +459,20 @@ class LspWsRpcClient:
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
if stream.get("finish_at") is not None:
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
ordered_tool_events: list[dict[str, Any]] = []
tool_states = stream.get("tool_states") or {}
for tool_id in stream.get("tool_order") or []:
event = tool_states.get(tool_id)
if isinstance(event, dict):
ordered_tool_events.append(event)
return {
"text": "".join(stream.get("parts") or []),
"finish": stream.get("finish") or {},
"firstTokenLatencyMs": first_ms,
"totalLatencyMs": total_ms,
"toolEvents": stream.get("tool_events") or [],
"toolEvents": ordered_tool_events,
}
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
request_id: str,
*,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
):
session_type = "developer" if ask_mode == "agent" else "chat"
return {
payload = {
"requestId": request_id,
"sessionId": session_id,
"sessionType": session_type,
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
"localeLang": "zh-CN",
},
}
if tool_config is not None:
payload["toolConfig"] = tool_config
return payload
async def _kick_chat_ask(self, payload: dict) -> None:
"""Fire chat/ask as a notification.
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
*,
session_id: str | None = None,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
) -> dict:
await self.ensure_ready()
request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
)
self.rpc.create_stream(request_id)
try:
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
*,
session_id: str | None = None,
is_reply: bool = False,
tool_config: dict[str, Any] | None = None,
out_meta: dict | None = None,
) -> AsyncIterator[dict[str, Any]]:
"""Stream chat events.
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
request_id = str(uuid.uuid4())
sid = session_id or str(uuid.uuid4())
payload = self._build_payload(
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
prompt,
model_key,
ask_mode,
sid,
request_id,
is_reply=is_reply,
tool_config=tool_config,
)
self.rpc.create_stream(request_id)
try:

View File

@@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage"))
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "openai",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "anthropic",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
for m in messages:
role = m.get("role")
if role == "tool":
return True
if role == "assistant" and m.get("tool_calls"):
return True
return False
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
if not isinstance(content, list):
return False
for item in content:
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
return True
return False
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
if _anthropic_content_has_tool_blocks(req.system):
return True
for m in req.messages:
if _anthropic_content_has_tool_blocks(m.content):
return True
return False
def _stream_event_type(event: Any) -> str:
if isinstance(event, dict):
t = event.get("type")
@@ -388,9 +452,9 @@ def _json_string(value: Any) -> str:
return "{}"
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
return {
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
"type": "function",
"function": {
"name": str(tool.get("name") or "tool"),
@@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
}
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
def _anthropic_tool_use_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any]:
return {
"type": "tool_use",
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
"id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
"name": str(tool.get("name") or "tool"),
"input": tool.get("input") if tool.get("input") is not None else {},
}
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
def _anthropic_tool_result_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any] | None:
if "result" not in tool:
return None
result = tool.get("result")
@@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
content = _json_string(result)
return {
"type": "tool_result",
"tool_use_id": str(tool.get("id") or ""),
"tool_use_id": str(tool.get("id") or forced_id or ""),
"content": content,
}
@@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
tool_config = _openai_tool_config(req)
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
reuse_eligible = (
session_cache.enabled
and ask_mode == "chat"
and len(messages_dump) >= 2
and not has_tooling_context
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
@@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
tool_call_indexes: dict[str, int] = {}
saw_tool_call = False
try:
async for chunk in _inst.client.chat_stream(
prompt,
@@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta,
):
if _stream_event_type(chunk) == "tool":
tool = _stream_tool_event(chunk)
if not tool:
continue
tool_id = str(tool.get("id") or "")
if not tool_id:
tool_id = f"call_{len(tool_call_indexes)}"
idx = tool_call_indexes.get(tool_id)
if idx is None:
idx = len(tool_call_indexes)
tool_call_indexes[tool_id] = idx
saw_tool_call = True
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
@@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"delta": {
"tool_calls": [
{
"index": 0,
**_openai_tool_call(tool),
"index": idx,
**_openai_tool_call(tool, forced_id=tool_id),
}
]
},
@@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "tool_calls" if saw_tool_call else "stop",
}
],
}
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
if include_usage:
usage_payload = {
"id": completion_id,
@@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
@@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
tool_events = result.get("toolEvents") or []
message_content = result.get("text") or ""
tool_calls: list[dict[str, Any]] = []
saw_tool_call = False
if isinstance(tool_events, list):
for item in tool_events:
for idx, item in enumerate(tool_events):
if isinstance(item, dict):
tool_calls.append(_openai_tool_call(item))
tool_id = str(item.get("id") or f"call_{idx}")
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
saw_tool_call = True
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
@@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
finish_reason="tool_calls" if saw_tool_call else "stop",
message={
"role": "assistant",
"content": message_content,
@@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
@@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
)
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str:
"""Approximate Anthropic `stop_reason`.
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it;
we report `max_tokens` only when the generated length meets or exceeds
the caller's stated ceiling. Everything else is `end_turn`.
"""
def _anthropic_stop_reason(
completion_tokens: int,
max_tokens: int,
*,
has_pending_tool_use: bool = False,
) -> str:
"""Approximate Anthropic `stop_reason`."""
if has_pending_tool_use:
return "tool_use"
if max_tokens and completion_tokens >= max_tokens:
return "max_tokens"
return "end_turn"
@@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
ask_mode = "chat"
tool_config = _anthropic_tool_config(req)
has_tooling_context = _anthropic_has_tooling_context(req)
reuse_eligible = (
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
@@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
return _anthropic_error(400, "invalid_request_error", "messages is empty")
prompt_tokens = estimate_tokens(prompt)
# ------------------------------------------------------------- backpressure
try:
ticket = await chat_guard.try_acquire()
@@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
success = False
block_index = 0
text_block_open = False
saw_pending_tool_use = False
try:
# 1) message_start — Anthropic SDKs read this first to get
# the message envelope (id/model/initial usage).
@@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta,
):
if _stream_event_type(chunk) == "tool":
@@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
tool = _stream_tool_event(chunk)
if not tool:
continue
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
tool_use_block = _anthropic_tool_use_block(tool)
tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
yield _sse(
"content_block_start",
{
@@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
)
block_index += 1
tool_result_block = _anthropic_tool_result_block(tool)
tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
if tool_result_block is not None:
yield _sse(
"content_block_start",
@@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
{"type": "content_block_stop", "index": block_index},
)
block_index += 1
else:
saw_pending_tool_use = True
continue
text = _stream_text(chunk)
@@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# 5) message_delta carries the terminal stop_reason and
# the final cumulative output_tokens count.
stop_reason = _anthropic_stop_reason(
completion_tokens_holder["n"], max_tokens
completion_tokens_holder["n"],
max_tokens,
has_pending_tool_use=saw_pending_tool_use,
)
yield _sse(
"message_delta",
@@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
},
)
# 6) message_stop — terminal event, no [DONE] sentinel.
yield _sse("message_stop", {"type": "message_stop"})
success = True
@@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
)
except Exception as exc:
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
@@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if text:
content_blocks.append({"type": "text", "text": text})
tool_events = result.get("toolEvents") or []
saw_pending_tool_use = False
if isinstance(tool_events, list):
for item in tool_events:
for idx, item in enumerate(tool_events):
if not isinstance(item, dict):
continue
content_blocks.append(_anthropic_tool_use_block(item))
tool_result = _anthropic_tool_result_block(item)
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
if tool_result is not None:
content_blocks.append(tool_result)
else:
saw_pending_tool_use = True
response_body: dict = {
"id": message_id,
@@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
"role": "assistant",
"model": model,
"content": content_blocks,
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
"stop_reason": _anthropic_stop_reason(
completion_tokens,
req.max_tokens,
has_pending_tool_use=saw_pending_tool_use,
),
"stop_sequence": None,
"usage": {
"input_tokens": prompt_tokens,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import hashlib
import json
import time
from collections import OrderedDict
from dataclasses import dataclass
@@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str:
return h.hexdigest()
def _tool_fingerprint(tool_config: dict | None) -> str:
if not isinstance(tool_config, dict):
return "-"
try:
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
except Exception:
canonical = str(tool_config)
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
class SessionCache:
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
@@ -79,11 +90,11 @@ class SessionCache:
def enabled(self) -> bool:
return self.max > 0
def build_key(self, api_key: str, messages: list[dict]) -> str:
def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
# API key scoping prevents cross-tenant session leakage even when
# different clients happen to produce identical histories.
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
return f"{key_scope}:{hash_user_context(messages)}"
return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
async def get(self, key: str) -> SessionEntry | None:
if not self.enabled:

View File

@@ -6,6 +6,31 @@ import types
import unittest
from unittest.mock import AsyncMock, patch
class _FakeSessionCache:
def __init__(self) -> None:
self.enabled = True
self.keys: list[str] = []
self.get_calls: list[str] = []
self.put_calls: list[tuple[str, str, str]] = []
self.invalidate_calls: list[str] = []
def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str:
marker = "with_tool" if tool_config is not None else "no_tool"
key = f"{api_key}:{len(messages)}:{marker}"
self.keys.append(key)
return key
async def get(self, key: str):
self.get_calls.append(key)
return None
async def put(self, key: str, session_id: str, instance_name: str = "") -> None:
self.put_calls.append((key, session_id, instance_name))
async def invalidate(self, key: str) -> None:
self.invalidate_calls.append(key)
# app.main imports playwright via auto_login; tests don't exercise that path.
# Inject a lightweight stub so unit tests run without installing playwright.
_playwright = types.ModuleType("playwright")
@@ -119,6 +144,38 @@ async def _collect_stream(response) -> str:
return "".join(chunks)
class _SpyClient(_FakeClient):
def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None:
super().__init__(stream_events=stream_events, complete_result=complete_result)
self.last_complete_kwargs: dict = {}
self.last_stream_kwargs: dict = {}
async def chat_complete(self, *args, **kwargs) -> dict:
self.last_complete_kwargs = dict(kwargs)
return await super().chat_complete(*args, **kwargs)
async def chat_stream(self, *args, **kwargs):
self.last_stream_kwargs = dict(kwargs)
async for event in super().chat_stream(*args, **kwargs):
yield event
class _SettingsPatch:
def __init__(self, **kwargs) -> None:
self._kwargs = kwargs
def __enter__(self):
self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()]
for p in self._patchers:
p.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for p in reversed(self._patchers):
p.stop()
return False
class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
async def test_openai_non_stream_bridges_tool_calls(self) -> None:
fake_client = _FakeClient(
@@ -156,6 +213,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
message = payload["choices"][0]["message"]
self.assertEqual(message["content"], "done")
self.assertIsInstance(message["tool_calls"], list)
self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls")
self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs")
self.assertEqual(
json.loads(message["tool_calls"][0]["function"]["arguments"]),
@@ -195,6 +253,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn('"tool_calls"', body)
self.assertIn('"content": "hello"', body)
self.assertIn('"finish_reason": "tool_calls"', body)
self.assertIn('"usage"', body)
self.assertIn("data: [DONE]", body)
@@ -239,9 +298,132 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
payload = json.loads(response.body)
types = [item["type"] for item in payload["content"]]
self.assertEqual(types, ["text", "tool_use", "tool_result"])
self.assertEqual(payload["stop_reason"], "end_turn")
self.assertEqual(payload["content"][1]["name"], "lookup")
self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1")
async def test_openai_stream_tool_call_indices_are_stable(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"id": "call_a",
"name": "read_file",
"input": {"path": "README.md"},
},
},
{
"type": "tool",
"tool": {
"id": "call_b",
"name": "search_docs",
"input": {"query": "gateway"},
},
},
],
complete_result={},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
):
response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
body = await _collect_stream(response)
self.assertIn('"id": "call_a"', body)
self.assertIn('"id": "call_b"', body)
self.assertIn('"index": 0', body)
self.assertIn('"index": 1', body)
async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
fake_client = _FakeClient(
stream_events=[],
complete_result={
"text": "",
"toolEvents": [
{
"name": "lookup",
"input": {"k": "v"},
}
],
"sessionId": "sess-2",
},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=False,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
payload = json.loads(response.body)
self.assertEqual(payload["stop_reason"], "tool_use")
self.assertEqual(len(payload["content"]), 1)
self.assertEqual(payload["content"][0]["type"], "tool_use")
async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None:
fake_client = _FakeClient(
stream_events=[
{
"type": "tool",
"tool": {
"name": "read",
"input": {"file": "a.txt"},
},
}
],
complete_result={},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=256,
messages=[{"role": "user", "content": "hi"}],
stream=True,
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
response = await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
body = await _collect_stream(response)
self.assertIn('"type": "tool_use"', body)
self.assertIn('"stop_reason": "tool_use"', body)
async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None:
fake_client = _FakeClient(
stream_events=[
@@ -284,11 +466,262 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase):
self.assertIn("event: message_start", body)
self.assertIn('"type": "tool_use"', body)
self.assertIn('"type": "tool_result"', body)
self.assertIn('"stop_reason": "end_turn"', body)
self.assertIn('"type": "text_delta"', body)
self.assertIn("event: message_stop", body)
class LingmaClientToolEventExtractionTests(unittest.TestCase):
async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertIn("tool_config", spy_client.last_complete_kwargs)
cfg = spy_client.last_complete_kwargs["tool_config"]
self.assertEqual(cfg["provider"], "openai")
self.assertEqual(len(cfg["tools"]), 1)
self.assertIsInstance(cfg["tool_choice"], dict)
async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None:
spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []})
req = ChatCompletionsRequest(
model="org_auto",
messages=[{"role": "user", "content": "hi"}],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=False),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertIn("tool_config", spy_client.last_complete_kwargs)
self.assertIsNone(spy_client.last_complete_kwargs["tool_config"])
async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
stream_events=[],
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"},
)
req = ChatCompletionsRequest(
model="org_auto",
messages=[
{"role": "user", "content": "turn-1"},
{"role": "user", "content": "turn-2"},
],
stream=False,
tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
tool_choice={"type": "function", "function": {"name": "lookup"}},
)
with (
patch.object(main, "session_cache", fake_cache),
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
_SettingsPatch(tool_forward_enabled=True),
):
await main.v1_chat_completions(req, _make_request("/v1/chat/completions"))
self.assertEqual(fake_cache.keys, [])
self.assertEqual(fake_cache.get_calls, [])
self.assertEqual(fake_cache.put_calls, [])
async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None:
fake_cache = _FakeSessionCache()
fake_client = _FakeClient(
stream_events=[],
complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"},
)
req = AnthropicMessagesRequest(
model="claude-3-5-sonnet-20241022",
max_tokens=128,
messages=[
{"role": "user", "content": "turn-1"},
{"role": "user", "content": "turn-2"},
],
stream=False,
tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}],
tool_choice={"type": "auto"},
)
with (
patch.object(main, "session_cache", fake_cache),
patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))),
patch.object(main, "chat_guard", _FakeGuard()),
patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})),
patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)),
patch.object(main.settings, "api_keys", ["test-key"]),
):
await main.v1_messages(
req,
_make_request(
"/v1/messages",
headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"},
),
)
self.assertEqual(fake_cache.keys, [])
self.assertEqual(fake_cache.get_calls, [])
self.assertEqual(fake_cache.put_calls, [])
class SessionCacheToolFingerprintTests(unittest.TestCase):
def test_build_key_changes_with_tool_config(self) -> None:
from app.session_cache import SessionCache
cache = SessionCache(max_entries=8, ttl_sec=60)
messages = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "hello"},
]
cfg_a = {
"provider": "openai",
"tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}],
"tool_choice": {"type": "function", "function": {"name": "lookup"}},
}
cfg_a_reordered = {
"tool_choice": {"function": {"name": "lookup"}, "type": "function"},
"tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}],
"provider": "openai",
}
cfg_b = {
"provider": "openai",
"tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}],
"tool_choice": {"type": "function", "function": {"name": "lookup_v2"}},
}
key_no_tool = cache.build_key("api-key", messages)
key_a = cache.build_key("api-key", messages, tool_config=cfg_a)
key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered)
key_b = cache.build_key("api-key", messages, tool_config=cfg_b)
self.assertNotEqual(key_no_tool, key_a)
self.assertEqual(key_a, key_a_reordered)
self.assertNotEqual(key_a, key_b)
def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"name": "lookup",
"parameters": {"q": "x"},
},
}
)
stream = rpc._chat_streams["req-1"]
self.assertEqual(stream["tool_order"], [])
self.assertEqual(stream["tool_states"], {})
self.assertTrue(stream["chunks"].empty())
import asyncio
asyncio.run(run())
def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"requestId": "req-1",
"toolCallId": "call-1",
"name": "lookup",
"parameters": {"q": "a"},
},
}
)
await rpc._handle_server_message(
{
"jsonrpc": "2.0",
"method": "tool/invokeResult",
"params": {
"toolCallId": "call-1",
"result": {"ok": True},
},
}
)
result = rpc.get_stream_result("req-1")
self.assertEqual(len(result["toolEvents"]), 1)
self.assertEqual(result["toolEvents"][0]["id"], "call-1")
self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"})
self.assertEqual(result["toolEvents"][0]["result"], {"ok": True})
import asyncio
asyncio.run(run())
def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None:
from app.lingma_client import LspWsRpcClient
rpc = LspWsRpcClient("ws://127.0.0.1:1")
async def run() -> None:
rpc.create_stream("req-1")
msg = {
"jsonrpc": "2.0",
"method": "tool/invoke",
"params": {
"requestId": "req-1",
"toolCallId": "call-dup",
"name": "lookup",
"parameters": {"q": "dup"},
},
}
await rpc._handle_server_message(msg)
await rpc._handle_server_message(msg)
stream = rpc._chat_streams["req-1"]
self.assertEqual(stream["tool_order"], ["call-dup"])
self.assertEqual(stream["chunks"].qsize(), 1)
import asyncio
asyncio.run(run())
def test_extracts_tool_event_from_results_and_parameters(self) -> None:
from app.lingma_client import LspWsRpcClient