From 5a7553b35b49d2a3ed4643cd054c95e2bf782506 Mon Sep 17 00:00:00 2001 From: GitHub Actions Date: Wed, 22 Apr 2026 07:37:00 +0800 Subject: [PATCH] refactor: share execution prep for tool-call phase Keep the current tool-call bridge contract stable while extracting shared execution setup and tightening Anthropic forwarding regressions. Co-Authored-By: Claude Opus 4.7 --- app/anthropic_schema.py | 7 +- app/http/execution_core.py | 148 ++++++++++++++++ app/http/tool_bridge.py | 43 +++++ app/main.py | 299 +++++++++++---------------------- tests/test_tool_call_bridge.py | 31 ++++ 5 files changed, 319 insertions(+), 209 deletions(-) create mode 100644 app/http/execution_core.py diff --git a/app/anthropic_schema.py b/app/anthropic_schema.py index 05dece0..b239081 100644 --- a/app/anthropic_schema.py +++ b/app/anthropic_schema.py @@ -52,10 +52,9 @@ class AnthropicMessagesRequest(BaseModel): stop_sequences: list[str] | None = None # metadata.user_id is the official hint for per-user routing / abuse tracking. metadata: dict[str, Any] | None = None - # Tools / tool_choice are accepted but we can't forward them to Lingma yet — - # they're preserved here so the request doesn't 422, and the flattener - # surfaces any tool_use blocks as `[tool_use] {...}` text so the assistant - # still sees the context. + # Tools / tool_choice are accepted for compatibility and, when forwarding is + # enabled, are passed upstream as tool_config; tool_use / tool_result blocks + # are still flattened into text so the assistant can see prior tool context. tools: list[dict[str, Any]] | None = None tool_choice: dict[str, Any] | None = None diff --git a/app/http/execution_core.py b/app/http/execution_core.py new file mode 100644 index 0000000..0591a82 --- /dev/null +++ b/app/http/execution_core.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Awaitable, Callable + +from ..lingma_pool import LingmaPool, PoolInstance +from ..model_map import build_model_name_map, flatten_model_keys, resolve_model +from ..session_cache import SessionCache, hash_branch_context + + +@dataclass +class ExecutionContext: + ask_mode: str + lookup_key: str | None + write_key: str | None + cached_session_id: str | None + inst: PoolInstance + model: str + prompt: str + is_reply: bool + affinity: str | None + + +def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str: + model_name = (model or "").lower() + if model_name in {"lingma-agent", "agent"} or has_tooling_context: + return "agent" + return default_ask_mode + + +async def _apply_cached_instance_or_invalidate( + *, + protocol: str, + logger: Any, + session_cache: SessionCache, + inst: PoolInstance, + cached_instance_name: str | None, + cached_session_id: str | None, + lookup_key: str | None, +) -> str | None: + if cached_instance_name and inst.name != cached_instance_name: + logger.info( + "%s session cache instance %s unhealthy, falling back to %s", + protocol, + cached_instance_name, + inst.name, + ) + if lookup_key: + await session_cache.invalidate(lookup_key) + return None + return cached_session_id + + +async def prepare_execution_context( + *, + protocol: str, + requested_model: str, + has_tooling_context: bool, + tool_config: dict[str, Any] | None, + messages_dump: list[dict[str, Any]], + api_key: str, + affinity_key: str | None, + pool: LingmaPool, + session_cache: SessionCache, + logger: Any, + default_model: str, + default_ask_mode: str, + ensure_instance_logged_in: Callable[[PoolInstance], Awaitable[Any]], + last_user_text: Callable[[list[dict[str, Any]]], str], + messages_to_prompt: Callable[[list[dict[str, Any]]], str], +) -> ExecutionContext: + ask_mode = _resolve_ask_mode( + requested_model, + has_tooling_context, + default_ask_mode=default_ask_mode, + ) + + reuse_eligible = ( + session_cache.enabled + and ask_mode == "chat" + and len(messages_dump) >= 2 + and not has_tooling_context + ) + lookup_key: str | None = None + write_key: str | None = None + cached_session_id: str | None = None + cached_instance_name: str | None = None + if reuse_eligible: + prefix_branch_context = hash_branch_context(messages_dump[:-1]) + lookup_key = session_cache.build_key( + api_key, + messages_dump[:-1], + tool_config=tool_config, + branch_context=prefix_branch_context, + ) + write_key = session_cache.build_key( + api_key, + messages_dump, + tool_config=tool_config, + branch_context=hash_branch_context(messages_dump), + ) + entry = await session_cache.get(lookup_key) + if entry is None: + legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) + entry = await session_cache.get(legacy_lookup_key) + if entry is not None: + lookup_key = legacy_lookup_key + if entry is not None: + cached_session_id = entry.session_id + cached_instance_name = entry.instance_name or None + + affinity = cached_instance_name or affinity_key + inst = pool.pick(affinity_key=affinity) + cached_session_id = await _apply_cached_instance_or_invalidate( + protocol=protocol, + logger=logger, + session_cache=session_cache, + inst=inst, + cached_instance_name=cached_instance_name, + cached_session_id=cached_session_id, + lookup_key=lookup_key, + ) + + await ensure_instance_logged_in(inst) + + models = await inst.client.query_models() + available = flatten_model_keys(models) + name_map = build_model_name_map(models) + model = resolve_model(requested_model, available, default_model, name_map) + + if cached_session_id: + prompt = last_user_text(messages_dump) + is_reply = True + else: + prompt = messages_to_prompt(messages_dump) + is_reply = False + + return ExecutionContext( + ask_mode=ask_mode, + lookup_key=lookup_key, + write_key=write_key, + cached_session_id=cached_session_id, + inst=inst, + model=model, + prompt=prompt, + is_reply=is_reply, + affinity=affinity, + ) diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py index 8e3a431..fc71fcd 100644 --- a/app/http/tool_bridge.py +++ b/app/http/tool_bridge.py @@ -15,6 +15,49 @@ def _json_string(value: Any) -> str: return "{}" +def _openai_tool_name(tool: Any) -> str | None: + if not isinstance(tool, dict): + return None + if tool.get("type") == "function": + fn = tool.get("function") + if isinstance(fn, dict): + name = fn.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + name = tool.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + return None + + +def _anthropic_tool_name(tool: Any) -> str | None: + if not isinstance(tool, dict): + return None + name = tool.get("name") + if isinstance(name, str) and name.strip(): + return name.strip() + fn = tool.get("function") + if isinstance(fn, dict): + nested_name = fn.get("name") + if isinstance(nested_name, str) and nested_name.strip(): + return nested_name.strip() + return None + + +def _tool_event_allowed( + tool_name: str, + tool_config: dict[str, Any] | None, + *, + forced_tool_name: str | None = None, +) -> bool: + if not (tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools")): + return True + for tool in tool_config.get("tools") or []: + if tool_name == _anthropic_tool_name(tool) or tool_name == _openai_tool_name(tool): + return True + return bool(forced_tool_name and tool_name == forced_tool_name) + + def _openai_forced_tool_name(tool_choice: Any) -> str | None: if not isinstance(tool_choice, dict): return None diff --git a/app/main.py b/app/main.py index 9793ccc..c34433c 100644 --- a/app/main.py +++ b/app/main.py @@ -25,6 +25,11 @@ from .auth import ( ) from .concurrency import BackpressureRejected, InFlightGuard from .config import Settings, load_settings +from .http.execution_core import ( + _apply_cached_instance_or_invalidate as _shared_apply_cached_instance_or_invalidate, + _resolve_ask_mode as _shared_resolve_ask_mode, + prepare_execution_context, +) from .http.responses_adapter import ( _responses_id_from_chat_id, _responses_input_to_messages, @@ -35,6 +40,7 @@ from .http.responses_adapter import ( ) from .http.tool_bridge import ( _anthropic_forced_tool_name, + _anthropic_tool_name as _shared_anthropic_tool_name, _anthropic_tool_result_block, _anthropic_tool_use_block, _forced_tool_event_from_text, @@ -42,8 +48,10 @@ from .http.tool_bridge import ( _json_string, _openai_forced_tool_name, _openai_tool_call, + _openai_tool_name as _shared_openai_tool_name, _tool_code_object_from_text, _tool_code_single_arg_name, + _tool_event_allowed, ) from .lingma_pool import LingmaPool, PoolInstance from .logging_config import configure_logging, get_logger, request_id_var @@ -383,32 +391,11 @@ def _tool_allowlist() -> set[str]: def _openai_tool_name(tool: Any) -> str | None: - if not isinstance(tool, dict): - return None - if tool.get("type") == "function": - fn = tool.get("function") - if isinstance(fn, dict): - name = fn.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - name = tool.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - return None + return _shared_openai_tool_name(tool) def _anthropic_tool_name(tool: Any) -> str | None: - if not isinstance(tool, dict): - return None - name = tool.get("name") - if isinstance(name, str) and name.strip(): - return name.strip() - fn = tool.get("function") - if isinstance(fn, dict): - nested_name = fn.get("name") - if isinstance(nested_name, str) and nested_name.strip(): - return nested_name.strip() - return None + return _shared_anthropic_tool_name(tool) def _filter_allowed_tools(tools: list[dict[str, Any]], *, provider: str) -> list[dict[str, Any]]: @@ -509,10 +496,11 @@ def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool: def _resolve_ask_mode(model: str, has_tooling_context: bool) -> str: - model_name = (model or "").lower() - if model_name in {"lingma-agent", "agent"} or has_tooling_context: - return "agent" - return settings.default_ask_mode + return _shared_resolve_ask_mode( + model, + has_tooling_context, + default_ask_mode=settings.default_ask_mode, + ) async def _apply_cached_instance_or_invalidate( @@ -523,17 +511,15 @@ async def _apply_cached_instance_or_invalidate( cached_session_id: str | None, lookup_key: str | None, ) -> str | None: - if cached_instance_name and inst.name != cached_instance_name: - logger.info( - "%s session cache instance %s unhealthy, falling back to %s", - protocol, - cached_instance_name, - inst.name, - ) - if lookup_key: - await session_cache.invalidate(lookup_key) - return None - return cached_session_id + return await _shared_apply_cached_instance_or_invalidate( + protocol=protocol, + logger=logger, + session_cache=session_cache, + inst=inst, + cached_instance_name=cached_instance_name, + cached_session_id=cached_session_id, + lookup_key=lookup_key, + ) @@ -588,68 +574,32 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): # 3. Stick the request to the pool instance that originally served it. tool_config = _openai_tool_config(req) has_tooling_context = _openai_has_tooling_context(req, messages_dump) - - ask_mode = _resolve_ask_mode(req.model, has_tooling_context) - - reuse_eligible = ( - session_cache.enabled - and ask_mode == "chat" - and len(messages_dump) >= 2 - and not has_tooling_context - ) - lookup_key: str | None = None - write_key: str | None = None - cached_session_id: str | None = None - cached_instance_name: str | None = None - if reuse_eligible: - prefix_branch_context = hash_branch_context(messages_dump[:-1]) - lookup_key = session_cache.build_key( - api_key, - messages_dump[:-1], - tool_config=tool_config, - branch_context=prefix_branch_context, - ) - write_key = session_cache.build_key( - api_key, - messages_dump, - tool_config=tool_config, - branch_context=hash_branch_context(messages_dump), - ) - entry = await session_cache.get(lookup_key) - if entry is None: - legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) - entry = await session_cache.get(legacy_lookup_key) - if entry is not None: - lookup_key = legacy_lookup_key - if entry is not None: - cached_session_id = entry.session_id - cached_instance_name = entry.instance_name or None - affinity = cached_instance_name or _affinity_key_for(req) - inst = p.pick(affinity_key=affinity) - - cached_session_id = await _apply_cached_instance_or_invalidate( + execution = await prepare_execution_context( protocol="chat", - inst=inst, - cached_instance_name=cached_instance_name, - cached_session_id=cached_session_id, - lookup_key=lookup_key, + requested_model=req.model, + has_tooling_context=has_tooling_context, + tool_config=tool_config, + messages_dump=messages_dump, + api_key=api_key, + affinity_key=_affinity_key_for(req), + pool=p, + session_cache=session_cache, + logger=logger, + default_model=settings.default_model, + default_ask_mode=settings.default_ask_mode, + ensure_instance_logged_in=_ensure_instance_logged_in, + last_user_text=_last_user_text, + messages_to_prompt=_messages_to_prompt, ) - - await _ensure_instance_logged_in(inst) - - models = await inst.client.query_models() - available = flatten_model_keys(models) - name_map = build_model_name_map(models) - model = resolve_model(req.model, available, settings.default_model, name_map) - - # Prompt construction: on cache hit send only the last user turn so Lingma's - # stored context isn't duplicated. - if cached_session_id: - prompt = _last_user_text(messages_dump) - is_reply = True - else: - prompt = _messages_to_prompt(messages_dump) - is_reply = False + ask_mode = execution.ask_mode + lookup_key = execution.lookup_key + write_key = execution.write_key + cached_session_id = execution.cached_session_id + inst = execution.inst + model = execution.model + prompt = execution.prompt + is_reply = execution.is_reply + affinity = execution.affinity if not prompt: raise HTTPException( @@ -748,16 +698,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): continue tool_name = str(tool.get("name") or "") - allowed = True - if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): - allowed = False - for t in tool_config.get("tools"): - if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): - allowed = True - break - if not allowed and forced_tool_name and tool_name == forced_tool_name: - allowed = True - if not allowed: + if not _tool_event_allowed( + tool_name, + tool_config, + forced_tool_name=forced_tool_name, + ): continue if buffered_text_parts: @@ -956,16 +901,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): for idx, item in enumerate(tool_events): if isinstance(item, dict): tool_name = str(item.get("name") or "") - allowed = True - if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): - allowed = False - for t in tool_config.get("tools"): - if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): - allowed = True - break - if not allowed and forced_tool_name and tool_name == forced_tool_name: - allowed = True - if not allowed: + if not _tool_event_allowed( + tool_name, + tool_config, + forced_tool_name=forced_tool_name, + ): continue tool_id = str(item.get("id") or f"call_{idx}") @@ -1414,77 +1354,38 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): message = error.get("message") or str(detail) or "invalid tool configuration" return _anthropic_error(exc.status_code, "invalid_request_error", message) has_tooling_context = _anthropic_has_tooling_context(req) - - ask_mode = _resolve_ask_mode(req.model, has_tooling_context) - - reuse_eligible = ( - session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context - ) - lookup_key: str | None = None - write_key: str | None = None - cached_session_id: str | None = None - cached_instance_name: str | None = None - if reuse_eligible: - prefix_branch_context = hash_branch_context(messages_dump[:-1]) - lookup_key = session_cache.build_key( - api_key, - messages_dump[:-1], - tool_config=tool_config, - branch_context=prefix_branch_context, - ) - write_key = session_cache.build_key( - api_key, - messages_dump, - tool_config=tool_config, - branch_context=hash_branch_context(messages_dump), - ) - entry = await session_cache.get(lookup_key) - if entry is None: - legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) - entry = await session_cache.get(legacy_lookup_key) - if entry is not None: - lookup_key = legacy_lookup_key - if entry is not None: - cached_session_id = entry.session_id - cached_instance_name = entry.instance_name or None - - affinity = cached_instance_name or affinity_key_for_anthropic(req) - inst = p.pick(affinity_key=affinity) - - if cached_instance_name and inst.name != cached_instance_name: - logger.info( - "anthropic session cache instance %s unhealthy, falling back to %s", - cached_instance_name, - inst.name, - ) - cached_session_id = None - if lookup_key: - await session_cache.invalidate(lookup_key) - try: - await _ensure_instance_logged_in(inst) + execution = await prepare_execution_context( + protocol="anthropic", + requested_model=req.model, + has_tooling_context=has_tooling_context, + tool_config=tool_config, + messages_dump=messages_dump, + api_key=api_key, + affinity_key=affinity_key_for_anthropic(req), + pool=p, + session_cache=session_cache, + logger=logger, + default_model=settings.default_model, + default_ask_mode=settings.default_ask_mode, + ensure_instance_logged_in=_ensure_instance_logged_in, + last_user_text=_last_user_text, + messages_to_prompt=_messages_to_prompt, + ) except HTTPException as exc: - # 503/401/502 from login: map to closest Anthropic kind. err_type = "authentication_error" if exc.status_code == 401 else "overloaded_error" detail = exc.detail if isinstance(exc.detail, dict) else {} msg = (detail.get("error") or {}).get("message") or str(detail) or "upstream error" return _anthropic_error(exc.status_code, err_type, msg) - - # ------------------------------------------------------------- prompt & model - models = await inst.client.query_models() - available = flatten_model_keys(models) - name_map = build_model_name_map(models) - # Anthropic callers send `claude-*` model names. resolve_model's - # final fallback (default_model / first available) handles that cleanly - # without us having to hard-code a mapping table. - model = resolve_model(req.model, available, settings.default_model, name_map) - - if cached_session_id: - prompt = _last_user_text(messages_dump) - is_reply = True - else: - prompt = _messages_to_prompt(messages_dump) - is_reply = False + ask_mode = execution.ask_mode + lookup_key = execution.lookup_key + write_key = execution.write_key + cached_session_id = execution.cached_session_id + inst = execution.inst + model = execution.model + prompt = execution.prompt + is_reply = execution.is_reply + affinity = execution.affinity if not prompt: return _anthropic_error(400, "invalid_request_error", "messages is empty") @@ -1588,17 +1489,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): continue tool_name = str(tool.get("name") or "") - allowed = True - if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): - allowed = False - for t in tool_config.get("tools"): - if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): - allowed = True - break - forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) - if not allowed and forced_tool_name and tool_name == forced_tool_name: - allowed = True - if not allowed: + if not _tool_event_allowed( + tool_name, + tool_config, + forced_tool_name=_anthropic_forced_tool_name(req.tool_choice), + ): continue tool_id = str(tool.get("id") or f"toolu_stream_{block_index}") @@ -1778,17 +1673,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): continue tool_name = str(item.get("name") or "") - allowed = True - if tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools"): - allowed = False - for t in tool_config.get("tools"): - if tool_name == _anthropic_tool_name(t) or tool_name == _openai_tool_name(t): - allowed = True - break - forced_tool_name = _anthropic_forced_tool_name(req.tool_choice) - if not allowed and forced_tool_name and tool_name == forced_tool_name: - allowed = True - if not allowed: + if not _tool_event_allowed( + tool_name, + tool_config, + forced_tool_name=_anthropic_forced_tool_name(req.tool_choice), + ): continue saw_tool_event = True diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index 66c4955..e875776 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -991,6 +991,37 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(openai_spy.last_complete_args[2], "chat") self.assertEqual(anthropic_spy.last_complete_args[2], "chat") + async def test_anthropic_non_stream_does_not_forward_tool_config_when_disabled(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tool_choice={"type": "tool", "name": "lookup"}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + _SettingsPatch(tool_forward_enabled=False, default_ask_mode="chat"), + ): + await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + self.assertIn("tool_config", spy_client.last_complete_kwargs) + self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) + self.assertEqual(spy_client.last_complete_args[2], "agent") + async def test_anthropic_non_stream_with_tools_uses_agent_mode(self) -> None: spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) req = AnthropicMessagesRequest(