fix: harden tooling session reuse and event routing
Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,7 @@ class Settings:
|
||||
session_reuse_enabled: bool = True
|
||||
session_cache_max_entries: int = 256
|
||||
session_cache_ttl_sec: float = 1800.0
|
||||
tool_forward_enabled: bool = False
|
||||
|
||||
|
||||
def _bool_env(name: str, default: bool) -> bool:
|
||||
@@ -175,4 +176,5 @@ def load_settings() -> Settings:
|
||||
session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True),
|
||||
session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")),
|
||||
session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")),
|
||||
tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False),
|
||||
)
|
||||
|
||||
@@ -100,6 +100,7 @@ class LspWsRpcClient:
|
||||
self._reader_task: asyncio.Task | None = None
|
||||
self._rx_buffer = b""
|
||||
self._chat_streams: dict[str, dict] = {}
|
||||
self._tool_stream_map: dict[str, str] = {}
|
||||
self._on_disconnect = on_disconnect
|
||||
self._closed = False
|
||||
|
||||
@@ -202,6 +203,7 @@ class LspWsRpcClient:
|
||||
stream["done"].set()
|
||||
stream["chunks"].put_nowait(None)
|
||||
self._chat_streams.clear()
|
||||
self._tool_stream_map.clear()
|
||||
|
||||
async def _send(self, payload: dict):
|
||||
async with self._send_lock:
|
||||
@@ -251,6 +253,92 @@ class LspWsRpcClient:
|
||||
except Exception:
|
||||
logger.exception("on_disconnect callback failed")
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None:
|
||||
event_id = None
|
||||
if isinstance(tool_event, dict):
|
||||
event_id = tool_event.get("id")
|
||||
if isinstance(event_id, str) and event_id.strip():
|
||||
return event_id.strip()
|
||||
|
||||
fallback_id = params.get("toolCallId") or params.get("tool_call_id")
|
||||
if isinstance(fallback_id, str) and fallback_id.strip():
|
||||
return fallback_id.strip()
|
||||
|
||||
req_id = params.get("requestId")
|
||||
name = None
|
||||
if isinstance(tool_event, dict):
|
||||
name = tool_event.get("name")
|
||||
if not name:
|
||||
name = params.get("name")
|
||||
if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip():
|
||||
return f"{req_id.strip()}:tool:{name.strip()}"
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
return f"{req_id.strip()}:tool"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||
merged = dict(existing or {})
|
||||
changed = False
|
||||
|
||||
val = incoming.get("id")
|
||||
if val and merged.get("id") != val:
|
||||
merged["id"] = val
|
||||
changed = True
|
||||
|
||||
name = incoming.get("name")
|
||||
if name:
|
||||
existing_name = merged.get("name")
|
||||
if not existing_name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
else:
|
||||
existing_norm = str(existing_name).strip().lower()
|
||||
incoming_norm = str(name).strip().lower()
|
||||
if existing_norm == "tool" and incoming_norm != "tool":
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
elif existing_norm != "tool" and incoming_norm == "tool":
|
||||
pass
|
||||
elif merged.get("name") != name:
|
||||
merged["name"] = name
|
||||
changed = True
|
||||
|
||||
if "input" in incoming and incoming.get("input") is not None:
|
||||
incoming_input = incoming.get("input")
|
||||
should_update_input = incoming_input != {} or "input" not in merged
|
||||
if should_update_input and merged.get("input") != incoming_input:
|
||||
merged["input"] = incoming_input
|
||||
changed = True
|
||||
|
||||
if "result" in incoming and incoming.get("result") is not None:
|
||||
if merged.get("result") != incoming.get("result"):
|
||||
merged["result"] = incoming.get("result")
|
||||
changed = True
|
||||
|
||||
return merged, changed
|
||||
|
||||
|
||||
def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None:
|
||||
req_id = params.get("requestId")
|
||||
if isinstance(req_id, str) and req_id.strip():
|
||||
stream = self._chat_streams.get(req_id)
|
||||
if stream is not None and tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
self._tool_stream_map[tool_id] = req_id
|
||||
return stream
|
||||
|
||||
if tool_event is not None:
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if tool_id:
|
||||
mapped_req = self._tool_stream_map.get(tool_id)
|
||||
if mapped_req:
|
||||
return self._chat_streams.get(mapped_req)
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_server_message(self, msg: dict):
|
||||
method = msg.get("method")
|
||||
params = msg.get("params") or {}
|
||||
@@ -268,21 +356,29 @@ class LspWsRpcClient:
|
||||
|
||||
if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}:
|
||||
tool_event = self._extract_tool_event(params)
|
||||
req_id = params.get("requestId")
|
||||
stream = self._chat_streams.get(req_id) if req_id else None
|
||||
|
||||
if stream is None and tool_event is not None:
|
||||
for item in self._chat_streams.values():
|
||||
if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]):
|
||||
stream = item
|
||||
break
|
||||
|
||||
if stream is None and len(self._chat_streams) == 1:
|
||||
stream = next(iter(self._chat_streams.values()))
|
||||
stream = self._resolve_tool_stream(method, params, tool_event)
|
||||
|
||||
if stream is not None and tool_event is not None:
|
||||
stream["tool_events"].append(tool_event)
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": tool_event})
|
||||
tool_id = self._normalize_tool_id(method, params, tool_event)
|
||||
if not tool_id:
|
||||
logger.warning("drop unroutable tool event: method=%s missing tool id", method)
|
||||
else:
|
||||
tool_states = stream["tool_states"]
|
||||
order = stream["tool_order"]
|
||||
existing = tool_states.get(tool_id)
|
||||
merged, changed = self._merge_tool_event(existing, tool_event)
|
||||
if not existing:
|
||||
if "id" not in merged or not merged.get("id"):
|
||||
merged["id"] = tool_id
|
||||
tool_states[tool_id] = merged
|
||||
order.append(tool_id)
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif changed:
|
||||
tool_states[tool_id] = merged
|
||||
stream["chunks"].put_nowait({"type": "tool", "tool": merged})
|
||||
elif tool_event is not None:
|
||||
logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId"))
|
||||
|
||||
|
||||
if method == "chat/finish":
|
||||
req_id = params.get("requestId")
|
||||
@@ -321,7 +417,8 @@ class LspWsRpcClient:
|
||||
"chunks": asyncio.Queue(),
|
||||
"done": asyncio.Event(),
|
||||
"finish": None,
|
||||
"tool_events": [],
|
||||
"tool_states": {},
|
||||
"tool_order": [],
|
||||
"started_at": time.monotonic(),
|
||||
"first_chunk_at": None,
|
||||
"finish_at": None,
|
||||
@@ -331,6 +428,9 @@ class LspWsRpcClient:
|
||||
stream = self._chat_streams.pop(request_id, None)
|
||||
if stream is None:
|
||||
return
|
||||
for tool_id, mapped_req in list(self._tool_stream_map.items()):
|
||||
if mapped_req == request_id:
|
||||
self._tool_stream_map.pop(tool_id, None)
|
||||
# Drain queue so no stray future gets stuck if the consumer bailed early.
|
||||
if not stream["done"].is_set():
|
||||
stream["done"].set()
|
||||
@@ -359,12 +459,20 @@ class LspWsRpcClient:
|
||||
first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000)
|
||||
if stream.get("finish_at") is not None:
|
||||
total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000)
|
||||
|
||||
ordered_tool_events: list[dict[str, Any]] = []
|
||||
tool_states = stream.get("tool_states") or {}
|
||||
for tool_id in stream.get("tool_order") or []:
|
||||
event = tool_states.get(tool_id)
|
||||
if isinstance(event, dict):
|
||||
ordered_tool_events.append(event)
|
||||
|
||||
return {
|
||||
"text": "".join(stream.get("parts") or []),
|
||||
"finish": stream.get("finish") or {},
|
||||
"firstTokenLatencyMs": first_ms,
|
||||
"totalLatencyMs": total_ms,
|
||||
"toolEvents": stream.get("tool_events") or [],
|
||||
"toolEvents": ordered_tool_events,
|
||||
}
|
||||
|
||||
|
||||
@@ -733,9 +841,10 @@ class LingmaGatewayClient:
|
||||
request_id: str,
|
||||
*,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
):
|
||||
session_type = "developer" if ask_mode == "agent" else "chat"
|
||||
return {
|
||||
payload = {
|
||||
"requestId": request_id,
|
||||
"sessionId": session_id,
|
||||
"sessionType": session_type,
|
||||
@@ -764,6 +873,9 @@ class LingmaGatewayClient:
|
||||
"localeLang": "zh-CN",
|
||||
},
|
||||
}
|
||||
if tool_config is not None:
|
||||
payload["toolConfig"] = tool_config
|
||||
return payload
|
||||
|
||||
async def _kick_chat_ask(self, payload: dict) -> None:
|
||||
"""Fire chat/ask as a notification.
|
||||
@@ -784,12 +896,19 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
await self.ensure_ready()
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
@@ -820,6 +939,7 @@ class LingmaGatewayClient:
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
is_reply: bool = False,
|
||||
tool_config: dict[str, Any] | None = None,
|
||||
out_meta: dict | None = None,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Stream chat events.
|
||||
@@ -837,7 +957,13 @@ class LingmaGatewayClient:
|
||||
request_id = str(uuid.uuid4())
|
||||
sid = session_id or str(uuid.uuid4())
|
||||
payload = self._build_payload(
|
||||
prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply
|
||||
prompt,
|
||||
model_key,
|
||||
ask_mode,
|
||||
sid,
|
||||
request_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
self.rpc.create_stream(request_id)
|
||||
try:
|
||||
|
||||
180
app/main.py
180
app/main.py
@@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool:
|
||||
return bool(stream_options.get("include_usage"))
|
||||
|
||||
|
||||
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
|
||||
if not settings.tool_forward_enabled:
|
||||
return None
|
||||
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
return {
|
||||
"provider": "openai",
|
||||
"tools": req.tools or [],
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
|
||||
if not settings.tool_forward_enabled:
|
||||
return None
|
||||
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
|
||||
has_choice = req.tool_choice is not None
|
||||
if not has_tools and not has_choice:
|
||||
return None
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
"tools": req.tools or [],
|
||||
"tool_choice": req.tool_choice,
|
||||
}
|
||||
|
||||
|
||||
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
|
||||
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||
return True
|
||||
if req.tool_choice is not None:
|
||||
return True
|
||||
for m in messages:
|
||||
role = m.get("role")
|
||||
if role == "tool":
|
||||
return True
|
||||
if role == "assistant" and m.get("tool_calls"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
|
||||
if not isinstance(content, list):
|
||||
return False
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
|
||||
if isinstance(req.tools, list) and len(req.tools) > 0:
|
||||
return True
|
||||
if req.tool_choice is not None:
|
||||
return True
|
||||
if _anthropic_content_has_tool_blocks(req.system):
|
||||
return True
|
||||
for m in req.messages:
|
||||
if _anthropic_content_has_tool_blocks(m.content):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _stream_event_type(event: Any) -> str:
|
||||
if isinstance(event, dict):
|
||||
t = event.get("type")
|
||||
@@ -388,9 +452,9 @@ def _json_string(value: Any) -> str:
|
||||
return "{}"
|
||||
|
||||
|
||||
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
|
||||
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
|
||||
return {
|
||||
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
|
||||
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": str(tool.get("name") or "tool"),
|
||||
@@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
|
||||
def _anthropic_tool_use_block(
|
||||
tool: dict[str, Any], *, forced_id: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "tool_use",
|
||||
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
|
||||
"id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
|
||||
"name": str(tool.get("name") or "tool"),
|
||||
"input": tool.get("input") if tool.get("input") is not None else {},
|
||||
}
|
||||
|
||||
|
||||
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
|
||||
def _anthropic_tool_result_block(
|
||||
tool: dict[str, Any], *, forced_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
if "result" not in tool:
|
||||
return None
|
||||
result = tool.get("result")
|
||||
@@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
|
||||
content = _json_string(result)
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": str(tool.get("id") or ""),
|
||||
"tool_use_id": str(tool.get("id") or forced_id or ""),
|
||||
"content": content,
|
||||
}
|
||||
|
||||
@@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
if req.model.lower() in {"lingma-agent", "agent"}:
|
||||
ask_mode = "agent"
|
||||
|
||||
tool_config = _openai_tool_config(req)
|
||||
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled
|
||||
and ask_mode == "chat"
|
||||
and len(messages_dump) >= 2
|
||||
and not has_tooling_context
|
||||
)
|
||||
lookup_key: str | None = None
|
||||
write_key: str | None = None
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||
write_key = session_cache.build_key(api_key, messages_dump)
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
@@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
|
||||
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
|
||||
success = False
|
||||
tool_call_indexes: dict[str, int] = {}
|
||||
saw_tool_call = False
|
||||
try:
|
||||
async for chunk in _inst.client.chat_stream(
|
||||
prompt,
|
||||
@@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
out_meta=_meta,
|
||||
):
|
||||
if _stream_event_type(chunk) == "tool":
|
||||
tool = _stream_tool_event(chunk)
|
||||
if not tool:
|
||||
continue
|
||||
tool_id = str(tool.get("id") or "")
|
||||
if not tool_id:
|
||||
tool_id = f"call_{len(tool_call_indexes)}"
|
||||
idx = tool_call_indexes.get(tool_id)
|
||||
if idx is None:
|
||||
idx = len(tool_call_indexes)
|
||||
tool_call_indexes[tool_id] = idx
|
||||
saw_tool_call = True
|
||||
payload = {
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
@@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
**_openai_tool_call(tool),
|
||||
"index": idx,
|
||||
**_openai_tool_call(tool, forced_id=tool_id),
|
||||
}
|
||||
]
|
||||
},
|
||||
@@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "tool_calls" if saw_tool_call else "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
if include_usage:
|
||||
usage_payload = {
|
||||
"id": completion_id,
|
||||
@@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
|
||||
@@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
tool_events = result.get("toolEvents") or []
|
||||
message_content = result.get("text") or ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
saw_tool_call = False
|
||||
if isinstance(tool_events, list):
|
||||
for item in tool_events:
|
||||
for idx, item in enumerate(tool_events):
|
||||
if isinstance(item, dict):
|
||||
tool_calls.append(_openai_tool_call(item))
|
||||
tool_id = str(item.get("id") or f"call_{idx}")
|
||||
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
|
||||
saw_tool_call = True
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex}",
|
||||
created=int(time.time()),
|
||||
@@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
finish_reason="tool_calls" if saw_tool_call else "stop",
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": message_content,
|
||||
@@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
data = response.model_dump()
|
||||
data["latency"] = {
|
||||
"first_token_ms": result.get("firstTokenLatencyMs"),
|
||||
@@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
|
||||
)
|
||||
|
||||
|
||||
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str:
|
||||
"""Approximate Anthropic `stop_reason`.
|
||||
|
||||
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it;
|
||||
we report `max_tokens` only when the generated length meets or exceeds
|
||||
the caller's stated ceiling. Everything else is `end_turn`.
|
||||
"""
|
||||
def _anthropic_stop_reason(
|
||||
completion_tokens: int,
|
||||
max_tokens: int,
|
||||
*,
|
||||
has_pending_tool_use: bool = False,
|
||||
) -> str:
|
||||
"""Approximate Anthropic `stop_reason`."""
|
||||
if has_pending_tool_use:
|
||||
return "tool_use"
|
||||
if max_tokens and completion_tokens >= max_tokens:
|
||||
return "max_tokens"
|
||||
return "end_turn"
|
||||
@@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
|
||||
ask_mode = "chat"
|
||||
|
||||
tool_config = _anthropic_tool_config(req)
|
||||
has_tooling_context = _anthropic_has_tooling_context(req)
|
||||
|
||||
reuse_eligible = (
|
||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2
|
||||
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
|
||||
)
|
||||
lookup_key: str | None = None
|
||||
write_key: str | None = None
|
||||
cached_session_id: str | None = None
|
||||
cached_instance_name: str | None = None
|
||||
if reuse_eligible:
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
|
||||
write_key = session_cache.build_key(api_key, messages_dump)
|
||||
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
||||
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
|
||||
entry = await session_cache.get(lookup_key)
|
||||
if entry is not None:
|
||||
cached_session_id = entry.session_id
|
||||
@@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
return _anthropic_error(400, "invalid_request_error", "messages is empty")
|
||||
|
||||
prompt_tokens = estimate_tokens(prompt)
|
||||
|
||||
# ------------------------------------------------------------- backpressure
|
||||
try:
|
||||
ticket = await chat_guard.try_acquire()
|
||||
@@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
success = False
|
||||
block_index = 0
|
||||
text_block_open = False
|
||||
saw_pending_tool_use = False
|
||||
try:
|
||||
# 1) message_start — Anthropic SDKs read this first to get
|
||||
# the message envelope (id/model/initial usage).
|
||||
@@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
out_meta=_meta,
|
||||
):
|
||||
if _stream_event_type(chunk) == "tool":
|
||||
@@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
tool = _stream_tool_event(chunk)
|
||||
if not tool:
|
||||
continue
|
||||
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
|
||||
|
||||
tool_use_block = _anthropic_tool_use_block(tool)
|
||||
tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
{
|
||||
@@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
)
|
||||
block_index += 1
|
||||
|
||||
tool_result_block = _anthropic_tool_result_block(tool)
|
||||
tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
|
||||
if tool_result_block is not None:
|
||||
yield _sse(
|
||||
"content_block_start",
|
||||
@@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
{"type": "content_block_stop", "index": block_index},
|
||||
)
|
||||
block_index += 1
|
||||
else:
|
||||
saw_pending_tool_use = True
|
||||
continue
|
||||
|
||||
text = _stream_text(chunk)
|
||||
@@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
# 5) message_delta carries the terminal stop_reason and
|
||||
# the final cumulative output_tokens count.
|
||||
stop_reason = _anthropic_stop_reason(
|
||||
completion_tokens_holder["n"], max_tokens
|
||||
completion_tokens_holder["n"],
|
||||
max_tokens,
|
||||
has_pending_tool_use=saw_pending_tool_use,
|
||||
)
|
||||
yield _sse(
|
||||
"message_delta",
|
||||
@@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 6) message_stop — terminal event, no [DONE] sentinel.
|
||||
yield _sse("message_stop", {"type": "message_stop"})
|
||||
success = True
|
||||
@@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
ask_mode,
|
||||
session_id=cached_session_id,
|
||||
is_reply=is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
|
||||
@@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
tool_events = result.get("toolEvents") or []
|
||||
saw_pending_tool_use = False
|
||||
if isinstance(tool_events, list):
|
||||
for item in tool_events:
|
||||
for idx, item in enumerate(tool_events):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
content_blocks.append(_anthropic_tool_use_block(item))
|
||||
tool_result = _anthropic_tool_result_block(item)
|
||||
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
|
||||
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
|
||||
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
|
||||
if tool_result is not None:
|
||||
content_blocks.append(tool_result)
|
||||
else:
|
||||
saw_pending_tool_use = True
|
||||
|
||||
response_body: dict = {
|
||||
"id": message_id,
|
||||
@@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": content_blocks,
|
||||
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
|
||||
"stop_reason": _anthropic_stop_reason(
|
||||
completion_tokens,
|
||||
req.max_tokens,
|
||||
has_pending_tool_use=saw_pending_tool_use,
|
||||
),
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": prompt_tokens,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
@@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str:
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def _tool_fingerprint(tool_config: dict | None) -> str:
|
||||
if not isinstance(tool_config, dict):
|
||||
return "-"
|
||||
try:
|
||||
canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
except Exception:
|
||||
canonical = str(tool_config)
|
||||
return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
|
||||
class SessionCache:
|
||||
"""LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId.
|
||||
|
||||
@@ -79,11 +90,11 @@ class SessionCache:
|
||||
def enabled(self) -> bool:
|
||||
return self.max > 0
|
||||
|
||||
def build_key(self, api_key: str, messages: list[dict]) -> str:
|
||||
def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str:
|
||||
# API key scoping prevents cross-tenant session leakage even when
|
||||
# different clients happen to produce identical histories.
|
||||
key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12]
|
||||
return f"{key_scope}:{hash_user_context(messages)}"
|
||||
return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}"
|
||||
|
||||
async def get(self, key: str) -> SessionEntry | None:
|
||||
if not self.enabled:
|
||||
|
||||
Reference in New Issue
Block a user