fix: harden tooling session reuse and event routing

Ensure session reuse is disabled for tooling contexts, include tool config in cache keys, and stabilize tool event merge/routing with expanded bridge tests.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
GitHub Actions
2026-04-19 19:29:30 +08:00
parent 5aa7fbfae5
commit e600bae27c
5 changed files with 741 additions and 53 deletions

View File

@@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool:
return bool(stream_options.get("include_usage"))
def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "openai",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None:
if not settings.tool_forward_enabled:
return None
has_tools = isinstance(req.tools, list) and len(req.tools) > 0
has_choice = req.tool_choice is not None
if not has_tools and not has_choice:
return None
return {
"provider": "anthropic",
"tools": req.tools or [],
"tool_choice": req.tool_choice,
}
def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
for m in messages:
role = m.get("role")
if role == "tool":
return True
if role == "assistant" and m.get("tool_calls"):
return True
return False
def _anthropic_content_has_tool_blocks(content: Any) -> bool:
if not isinstance(content, list):
return False
for item in content:
if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}:
return True
return False
def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool:
if isinstance(req.tools, list) and len(req.tools) > 0:
return True
if req.tool_choice is not None:
return True
if _anthropic_content_has_tool_blocks(req.system):
return True
for m in req.messages:
if _anthropic_content_has_tool_blocks(m.content):
return True
return False
def _stream_event_type(event: Any) -> str:
if isinstance(event, dict):
t = event.get("type")
@@ -388,9 +452,9 @@ def _json_string(value: Any) -> str:
return "{}"
def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]:
return {
"id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"),
"id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"),
"type": "function",
"function": {
"name": str(tool.get("name") or "tool"),
@@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]:
}
def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]:
def _anthropic_tool_use_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any]:
return {
"type": "tool_use",
"id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"),
"id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"),
"name": str(tool.get("name") or "tool"),
"input": tool.get("input") if tool.get("input") is not None else {},
}
def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
def _anthropic_tool_result_block(
tool: dict[str, Any], *, forced_id: str | None = None
) -> dict[str, Any] | None:
if "result" not in tool:
return None
result = tool.get("result")
@@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None:
content = _json_string(result)
return {
"type": "tool_result",
"tool_use_id": str(tool.get("id") or ""),
"tool_use_id": str(tool.get("id") or forced_id or ""),
"content": content,
}
@@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
if req.model.lower() in {"lingma-agent", "agent"}:
ask_mode = "agent"
tool_config = _openai_tool_config(req)
has_tooling_context = _openai_has_tooling_context(req, messages_dump)
reuse_eligible = (
session_cache.enabled
and ask_mode == "chat"
and len(messages_dump) >= 2
and not has_tooling_context
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
@@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta):
success = False
tool_call_indexes: dict[str, int] = {}
saw_tool_call = False
try:
async for chunk in _inst.client.chat_stream(
prompt,
@@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta,
):
if _stream_event_type(chunk) == "tool":
tool = _stream_tool_event(chunk)
if not tool:
continue
tool_id = str(tool.get("id") or "")
if not tool_id:
tool_id = f"call_{len(tool_call_indexes)}"
idx = tool_call_indexes.get(tool_id)
if idx is None:
idx = len(tool_call_indexes)
tool_call_indexes[tool_id] = idx
saw_tool_call = True
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
@@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"delta": {
"tool_calls": [
{
"index": 0,
**_openai_tool_call(tool),
"index": idx,
**_openai_tool_call(tool, forced_id=tool_id),
}
]
},
@@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "tool_calls" if saw_tool_call else "stop",
}
],
}
yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n"
if include_usage:
usage_payload = {
"id": completion_id,
@@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
)
except Exception as exc:
logger.warning("chat.complete error (inst=%s): %s", inst.name, exc)
@@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
tool_events = result.get("toolEvents") or []
message_content = result.get("text") or ""
tool_calls: list[dict[str, Any]] = []
saw_tool_call = False
if isinstance(tool_events, list):
for item in tool_events:
for idx, item in enumerate(tool_events):
if isinstance(item, dict):
tool_calls.append(_openai_tool_call(item))
tool_id = str(item.get("id") or f"call_{idx}")
tool_calls.append(_openai_tool_call(item, forced_id=tool_id))
saw_tool_call = True
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex}",
created=int(time.time()),
@@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
choices=[
ChatCompletionChoice(
index=0,
finish_reason="stop",
finish_reason="tool_calls" if saw_tool_call else "stop",
message={
"role": "assistant",
"content": message_content,
@@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
],
)
data = response.model_dump()
data["latency"] = {
"first_token_ms": result.get("firstTokenLatencyMs"),
@@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes
)
def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str:
"""Approximate Anthropic `stop_reason`.
Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it;
we report `max_tokens` only when the generated length meets or exceeds
the caller's stated ceiling. Everything else is `end_turn`.
"""
def _anthropic_stop_reason(
completion_tokens: int,
max_tokens: int,
*,
has_pending_tool_use: bool = False,
) -> str:
"""Approximate Anthropic `stop_reason`."""
if has_pending_tool_use:
return "tool_use"
if max_tokens and completion_tokens >= max_tokens:
return "max_tokens"
return "end_turn"
@@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# Anthropic clients don't expose an ask_mode, so we always run in "chat".
ask_mode = "chat"
tool_config = _anthropic_tool_config(req)
has_tooling_context = _anthropic_has_tooling_context(req)
reuse_eligible = (
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2
session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
lookup_key = session_cache.build_key(api_key, messages_dump[:-1])
write_key = session_cache.build_key(api_key, messages_dump)
lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config)
entry = await session_cache.get(lookup_key)
if entry is not None:
cached_session_id = entry.session_id
@@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
return _anthropic_error(400, "invalid_request_error", "messages is empty")
prompt_tokens = estimate_tokens(prompt)
# ------------------------------------------------------------- backpressure
try:
ticket = await chat_guard.try_acquire()
@@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
success = False
block_index = 0
text_block_open = False
saw_pending_tool_use = False
try:
# 1) message_start — Anthropic SDKs read this first to get
# the message envelope (id/model/initial usage).
@@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
out_meta=_meta,
):
if _stream_event_type(chunk) == "tool":
@@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
tool = _stream_tool_event(chunk)
if not tool:
continue
tool_id = str(tool.get("id") or f"toolu_stream_{block_index}")
tool_use_block = _anthropic_tool_use_block(tool)
tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id)
yield _sse(
"content_block_start",
{
@@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
)
block_index += 1
tool_result_block = _anthropic_tool_result_block(tool)
tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id)
if tool_result_block is not None:
yield _sse(
"content_block_start",
@@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
{"type": "content_block_stop", "index": block_index},
)
block_index += 1
else:
saw_pending_tool_use = True
continue
text = _stream_text(chunk)
@@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
# 5) message_delta carries the terminal stop_reason and
# the final cumulative output_tokens count.
stop_reason = _anthropic_stop_reason(
completion_tokens_holder["n"], max_tokens
completion_tokens_holder["n"],
max_tokens,
has_pending_tool_use=saw_pending_tool_use,
)
yield _sse(
"message_delta",
@@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
},
)
# 6) message_stop — terminal event, no [DONE] sentinel.
yield _sse("message_stop", {"type": "message_stop"})
success = True
@@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
ask_mode,
session_id=cached_session_id,
is_reply=is_reply,
tool_config=tool_config,
)
except Exception as exc:
logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc)
@@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
if text:
content_blocks.append({"type": "text", "text": text})
tool_events = result.get("toolEvents") or []
saw_pending_tool_use = False
if isinstance(tool_events, list):
for item in tool_events:
for idx, item in enumerate(tool_events):
if not isinstance(item, dict):
continue
content_blocks.append(_anthropic_tool_use_block(item))
tool_result = _anthropic_tool_result_block(item)
tool_id = str(item.get("id") or f"toolu_nonstream_{idx}")
content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id))
tool_result = _anthropic_tool_result_block(item, forced_id=tool_id)
if tool_result is not None:
content_blocks.append(tool_result)
else:
saw_pending_tool_use = True
response_body: dict = {
"id": message_id,
@@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request):
"role": "assistant",
"model": model,
"content": content_blocks,
"stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens),
"stop_reason": _anthropic_stop_reason(
completion_tokens,
req.max_tokens,
has_pending_tool_use=saw_pending_tool_use,
),
"stop_sequence": None,
"usage": {
"input_tokens": prompt_tokens,