diff --git a/app/config.py b/app/config.py index 12716c0..75f9e3d 100644 --- a/app/config.py +++ b/app/config.py @@ -44,6 +44,7 @@ class Settings: session_reuse_enabled: bool = True session_cache_max_entries: int = 256 session_cache_ttl_sec: float = 1800.0 + tool_forward_enabled: bool = False def _bool_env(name: str, default: bool) -> bool: @@ -175,4 +176,5 @@ def load_settings() -> Settings: session_reuse_enabled=_bool_env("SESSION_REUSE_ENABLED", True), session_cache_max_entries=int(os.getenv("SESSION_CACHE_MAX_ENTRIES", "256")), session_cache_ttl_sec=float(os.getenv("SESSION_CACHE_TTL_SEC", "1800")), + tool_forward_enabled=_bool_env("TOOL_FORWARD_ENABLED", False), ) diff --git a/app/lingma_client.py b/app/lingma_client.py index 45dd8fd..2ff98fb 100644 --- a/app/lingma_client.py +++ b/app/lingma_client.py @@ -100,6 +100,7 @@ class LspWsRpcClient: self._reader_task: asyncio.Task | None = None self._rx_buffer = b"" self._chat_streams: dict[str, dict] = {} + self._tool_stream_map: dict[str, str] = {} self._on_disconnect = on_disconnect self._closed = False @@ -202,6 +203,7 @@ class LspWsRpcClient: stream["done"].set() stream["chunks"].put_nowait(None) self._chat_streams.clear() + self._tool_stream_map.clear() async def _send(self, payload: dict): async with self._send_lock: @@ -251,6 +253,92 @@ class LspWsRpcClient: except Exception: logger.exception("on_disconnect callback failed") + @staticmethod + def _normalize_tool_id(method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> str | None: + event_id = None + if isinstance(tool_event, dict): + event_id = tool_event.get("id") + if isinstance(event_id, str) and event_id.strip(): + return event_id.strip() + + fallback_id = params.get("toolCallId") or params.get("tool_call_id") + if isinstance(fallback_id, str) and fallback_id.strip(): + return fallback_id.strip() + + req_id = params.get("requestId") + name = None + if isinstance(tool_event, dict): + name = tool_event.get("name") + if not name: + name = params.get("name") + if isinstance(req_id, str) and req_id.strip() and isinstance(name, str) and name.strip(): + return f"{req_id.strip()}:tool:{name.strip()}" + if isinstance(req_id, str) and req_id.strip(): + return f"{req_id.strip()}:tool" + return None + + @staticmethod + def _merge_tool_event(existing: dict[str, Any] | None, incoming: dict[str, Any]) -> tuple[dict[str, Any], bool]: + merged = dict(existing or {}) + changed = False + + val = incoming.get("id") + if val and merged.get("id") != val: + merged["id"] = val + changed = True + + name = incoming.get("name") + if name: + existing_name = merged.get("name") + if not existing_name: + merged["name"] = name + changed = True + else: + existing_norm = str(existing_name).strip().lower() + incoming_norm = str(name).strip().lower() + if existing_norm == "tool" and incoming_norm != "tool": + merged["name"] = name + changed = True + elif existing_norm != "tool" and incoming_norm == "tool": + pass + elif merged.get("name") != name: + merged["name"] = name + changed = True + + if "input" in incoming and incoming.get("input") is not None: + incoming_input = incoming.get("input") + should_update_input = incoming_input != {} or "input" not in merged + if should_update_input and merged.get("input") != incoming_input: + merged["input"] = incoming_input + changed = True + + if "result" in incoming and incoming.get("result") is not None: + if merged.get("result") != incoming.get("result"): + merged["result"] = incoming.get("result") + changed = True + + return merged, changed + + + def _resolve_tool_stream(self, method: str, params: dict[str, Any], tool_event: dict[str, Any] | None) -> dict | None: + req_id = params.get("requestId") + if isinstance(req_id, str) and req_id.strip(): + stream = self._chat_streams.get(req_id) + if stream is not None and tool_event is not None: + tool_id = self._normalize_tool_id(method, params, tool_event) + if tool_id: + self._tool_stream_map[tool_id] = req_id + return stream + + if tool_event is not None: + tool_id = self._normalize_tool_id(method, params, tool_event) + if tool_id: + mapped_req = self._tool_stream_map.get(tool_id) + if mapped_req: + return self._chat_streams.get(mapped_req) + + return None + async def _handle_server_message(self, msg: dict): method = msg.get("method") params = msg.get("params") or {} @@ -268,21 +356,29 @@ class LspWsRpcClient: if method in {"tool/call/sync", "tool/invoke", "tool/call/approve", "tool/invokeResult"}: tool_event = self._extract_tool_event(params) - req_id = params.get("requestId") - stream = self._chat_streams.get(req_id) if req_id else None - - if stream is None and tool_event is not None: - for item in self._chat_streams.values(): - if any(evt.get("id") == tool_event["id"] for evt in item["tool_events"]): - stream = item - break - - if stream is None and len(self._chat_streams) == 1: - stream = next(iter(self._chat_streams.values())) + stream = self._resolve_tool_stream(method, params, tool_event) if stream is not None and tool_event is not None: - stream["tool_events"].append(tool_event) - stream["chunks"].put_nowait({"type": "tool", "tool": tool_event}) + tool_id = self._normalize_tool_id(method, params, tool_event) + if not tool_id: + logger.warning("drop unroutable tool event: method=%s missing tool id", method) + else: + tool_states = stream["tool_states"] + order = stream["tool_order"] + existing = tool_states.get(tool_id) + merged, changed = self._merge_tool_event(existing, tool_event) + if not existing: + if "id" not in merged or not merged.get("id"): + merged["id"] = tool_id + tool_states[tool_id] = merged + order.append(tool_id) + stream["chunks"].put_nowait({"type": "tool", "tool": merged}) + elif changed: + tool_states[tool_id] = merged + stream["chunks"].put_nowait({"type": "tool", "tool": merged}) + elif tool_event is not None: + logger.warning("drop unroutable tool event: method=%s requestId=%s", method, params.get("requestId")) + if method == "chat/finish": req_id = params.get("requestId") @@ -321,7 +417,8 @@ class LspWsRpcClient: "chunks": asyncio.Queue(), "done": asyncio.Event(), "finish": None, - "tool_events": [], + "tool_states": {}, + "tool_order": [], "started_at": time.monotonic(), "first_chunk_at": None, "finish_at": None, @@ -331,6 +428,9 @@ class LspWsRpcClient: stream = self._chat_streams.pop(request_id, None) if stream is None: return + for tool_id, mapped_req in list(self._tool_stream_map.items()): + if mapped_req == request_id: + self._tool_stream_map.pop(tool_id, None) # Drain queue so no stray future gets stuck if the consumer bailed early. if not stream["done"].is_set(): stream["done"].set() @@ -359,12 +459,20 @@ class LspWsRpcClient: first_ms = int((stream["first_chunk_at"] - stream["started_at"]) * 1000) if stream.get("finish_at") is not None: total_ms = int((stream["finish_at"] - stream["started_at"]) * 1000) + + ordered_tool_events: list[dict[str, Any]] = [] + tool_states = stream.get("tool_states") or {} + for tool_id in stream.get("tool_order") or []: + event = tool_states.get(tool_id) + if isinstance(event, dict): + ordered_tool_events.append(event) + return { "text": "".join(stream.get("parts") or []), "finish": stream.get("finish") or {}, "firstTokenLatencyMs": first_ms, "totalLatencyMs": total_ms, - "toolEvents": stream.get("tool_events") or [], + "toolEvents": ordered_tool_events, } @@ -733,9 +841,10 @@ class LingmaGatewayClient: request_id: str, *, is_reply: bool = False, + tool_config: dict[str, Any] | None = None, ): session_type = "developer" if ask_mode == "agent" else "chat" - return { + payload = { "requestId": request_id, "sessionId": session_id, "sessionType": session_type, @@ -764,6 +873,9 @@ class LingmaGatewayClient: "localeLang": "zh-CN", }, } + if tool_config is not None: + payload["toolConfig"] = tool_config + return payload async def _kick_chat_ask(self, payload: dict) -> None: """Fire chat/ask as a notification. @@ -784,12 +896,19 @@ class LingmaGatewayClient: *, session_id: str | None = None, is_reply: bool = False, + tool_config: dict[str, Any] | None = None, ) -> dict: await self.ensure_ready() request_id = str(uuid.uuid4()) sid = session_id or str(uuid.uuid4()) payload = self._build_payload( - prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply + prompt, + model_key, + ask_mode, + sid, + request_id, + is_reply=is_reply, + tool_config=tool_config, ) self.rpc.create_stream(request_id) try: @@ -820,6 +939,7 @@ class LingmaGatewayClient: *, session_id: str | None = None, is_reply: bool = False, + tool_config: dict[str, Any] | None = None, out_meta: dict | None = None, ) -> AsyncIterator[dict[str, Any]]: """Stream chat events. @@ -837,7 +957,13 @@ class LingmaGatewayClient: request_id = str(uuid.uuid4()) sid = session_id or str(uuid.uuid4()) payload = self._build_payload( - prompt, model_key, ask_mode, sid, request_id, is_reply=is_reply + prompt, + model_key, + ask_mode, + sid, + request_id, + is_reply=is_reply, + tool_config=tool_config, ) self.rpc.create_stream(request_id) try: diff --git a/app/main.py b/app/main.py index 082c30c..716256d 100644 --- a/app/main.py +++ b/app/main.py @@ -351,6 +351,70 @@ def _include_usage(stream_options: dict | None) -> bool: return bool(stream_options.get("include_usage")) +def _openai_tool_config(req: ChatCompletionsRequest) -> dict[str, Any] | None: + if not settings.tool_forward_enabled: + return None + has_tools = isinstance(req.tools, list) and len(req.tools) > 0 + has_choice = req.tool_choice is not None + if not has_tools and not has_choice: + return None + return { + "provider": "openai", + "tools": req.tools or [], + "tool_choice": req.tool_choice, + } + + +def _anthropic_tool_config(req: AnthropicMessagesRequest) -> dict[str, Any] | None: + if not settings.tool_forward_enabled: + return None + has_tools = isinstance(req.tools, list) and len(req.tools) > 0 + has_choice = req.tool_choice is not None + if not has_tools and not has_choice: + return None + return { + "provider": "anthropic", + "tools": req.tools or [], + "tool_choice": req.tool_choice, + } + + +def _openai_has_tooling_context(req: ChatCompletionsRequest, messages: list[dict[str, Any]]) -> bool: + if isinstance(req.tools, list) and len(req.tools) > 0: + return True + if req.tool_choice is not None: + return True + for m in messages: + role = m.get("role") + if role == "tool": + return True + if role == "assistant" and m.get("tool_calls"): + return True + return False + + +def _anthropic_content_has_tool_blocks(content: Any) -> bool: + if not isinstance(content, list): + return False + for item in content: + if isinstance(item, dict) and item.get("type") in {"tool_use", "tool_result"}: + return True + return False + + +def _anthropic_has_tooling_context(req: AnthropicMessagesRequest) -> bool: + if isinstance(req.tools, list) and len(req.tools) > 0: + return True + if req.tool_choice is not None: + return True + if _anthropic_content_has_tool_blocks(req.system): + return True + for m in req.messages: + if _anthropic_content_has_tool_blocks(m.content): + return True + return False + + def _stream_event_type(event: Any) -> str: if isinstance(event, dict): t = event.get("type") @@ -388,9 +452,9 @@ def _json_string(value: Any) -> str: return "{}" -def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]: +def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: return { - "id": str(tool.get("id") or f"call_{uuid.uuid4().hex}"), + "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), "type": "function", "function": { "name": str(tool.get("name") or "tool"), @@ -399,16 +463,20 @@ def _openai_tool_call(tool: dict[str, Any]) -> dict[str, Any]: } -def _anthropic_tool_use_block(tool: dict[str, Any]) -> dict[str, Any]: +def _anthropic_tool_use_block( + tool: dict[str, Any], *, forced_id: str | None = None +) -> dict[str, Any]: return { "type": "tool_use", - "id": str(tool.get("id") or f"toolu_{uuid.uuid4().hex}"), + "id": str(tool.get("id") or forced_id or f"toolu_{uuid.uuid4().hex}"), "name": str(tool.get("name") or "tool"), "input": tool.get("input") if tool.get("input") is not None else {}, } -def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None: +def _anthropic_tool_result_block( + tool: dict[str, Any], *, forced_id: str | None = None +) -> dict[str, Any] | None: if "result" not in tool: return None result = tool.get("result") @@ -418,7 +486,7 @@ def _anthropic_tool_result_block(tool: dict[str, Any]) -> dict[str, Any] | None: content = _json_string(result) return { "type": "tool_result", - "tool_use_id": str(tool.get("id") or ""), + "tool_use_id": str(tool.get("id") or forced_id or ""), "content": content, } @@ -440,18 +508,22 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): if req.model.lower() in {"lingma-agent", "agent"}: ask_mode = "agent" + tool_config = _openai_tool_config(req) + has_tooling_context = _openai_has_tooling_context(req, messages_dump) + reuse_eligible = ( session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 + and not has_tooling_context ) lookup_key: str | None = None write_key: str | None = None cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: - lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) - write_key = session_cache.build_key(api_key, messages_dump) + lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) + write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config) entry = await session_cache.get(lookup_key) if entry is not None: cached_session_id = entry.session_id @@ -549,6 +621,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False + tool_call_indexes: dict[str, int] = {} + saw_tool_call = False try: async for chunk in _inst.client.chat_stream( prompt, @@ -556,12 +630,21 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, + tool_config=tool_config, out_meta=_meta, ): if _stream_event_type(chunk) == "tool": tool = _stream_tool_event(chunk) if not tool: continue + tool_id = str(tool.get("id") or "") + if not tool_id: + tool_id = f"call_{len(tool_call_indexes)}" + idx = tool_call_indexes.get(tool_id) + if idx is None: + idx = len(tool_call_indexes) + tool_call_indexes[tool_id] = idx + saw_tool_call = True payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -573,8 +656,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "delta": { "tool_calls": [ { - "index": 0, - **_openai_tool_call(tool), + "index": idx, + **_openai_tool_call(tool, forced_id=tool_id), } ] }, @@ -609,10 +692,17 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "object": "chat.completion.chunk", "created": created, "model": model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "tool_calls" if saw_tool_call else "stop", + } + ], } yield f"data: {json.dumps(done_payload, ensure_ascii=False)}\n\n" + if include_usage: usage_payload = { "id": completion_id, @@ -670,6 +760,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, + tool_config=tool_config, ) except Exception as exc: logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) @@ -702,10 +793,13 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_events = result.get("toolEvents") or [] message_content = result.get("text") or "" tool_calls: list[dict[str, Any]] = [] + saw_tool_call = False if isinstance(tool_events, list): - for item in tool_events: + for idx, item in enumerate(tool_events): if isinstance(item, dict): - tool_calls.append(_openai_tool_call(item)) + tool_id = str(item.get("id") or f"call_{idx}") + tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) + saw_tool_call = True response = ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex}", created=int(time.time()), @@ -713,7 +807,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): choices=[ ChatCompletionChoice( index=0, - finish_reason="stop", + finish_reason="tool_calls" if saw_tool_call else "stop", message={ "role": "assistant", "content": message_content, @@ -723,6 +817,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ], ) + data = response.model_dump() data["latency"] = { "first_token_ms": result.get("firstTokenLatencyMs"), @@ -749,13 +844,15 @@ def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONRes ) -def _anthropic_stop_reason(completion_tokens: int, max_tokens: int) -> str: - """Approximate Anthropic `stop_reason`. - - Lingma doesn't expose a `max_tokens` knob, so we can't truly enforce it; - we report `max_tokens` only when the generated length meets or exceeds - the caller's stated ceiling. Everything else is `end_turn`. - """ +def _anthropic_stop_reason( + completion_tokens: int, + max_tokens: int, + *, + has_pending_tool_use: bool = False, +) -> str: + """Approximate Anthropic `stop_reason`.""" + if has_pending_tool_use: + return "tool_use" if max_tokens and completion_tokens >= max_tokens: return "max_tokens" return "end_turn" @@ -818,16 +915,19 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): # Anthropic clients don't expose an ask_mode, so we always run in "chat". ask_mode = "chat" + tool_config = _anthropic_tool_config(req) + has_tooling_context = _anthropic_has_tooling_context(req) + reuse_eligible = ( - session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 + session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context ) lookup_key: str | None = None write_key: str | None = None cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: - lookup_key = session_cache.build_key(api_key, messages_dump[:-1]) - write_key = session_cache.build_key(api_key, messages_dump) + lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) + write_key = session_cache.build_key(api_key, messages_dump, tool_config=tool_config) entry = await session_cache.get(lookup_key) if entry is not None: cached_session_id = entry.session_id @@ -875,7 +975,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): return _anthropic_error(400, "invalid_request_error", "messages is empty") prompt_tokens = estimate_tokens(prompt) - # ------------------------------------------------------------- backpressure try: ticket = await chat_guard.try_acquire() @@ -927,6 +1026,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): success = False block_index = 0 text_block_open = False + saw_pending_tool_use = False try: # 1) message_start — Anthropic SDKs read this first to get # the message envelope (id/model/initial usage). @@ -956,6 +1056,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, + tool_config=tool_config, out_meta=_meta, ): if _stream_event_type(chunk) == "tool": @@ -970,8 +1071,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): tool = _stream_tool_event(chunk) if not tool: continue + tool_id = str(tool.get("id") or f"toolu_stream_{block_index}") - tool_use_block = _anthropic_tool_use_block(tool) + tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id) yield _sse( "content_block_start", { @@ -986,7 +1088,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ) block_index += 1 - tool_result_block = _anthropic_tool_result_block(tool) + tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id) if tool_result_block is not None: yield _sse( "content_block_start", @@ -1001,6 +1103,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): {"type": "content_block_stop", "index": block_index}, ) block_index += 1 + else: + saw_pending_tool_use = True continue text = _stream_text(chunk) @@ -1036,7 +1140,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): # 5) message_delta carries the terminal stop_reason and # the final cumulative output_tokens count. stop_reason = _anthropic_stop_reason( - completion_tokens_holder["n"], max_tokens + completion_tokens_holder["n"], + max_tokens, + has_pending_tool_use=saw_pending_tool_use, ) yield _sse( "message_delta", @@ -1050,6 +1156,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): }, ) + # 6) message_stop — terminal event, no [DONE] sentinel. yield _sse("message_stop", {"type": "message_stop"}) success = True @@ -1109,6 +1216,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, + tool_config=tool_config, ) except Exception as exc: logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc) @@ -1139,14 +1247,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): if text: content_blocks.append({"type": "text", "text": text}) tool_events = result.get("toolEvents") or [] + saw_pending_tool_use = False if isinstance(tool_events, list): - for item in tool_events: + for idx, item in enumerate(tool_events): if not isinstance(item, dict): continue - content_blocks.append(_anthropic_tool_use_block(item)) - tool_result = _anthropic_tool_result_block(item) + tool_id = str(item.get("id") or f"toolu_nonstream_{idx}") + content_blocks.append(_anthropic_tool_use_block(item, forced_id=tool_id)) + tool_result = _anthropic_tool_result_block(item, forced_id=tool_id) if tool_result is not None: content_blocks.append(tool_result) + else: + saw_pending_tool_use = True response_body: dict = { "id": message_id, @@ -1154,7 +1266,11 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): "role": "assistant", "model": model, "content": content_blocks, - "stop_reason": _anthropic_stop_reason(completion_tokens, req.max_tokens), + "stop_reason": _anthropic_stop_reason( + completion_tokens, + req.max_tokens, + has_pending_tool_use=saw_pending_tool_use, + ), "stop_sequence": None, "usage": { "input_tokens": prompt_tokens, diff --git a/app/session_cache.py b/app/session_cache.py index 46d9778..1ada29a 100644 --- a/app/session_cache.py +++ b/app/session_cache.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import hashlib +import json import time from collections import OrderedDict from dataclasses import dataclass @@ -42,6 +43,16 @@ def hash_user_context(messages: list[dict]) -> str: return h.hexdigest() +def _tool_fingerprint(tool_config: dict | None) -> str: + if not isinstance(tool_config, dict): + return "-" + try: + canonical = json.dumps(tool_config, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + except Exception: + canonical = str(tool_config) + return hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:16] + + class SessionCache: """LRU + TTL cache: conversation-prefix hash -> upstream Lingma sessionId. @@ -79,11 +90,11 @@ class SessionCache: def enabled(self) -> bool: return self.max > 0 - def build_key(self, api_key: str, messages: list[dict]) -> str: + def build_key(self, api_key: str, messages: list[dict], *, tool_config: dict | None = None) -> str: # API key scoping prevents cross-tenant session leakage even when # different clients happen to produce identical histories. key_scope = hashlib.sha1((api_key or "-").encode("utf-8")).hexdigest()[:12] - return f"{key_scope}:{hash_user_context(messages)}" + return f"{key_scope}:{hash_user_context(messages)}:{_tool_fingerprint(tool_config)}" async def get(self, key: str) -> SessionEntry | None: if not self.enabled: diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index 05ea2b3..8667d13 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -6,6 +6,31 @@ import types import unittest from unittest.mock import AsyncMock, patch + +class _FakeSessionCache: + def __init__(self) -> None: + self.enabled = True + self.keys: list[str] = [] + self.get_calls: list[str] = [] + self.put_calls: list[tuple[str, str, str]] = [] + self.invalidate_calls: list[str] = [] + + def build_key(self, api_key: str, messages: list[dict], *, tool_config=None) -> str: + marker = "with_tool" if tool_config is not None else "no_tool" + key = f"{api_key}:{len(messages)}:{marker}" + self.keys.append(key) + return key + + async def get(self, key: str): + self.get_calls.append(key) + return None + + async def put(self, key: str, session_id: str, instance_name: str = "") -> None: + self.put_calls.append((key, session_id, instance_name)) + + async def invalidate(self, key: str) -> None: + self.invalidate_calls.append(key) + # app.main imports playwright via auto_login; tests don't exercise that path. # Inject a lightweight stub so unit tests run without installing playwright. _playwright = types.ModuleType("playwright") @@ -119,6 +144,38 @@ async def _collect_stream(response) -> str: return "".join(chunks) +class _SpyClient(_FakeClient): + def __init__(self, *, stream_events: list[dict], complete_result: dict) -> None: + super().__init__(stream_events=stream_events, complete_result=complete_result) + self.last_complete_kwargs: dict = {} + self.last_stream_kwargs: dict = {} + + async def chat_complete(self, *args, **kwargs) -> dict: + self.last_complete_kwargs = dict(kwargs) + return await super().chat_complete(*args, **kwargs) + + async def chat_stream(self, *args, **kwargs): + self.last_stream_kwargs = dict(kwargs) + async for event in super().chat_stream(*args, **kwargs): + yield event + + +class _SettingsPatch: + def __init__(self, **kwargs) -> None: + self._kwargs = kwargs + + def __enter__(self): + self._patchers = [patch.object(main.settings, k, v) for k, v in self._kwargs.items()] + for p in self._patchers: + p.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + for p in reversed(self._patchers): + p.stop() + return False + + class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): async def test_openai_non_stream_bridges_tool_calls(self) -> None: fake_client = _FakeClient( @@ -156,6 +213,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): message = payload["choices"][0]["message"] self.assertEqual(message["content"], "done") self.assertIsInstance(message["tool_calls"], list) + self.assertEqual(payload["choices"][0]["finish_reason"], "tool_calls") self.assertEqual(message["tool_calls"][0]["function"]["name"], "search_docs") self.assertEqual( json.loads(message["tool_calls"][0]["function"]["arguments"]), @@ -195,6 +253,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"tool_calls"', body) self.assertIn('"content": "hello"', body) + self.assertIn('"finish_reason": "tool_calls"', body) self.assertIn('"usage"', body) self.assertIn("data: [DONE]", body) @@ -239,9 +298,132 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): payload = json.loads(response.body) types = [item["type"] for item in payload["content"]] self.assertEqual(types, ["text", "tool_use", "tool_result"]) + self.assertEqual(payload["stop_reason"], "end_turn") self.assertEqual(payload["content"][1]["name"], "lookup") self.assertEqual(payload["content"][2]["tool_use_id"], "toolu_1") + async def test_openai_stream_tool_call_indices_are_stable(self) -> None: + fake_client = _FakeClient( + stream_events=[ + { + "type": "tool", + "tool": { + "id": "call_a", + "name": "read_file", + "input": {"path": "README.md"}, + }, + }, + { + "type": "tool", + "tool": { + "id": "call_b", + "name": "search_docs", + "input": {"query": "gateway"}, + }, + }, + ], + complete_result={}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=True, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + ): + response = await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + body = await _collect_stream(response) + + self.assertIn('"id": "call_a"', body) + self.assertIn('"id": "call_b"', body) + self.assertIn('"index": 0', body) + self.assertIn('"index": 1', body) + + async def test_anthropic_non_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None: + fake_client = _FakeClient( + stream_events=[], + complete_result={ + "text": "", + "toolEvents": [ + { + "name": "lookup", + "input": {"k": "v"}, + } + ], + "sessionId": "sess-2", + }, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=False, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + payload = json.loads(response.body) + self.assertEqual(payload["stop_reason"], "tool_use") + self.assertEqual(len(payload["content"]), 1) + self.assertEqual(payload["content"][0]["type"], "tool_use") + + async def test_anthropic_stream_returns_tool_use_stop_reason_when_result_missing(self) -> None: + fake_client = _FakeClient( + stream_events=[ + { + "type": "tool", + "tool": { + "name": "read", + "input": {"file": "a.txt"}, + }, + } + ], + complete_result={}, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=256, + messages=[{"role": "user", "content": "hi"}], + stream=True, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + response = await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + body = await _collect_stream(response) + + self.assertIn('"type": "tool_use"', body) + self.assertIn('"stop_reason": "tool_use"', body) + async def test_anthropic_stream_bridges_tool_and_text_events(self) -> None: fake_client = _FakeClient( stream_events=[ @@ -284,11 +466,262 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn("event: message_start", body) self.assertIn('"type": "tool_use"', body) self.assertIn('"type": "tool_result"', body) + self.assertIn('"stop_reason": "end_turn"', body) self.assertIn('"type": "text_delta"', body) self.assertIn("event: message_stop", body) -class LingmaClientToolEventExtractionTests(unittest.TestCase): + + async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=True), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + self.assertIn("tool_config", spy_client.last_complete_kwargs) + cfg = spy_client.last_complete_kwargs["tool_config"] + self.assertEqual(cfg["provider"], "openai") + self.assertEqual(len(cfg["tools"]), 1) + self.assertIsInstance(cfg["tool_choice"], dict) + + async def test_openai_non_stream_does_not_forward_tool_config_when_disabled(self) -> None: + spy_client = _SpyClient(stream_events=[], complete_result={"text": "ok", "toolEvents": []}) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "hi"}], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=False), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + self.assertIn("tool_config", spy_client.last_complete_kwargs) + self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) + + + async def test_openai_tooling_context_disables_session_reuse_cache(self) -> None: + fake_cache = _FakeSessionCache() + fake_client = _FakeClient( + stream_events=[], + complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-3"}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[ + {"role": "user", "content": "turn-1"}, + {"role": "user", "content": "turn-2"}, + ], + stream=False, + tools=[{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + tool_choice={"type": "function", "function": {"name": "lookup"}}, + ) + + with ( + patch.object(main, "session_cache", fake_cache), + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + _SettingsPatch(tool_forward_enabled=True), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + self.assertEqual(fake_cache.keys, []) + self.assertEqual(fake_cache.get_calls, []) + self.assertEqual(fake_cache.put_calls, []) + + async def test_anthropic_tooling_context_disables_session_reuse_cache(self) -> None: + fake_cache = _FakeSessionCache() + fake_client = _FakeClient( + stream_events=[], + complete_result={"text": "ok", "toolEvents": [], "sessionId": "sess-4"}, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[ + {"role": "user", "content": "turn-1"}, + {"role": "user", "content": "turn-2"}, + ], + stream=False, + tools=[{"name": "lookup", "input_schema": {"type": "object", "properties": {}}}], + tool_choice={"type": "auto"}, + ) + + with ( + patch.object(main, "session_cache", fake_cache), + patch.object(main, "pool", _FakePool(_FakeInstance(fake_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object(main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"})), + patch.object(main.stats_collector, "record_chat", AsyncMock(return_value=None)), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ), + ) + + self.assertEqual(fake_cache.keys, []) + self.assertEqual(fake_cache.get_calls, []) + self.assertEqual(fake_cache.put_calls, []) + + +class SessionCacheToolFingerprintTests(unittest.TestCase): + def test_build_key_changes_with_tool_config(self) -> None: + from app.session_cache import SessionCache + + cache = SessionCache(max_entries=8, ttl_sec=60) + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hello"}, + ] + + cfg_a = { + "provider": "openai", + "tools": [{"type": "function", "function": {"name": "lookup", "parameters": {}}}], + "tool_choice": {"type": "function", "function": {"name": "lookup"}}, + } + cfg_a_reordered = { + "tool_choice": {"function": {"name": "lookup"}, "type": "function"}, + "tools": [{"function": {"parameters": {}, "name": "lookup"}, "type": "function"}], + "provider": "openai", + } + cfg_b = { + "provider": "openai", + "tools": [{"type": "function", "function": {"name": "lookup_v2", "parameters": {}}}], + "tool_choice": {"type": "function", "function": {"name": "lookup_v2"}}, + } + + key_no_tool = cache.build_key("api-key", messages) + key_a = cache.build_key("api-key", messages, tool_config=cfg_a) + key_a_reordered = cache.build_key("api-key", messages, tool_config=cfg_a_reordered) + key_b = cache.build_key("api-key", messages, tool_config=cfg_b) + + self.assertNotEqual(key_no_tool, key_a) + self.assertEqual(key_a, key_a_reordered) + self.assertNotEqual(key_a, key_b) + + + def test_handle_server_message_drops_unroutable_tool_event_without_request_id(self) -> None: + from app.lingma_client import LspWsRpcClient + + rpc = LspWsRpcClient("ws://127.0.0.1:1") + + async def run() -> None: + rpc.create_stream("req-1") + await rpc._handle_server_message( + { + "jsonrpc": "2.0", + "method": "tool/invoke", + "params": { + "name": "lookup", + "parameters": {"q": "x"}, + }, + } + ) + stream = rpc._chat_streams["req-1"] + self.assertEqual(stream["tool_order"], []) + self.assertEqual(stream["tool_states"], {}) + self.assertTrue(stream["chunks"].empty()) + + import asyncio + + asyncio.run(run()) + + def test_handle_server_message_routes_by_tool_map_without_request_id(self) -> None: + from app.lingma_client import LspWsRpcClient + + rpc = LspWsRpcClient("ws://127.0.0.1:1") + + async def run() -> None: + rpc.create_stream("req-1") + + await rpc._handle_server_message( + { + "jsonrpc": "2.0", + "method": "tool/invoke", + "params": { + "requestId": "req-1", + "toolCallId": "call-1", + "name": "lookup", + "parameters": {"q": "a"}, + }, + } + ) + + await rpc._handle_server_message( + { + "jsonrpc": "2.0", + "method": "tool/invokeResult", + "params": { + "toolCallId": "call-1", + "result": {"ok": True}, + }, + } + ) + + result = rpc.get_stream_result("req-1") + self.assertEqual(len(result["toolEvents"]), 1) + self.assertEqual(result["toolEvents"][0]["id"], "call-1") + self.assertEqual(result["toolEvents"][0]["input"], {"q": "a"}) + self.assertEqual(result["toolEvents"][0]["result"], {"ok": True}) + + import asyncio + + asyncio.run(run()) + + def test_handle_server_message_dedupes_identical_repeated_tool_events(self) -> None: + from app.lingma_client import LspWsRpcClient + + rpc = LspWsRpcClient("ws://127.0.0.1:1") + + async def run() -> None: + rpc.create_stream("req-1") + msg = { + "jsonrpc": "2.0", + "method": "tool/invoke", + "params": { + "requestId": "req-1", + "toolCallId": "call-dup", + "name": "lookup", + "parameters": {"q": "dup"}, + }, + } + await rpc._handle_server_message(msg) + await rpc._handle_server_message(msg) + + stream = rpc._chat_streams["req-1"] + self.assertEqual(stream["tool_order"], ["call-dup"]) + self.assertEqual(stream["chunks"].qsize(), 1) + + import asyncio + + asyncio.run(run()) + def test_extracts_tool_event_from_results_and_parameters(self) -> None: from app.lingma_client import LspWsRpcClient