diff --git a/README.md b/README.md index 522ae92..8a3fb8b 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ - OpenAI:`/v1/models`、`/v1/chat/completions`(含 stream) - Anthropic:`/v1/messages`、`/v1/messages/count_tokens`(含 stream) +- 能力探测:`/capabilities`、`/v1/capabilities` +- 内省端点:`/internal/effective-config`、`/internal/debug/requests` - 内置:多实例池、会话复用、Prometheus 指标、登录态 bundle 注入 - 工具事件桥接:Lingma 上游返回 `tool` 事件时,网关会输出为 OpenAI `tool_calls`(stream/non-stream)和 Anthropic `tool_use` / `tool_result`(stream/non-stream);请求侧 `tools` / `tool_choice` 仅在 `TOOL_FORWARD_ENABLED=true` 时透传(默认开启,可显式关闭) - 工具模拟回退:当 Lingma 未稳定外显原生 `tool/*` 事件时,网关会把注入后的 `json action` / `#Tool Call` 等动作文本归一化为 OpenAI `tool_calls`,并支持 tool result continuation @@ -56,6 +58,7 @@ API_KEY=$(grep '^API_KEYS=' .env | cut -d= -f2 | cut -d, -f1) curl -s "http://127.0.0.1:${PORT}/healthz" curl -s "http://127.0.0.1:${PORT}/v1/models" \ -H "Authorization: Bearer ${API_KEY}" +curl -s "http://127.0.0.1:${PORT}/capabilities" ``` --- @@ -172,6 +175,32 @@ curl -s "http://127.0.0.1:${PORT}/v1/messages/count_tokens" \ }' ``` +### 能力探测 + +```bash +curl -s "http://127.0.0.1:${PORT}/capabilities" + +curl -s "http://127.0.0.1:${PORT}/v1/capabilities" \ + -H "x-api-key: ${API_KEY}" \ + -H "anthropic-version: 2023-06-01" +``` + +### 内省端点(admin) + +如果配置了 `ADMIN_TOKEN`,以下端点需要使用该 token;否则会回退复用 `API_KEYS`。 + +```bash +ADMIN_TOKEN=${ADMIN_TOKEN:-$API_KEY} + +curl -s "http://127.0.0.1:${PORT}/internal/effective-config" \ + -H "Authorization: Bearer ${ADMIN_TOKEN}" + +curl -s "http://127.0.0.1:${PORT}/internal/debug/requests?limit=5" \ + -H "Authorization: Bearer ${ADMIN_TOKEN}" +``` + +> `internal/debug/requests` 会对 token、session bundle、data URL 图片和超长工具参数做脱敏/截断。 + --- ## 部署与更新 diff --git a/app/main.py b/app/main.py index 9476543..ae6d24f 100644 --- a/app/main.py +++ b/app/main.py @@ -5,6 +5,7 @@ import hashlib import json import time import uuid +from collections import deque from contextlib import asynccontextmanager from typing import Any @@ -15,6 +16,7 @@ from .anthropic_schema import ( AnthropicMessagesRequest, affinity_key_for_anthropic, anthropic_to_internal_messages, + flatten_anthropic_content, ) from .auth import ( AnthropicAuthError, @@ -112,6 +114,8 @@ STREAMING_RESPONSE_HEADERS = { "Connection": "keep-alive", } +_DEBUG_REQUEST_LOG: deque[dict[str, Any]] = deque(maxlen=100) + def _require_pool() -> LingmaPool: if pool is None: @@ -249,6 +253,63 @@ def _log_auth_posture() -> None: ) +def _safe_setting_value(key: str, value: Any) -> Any: + key_upper = key.upper() + if any( + marker in key_upper + for marker in {"KEY", "TOKEN", "PASSWORD", "SECRET", "BUNDLE"} + ): + if isinstance(value, list): + return ["***" for _ in value] + return "***" + return value + + +def _redact_debug_value(path: tuple[str, ...], value: Any) -> Any: + if isinstance(value, dict): + return { + k: _redact_debug_value(path + (str(k).lower(),), v) + for k, v in value.items() + } + if isinstance(value, list): + return [_redact_debug_value(path + ("[]",), item) for item in value] + if isinstance(value, str): + lowered_path = "/".join(path) + if any(marker in lowered_path for marker in ("authorization", "x-api-key", "api_key", "token", "password", "secret", "session_bundle")): + return "***" + if value.startswith("data:"): + return "[redacted-data-url]" + if "session bundle" in value.lower(): + return "[redacted-session-bundle]" + if any(part in {"args", "arguments"} for part in path) and len(value) > 2048: + return value[:1024] + "... [truncated]" + return value + + +def _record_debug_request(protocol: str, path: str, body: dict[str, Any], request: Request) -> None: + _DEBUG_REQUEST_LOG.appendleft( + { + "timestamp": int(time.time()), + "protocol": protocol, + "path": path, + "request_id": request.headers.get("x-request-id", ""), + "body": _redact_debug_value((), body), + } + ) + + +@app.get("/internal/debug/requests", dependencies=[Depends(admin_auth_guard)]) +async def internal_debug_requests(limit: int = 20): + safe_limit = min(max(limit, 1), 100) + return JSONResponse( + content={ + "ok": True, + "count": min(safe_limit, len(_DEBUG_REQUEST_LOG)), + "items": list(_DEBUG_REQUEST_LOG)[:safe_limit], + } + ) + + @app.get("/healthz") async def healthz(): if pool is None: @@ -267,6 +328,62 @@ async def healthz(): } +def _capabilities_payload() -> dict[str, Any]: + return { + "service": "lingma-openai-gateway", + "version": app.version, + "protocols": { + "openai": { + "models": True, + "chat_completions": True, + "responses": True, + "streaming": True, + "response_tool_calls": True, + "request_tools_forwarded": settings.tool_forward_enabled, + }, + "anthropic": { + "messages": True, + "count_tokens": True, + "streaming": True, + "response_tool_use": True, + "request_tools_forwarded": settings.tool_forward_enabled, + }, + }, + "features": { + "session_reuse": { + "enabled": settings.session_reuse_enabled, + "cache_max_entries": settings.session_cache_max_entries, + "cache_ttl_sec": settings.session_cache_ttl_sec, + }, + "tooling": { + "forward_enabled": settings.tool_forward_enabled, + "allowlist": settings.tool_allowlist, + "emulation_bridge_enabled": True, + }, + "pool": { + "configured_instance_count": settings.instance_count, + "default_model": settings.default_model, + "default_ask_mode": settings.default_ask_mode, + }, + "auth": { + "v1_requires_auth": bool(settings.api_keys), + "admin_token_configured": bool(settings.admin_token), + "metrics_public": settings.metrics_public, + }, + }, + } + + +@app.get("/capabilities") +async def capabilities(): + return JSONResponse(content=_capabilities_payload()) + + +@app.get("/v1/capabilities", dependencies=[Depends(anthropic_auth_guard)]) +async def v1_capabilities(): + return JSONResponse(content=_capabilities_payload()) + + async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: client = inst.client auto_login = inst.auto_login @@ -433,6 +550,75 @@ def _messages_to_prompt(messages: list[dict]) -> str: return "\n".join(parts).strip() +def _assistant_tool_calls_to_emulation_text(tool_calls: Any) -> str: + if not isinstance(tool_calls, list): + return "" + blocks: list[str] = [] + for item in tool_calls: + if not isinstance(item, dict): + continue + fn = item.get("function") if isinstance(item.get("function"), dict) else None + name = str((fn or {}).get("name") or item.get("name") or "").strip() + if not name: + continue + arguments = (fn or {}).get("arguments") + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except Exception: + arguments = {"raw": arguments} + if not isinstance(arguments, dict): + arguments = {} + blocks.append( + "```json action\n" + + json.dumps( + {"tool": name, "parameters": arguments}, ensure_ascii=False, indent=2 + ) + + "\n```" + ) + return "\n\n".join(blocks) + + +def _tool_action_block(name: str, arguments: dict[str, Any]) -> str: + return ( + "```json action\n" + + json.dumps( + {"tool": name, "parameters": arguments}, ensure_ascii=False, indent=2 + ) + + "\n```" + ) + + +def _anthropic_flattened_tool_history_to_emulation_text(text: str) -> str: + if not text: + return "" + out: list[str] = [] + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("[tool_use]"): + raw = stripped[len("[tool_use]") :].strip() + try: + payload = json.loads(raw) + except Exception: + out.append(line) + continue + if not isinstance(payload, dict): + out.append(line) + continue + name = str(payload.get("name") or "").strip() + arguments = payload.get("input") + if name and isinstance(arguments, dict): + out.append(_tool_action_block(name, arguments)) + else: + out.append(line) + continue + if stripped.startswith("[tool_result]"): + out.append(action_output_prompt(None, stripped[len("[tool_result]") :].strip())) + continue + out.append(line) + return "\n".join(part for part in out if part).strip() + + def _messages_to_emulation_prompt( messages: list[dict[str, Any]], *, @@ -446,6 +632,10 @@ def _messages_to_emulation_prompt( if role in {"system", "developer"}: continue text = flatten_content(message.get("content")) + if role == "assistant" and message.get("tool_calls"): + projected = _assistant_tool_calls_to_emulation_text(message.get("tool_calls")) + if projected: + text = "\n\n".join(part for part in [text, projected] if part) if role == "tool": text = action_output_prompt(message.get("tool_call_id"), text) role = "user" @@ -472,6 +662,22 @@ def _messages_to_emulation_prompt( return "\n\n".join(parts).strip() +def _effective_tool_config_for_emulation( + tool_config: dict[str, Any] | None, + *, + use_emulation: bool, +) -> dict[str, Any] | None: + if use_emulation: + return None + return tool_config + + +def _emulation_tools(raw_tools: list[dict[str, Any]] | None, tool_config: dict[str, Any] | None) -> list[dict[str, Any]] | None: + if isinstance(tool_config, dict) and isinstance(tool_config.get("tools"), list): + return tool_config.get("tools") + return raw_tools + + def _anthropic_messages_to_emulation_prompt( messages: list[dict[str, Any]], *, @@ -483,6 +689,10 @@ def _anthropic_messages_to_emulation_prompt( for message in messages: role = str(message.get("role") or "").strip().lower() text = str(message.get("content") or "").strip() + if role == "assistant" and "[tool_use]" in text: + text = _anthropic_flattened_tool_history_to_emulation_text(text) + elif role == "user" and "[tool_result]" in text: + text = _anthropic_flattened_tool_history_to_emulation_text(text) if role == "tool": text = action_output_prompt(message.get("tool_call_id"), text) role = "user" @@ -575,6 +785,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): p = _require_pool() messages_dump = [m.model_dump() for m in req.messages] + _record_debug_request("openai", "/v1/chat/completions", req.model_dump(mode="json"), request) api_key = _extract_api_key(request) or "-" # ------------------------------------------------------------- session reuse @@ -617,9 +828,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): is_reply = execution.is_reply include_usage = _include_usage(req.stream_options) - em_tools = _em_extract_openai_tools(req.tools) + emulation_tools = _emulation_tools(req.tools, tool_config) + em_tools = _em_extract_openai_tools(emulation_tools) em_choice = _em_extract_openai_tool_choice(req.tool_choice) - if _em_has_tool_request(em_tools, em_choice): + use_emulation = has_tooling_context + if use_emulation: system_parts = [ flatten_content(m.content) for m in req.messages @@ -628,9 +841,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): prompt = _messages_to_emulation_prompt( messages_dump, system_text="\n\n".join(system_parts), - tools=req.tools, + tools=emulation_tools, tool_choice=req.tool_choice, ) + execution.prompt = prompt + effective_tool_config = _effective_tool_config_for_emulation( + tool_config, + use_emulation=use_emulation, + ) try: started = await start_execution( @@ -708,7 +926,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, - tool_config=tool_config, + tool_config=effective_tool_config, out_meta=_meta, ): if _stream_event_type(chunk) == "tool": @@ -763,6 +981,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): continue buffered_text_parts.append(text) completion_tokens_holder["n"] += estimate_tokens(text) + if use_emulation: + continue full_text = "".join(buffered_text_parts) if req.tools: @@ -855,9 +1075,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): buffered_text_parts.clear() yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - if buffered_text_parts and forced_tool_name and saw_tool_call: - buffered_text_parts.clear() - if buffered_text_parts and req.tools and not saw_tool_call: merged_text = "".join(buffered_text_parts) inferred = _infer_tool_event_from_declared_tools( @@ -924,6 +1141,11 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" buffered_text_parts = [remaining] if remaining else [] + if buffered_text_parts and saw_tool_call: + text_to_yield = "".join(buffered_text_parts) + buffered_text_parts.clear() + yield _text_payload(text_to_yield) + done_payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -996,7 +1218,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): protocol="chat", execution=execution, prompt_tokens=prompt_tokens, - tool_config=tool_config, + tool_config=effective_tool_config, logger=logger, stats_collector=stats_collector, session_cache=session_cache, @@ -1095,7 +1317,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ask_mode, session_id=None, is_reply=False, - tool_config=tool_config, + tool_config=effective_tool_config, ) retry_text = retry_result.get("text") or "" parsed_calls, remaining = parse_action_blocks(retry_text, em_tools) @@ -1227,6 +1449,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ) messages_dump = anthropic_to_internal_messages(req) + _record_debug_request("anthropic", "/v1/messages", req.model_dump(mode="json"), request) # Prefer the auth token actually accepted so session-cache bucketing is # consistent regardless of which auth header style the caller used. api_key = ( @@ -1284,16 +1507,23 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): model = execution.model prompt = execution.prompt is_reply = execution.is_reply - em_anthropic_tools = _em_extract_anthropic_tools(req.tools) + emulation_tools = _emulation_tools(req.tools, tool_config) + em_anthropic_tools = _em_extract_anthropic_tools(emulation_tools) em_anthropic_choice = _em_extract_anthropic_tool_choice(req.tool_choice) - if _em_has_tool_request(em_anthropic_tools, em_anthropic_choice): + use_emulation = has_tooling_context + if use_emulation: system_text = flatten_anthropic_content(req.system) if req.system else "" prompt = _anthropic_messages_to_emulation_prompt( messages_dump, system_text=system_text, - tools=req.tools, + tools=emulation_tools, tool_choice=req.tool_choice, ) + execution.prompt = prompt + effective_tool_config = _effective_tool_config_for_emulation( + tool_config, + use_emulation=use_emulation, + ) try: started = await start_execution( @@ -1372,7 +1602,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ask_mode, session_id=cached_session_id, is_reply=is_reply, - tool_config=tool_config, + tool_config=effective_tool_config, out_meta=_meta, ): if _stream_event_type(chunk) == "tool": @@ -1703,7 +1933,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): protocol="anthropic", execution=execution, prompt_tokens=prompt_tokens, - tool_config=tool_config, + tool_config=effective_tool_config, logger=logger, stats_collector=stats_collector, session_cache=session_cache, @@ -1757,10 +1987,8 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): text = remaining if not saw_tool_event and em_anthropic_tools: - inferred_call = infer_declared_tool_call_from_text(text, em_anthropic_tools) - if inferred_call is None: - inferred_calls = infer_tool_calls_from_text(text, em_anthropic_tools) - inferred_call = inferred_calls[0] if inferred_calls else None + inferred_calls = infer_tool_calls_from_text(text, em_anthropic_tools) + inferred_call = inferred_calls[0] if inferred_calls else None if inferred_call is not None: content_blocks = [ { @@ -1774,7 +2002,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): saw_pending_tool_use = True text = "" - if not saw_tool_event and em_anthropic_tools: + if not saw_tool_event and em_anthropic_tools and not text.strip(): retry_prompt = f"{prompt}\n\n{force_tooling_prompt(em_anthropic_choice)}" retry_result = await inst.client.chat_complete( retry_prompt, @@ -1782,53 +2010,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ask_mode, session_id=None, is_reply=False, - tool_config=tool_config, - ) - retry_text = retry_result.get("text") or "" - parsed_calls, remaining = parse_action_blocks(retry_text, em_anthropic_tools) - if parsed_calls: - content_blocks = [] - if remaining: - content_blocks.append({"type": "text", "text": remaining}) - for call in parsed_calls: - content_blocks.append( - { - "type": "tool_use", - "id": call.id, - "name": call.name, - "input": call.arguments, - } - ) - saw_tool_event = True - saw_pending_tool_use = True - text = remaining - else: - inferred_call = infer_declared_tool_call_from_text(retry_text, em_anthropic_tools) - if inferred_call is None: - inferred_calls = infer_tool_calls_from_text(retry_text, em_anthropic_tools) - inferred_call = inferred_calls[0] if inferred_calls else None - if inferred_call is not None: - content_blocks = [ - { - "type": "tool_use", - "id": inferred_call.id, - "name": inferred_call.name, - "input": inferred_call.arguments, - } - ] - saw_tool_event = True - saw_pending_tool_use = True - text = "" - - if not saw_tool_event and em_anthropic_tools and text.strip(): - retry_prompt = f"{prompt}\n\n{force_tooling_prompt(em_anthropic_choice)}" - retry_result = await inst.client.chat_complete( - retry_prompt, - model, - ask_mode, - session_id=None, - is_reply=False, - tool_config=tool_config, + tool_config=effective_tool_config, ) retry_text = retry_result.get("text") or "" parsed_calls, remaining = parse_action_blocks(retry_text, em_anthropic_tools) @@ -2090,6 +2272,60 @@ async def internal_stats(): } +@app.get("/internal/effective-config", dependencies=[Depends(admin_auth_guard)]) +async def internal_effective_config(): + cfg = settings + return JSONResponse(content={ + "ok": True, + "settings": { + "host": cfg.host, + "port": cfg.port, + "api_keys": _safe_setting_value("api_keys", cfg.api_keys), + "metrics_token": _safe_setting_value("metrics_token", cfg.metrics_token), + "admin_token": _safe_setting_value("admin_token", cfg.admin_token), + "metrics_public": cfg.metrics_public, + "log_level": cfg.log_level, + "gateway_max_in_flight": cfg.gateway_max_in_flight, + "gateway_queue_timeout_sec": cfg.gateway_queue_timeout_sec, + "lingma_bin": cfg.lingma_bin, + "lingma_work_dir": cfg.lingma_work_dir, + "lingma_socket_port": cfg.lingma_socket_port, + "lingma_startup_timeout": cfg.lingma_startup_timeout, + "lingma_rpc_timeout": cfg.lingma_rpc_timeout, + "default_model": cfg.default_model, + "default_ask_mode": cfg.default_ask_mode, + "dedicated_domain_url": cfg.dedicated_domain_url, + "auto_login_enabled": cfg.auto_login_enabled, + "auto_login_headless": cfg.auto_login_headless, + "auto_login_timeout": cfg.auto_login_timeout, + "auto_login_max_retry": cfg.auto_login_max_retry, + "instance_count": cfg.instance_count, + "session_reuse_enabled": cfg.session_reuse_enabled, + "session_cache_max_entries": cfg.session_cache_max_entries, + "session_cache_ttl_sec": cfg.session_cache_ttl_sec, + "tool_forward_enabled": cfg.tool_forward_enabled, + "tool_allowlist": cfg.tool_allowlist, + "accounts": [ + { + "username": account.username, + "password": _safe_setting_value("password", account.password), + "session_bundle_b64": _safe_setting_value( + "session_bundle_b64", account.session_bundle_b64 + ), + "session_bundle_file": account.session_bundle_file, + } + for account in cfg.accounts + ], + }, + "feature_flags": { + "tool_forward_enabled": cfg.tool_forward_enabled, + "session_reuse_enabled": cfg.session_reuse_enabled, + "metrics_public": cfg.metrics_public, + "auto_login_enabled": cfg.auto_login_enabled, + }, + }) + + @app.get("/metrics", dependencies=[Depends(metrics_auth_guard)]) async def metrics(): base = await stats_collector.prometheus_text() diff --git a/requirements.txt b/requirements.txt index e2c9c76..a5bff5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ fastapi==0.115.0 +starlette==0.38.6 uvicorn[standard]==0.30.6 websockets==13.1 pydantic==2.9.2 diff --git a/tests/test_auth_concurrency.py b/tests/test_auth_concurrency.py index d8699c1..65b126c 100644 --- a/tests/test_auth_concurrency.py +++ b/tests/test_auth_concurrency.py @@ -1,14 +1,37 @@ from __future__ import annotations import asyncio +import sys +import types import unittest +from unittest.mock import patch from fastapi import HTTPException +from fastapi.testclient import TestClient from starlette.requests import Request from app.auth import AnthropicAuthError, require_anthropic_key, require_bearer, require_metrics_access from app.concurrency import BackpressureRejected, InFlightGuard +_playwright = types.ModuleType("playwright") +_playwright_async = types.ModuleType("playwright.async_api") + + +class _StubPlaywrightTimeoutError(Exception): + pass + + +async def _stub_async_playwright(): + raise RuntimeError("playwright is stubbed in unit tests") + + +_playwright_async.TimeoutError = _StubPlaywrightTimeoutError +_playwright_async.async_playwright = _stub_async_playwright +sys.modules.setdefault("playwright", _playwright) +sys.modules.setdefault("playwright.async_api", _playwright_async) + +import app.main as main + def _req(headers: dict[str, str] | None = None) -> Request: pairs = [] @@ -82,5 +105,48 @@ class AuthAndConcurrencyTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(guard.in_flight, 0) +class DebugRequestRecordingTests(unittest.TestCase): + def setUp(self) -> None: + main._DEBUG_REQUEST_LOG.clear() + + def test_redacts_sensitive_fields_and_data_urls(self) -> None: + body = { + "authorization": "Bearer abc", + "x-api-key": "secret", + "session_bundle": "very-secret", + "images": ["data:image/png;base64,ABC"], + "tool": {"args": "x" * 3000}, + } + redacted = main._redact_debug_value((), body) + + self.assertEqual(redacted["authorization"], "***") + self.assertEqual(redacted["x-api-key"], "***") + self.assertEqual(redacted["session_bundle"], "***") + self.assertEqual(redacted["images"][0], "[redacted-data-url]") + self.assertIn("[truncated]", redacted["tool"]["args"]) + + def test_internal_debug_requests_requires_admin_and_returns_items(self) -> None: + with patch.object(main.settings, "api_keys", ["k1"]), patch.object(main.settings, "admin_token", "admin-1"): + client = TestClient(main.app) + req_payload = { + "model": "org_auto", + "messages": [{"role": "user", "content": "hello"}], + } + main._record_debug_request("openai", "/v1/chat/completions", req_payload, _req({"x-request-id": "req-1"})) + + denied = client.get("/internal/debug/requests") + self.assertEqual(denied.status_code, 401) + + ok = client.get( + "/internal/debug/requests?limit=1", + headers={"Authorization": "Bearer admin-1"}, + ) + self.assertEqual(ok.status_code, 200) + data = ok.json() + self.assertTrue(data["ok"]) + self.assertEqual(data["count"], 1) + self.assertEqual(data["items"][0]["protocol"], "openai") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tool_call_bridge.py b/tests/test_tool_call_bridge.py index a1c2a94..e9f0833 100644 --- a/tests/test_tool_call_bridge.py +++ b/tests/test_tool_call_bridge.py @@ -5,6 +5,7 @@ import sys import types import unittest import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, patch @@ -1251,7 +1252,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('"type": "tool_result"', body) self.assertIn('"stop_reason": "end_turn"', body) - async def test_openai_non_stream_forwards_tool_config_when_enabled(self) -> None: + async def test_openai_non_stream_uses_emulation_instead_of_forwarding_tool_config(self) -> None: spy_client = _SpyClient( stream_events=[], complete_result={"text": "ok", "toolEvents": []} ) @@ -1279,13 +1280,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) self.assertIn("tool_config", spy_client.last_complete_kwargs) - cfg = spy_client.last_complete_kwargs["tool_config"] - self.assertEqual(cfg["provider"], "openai") - self.assertEqual(len(cfg["tools"]), 1) - self.assertIsInstance(cfg["tool_choice"], dict) + self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) self.assertEqual(spy_client.last_complete_args[2], "agent") - async def test_openai_stream_forwards_tool_config_when_enabled(self) -> None: + async def test_openai_stream_uses_emulation_instead_of_forwarding_tool_config(self) -> None: spy_client = _SpyClient( stream_events=[{"type": "text", "text": "ok"}], complete_result={} ) @@ -1316,10 +1314,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): await _collect_stream(response) self.assertIn("tool_config", spy_client.last_stream_kwargs) - cfg = spy_client.last_stream_kwargs["tool_config"] - self.assertEqual(cfg["provider"], "openai") - self.assertEqual(len(cfg["tools"]), 1) - self.assertIsInstance(cfg["tool_choice"], dict) + self.assertIsNone(spy_client.last_stream_kwargs["tool_config"]) self.assertEqual(spy_client.last_stream_args[2], "agent") async def test_openai_non_stream_does_not_forward_tool_config_when_disabled( @@ -1355,7 +1350,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) self.assertEqual(spy_client.last_complete_args[2], "agent") - async def test_openai_non_stream_filters_tools_by_allowlist(self) -> None: + async def test_openai_non_stream_filters_tools_by_allowlist_before_emulation(self) -> None: spy_client = _SpyClient( stream_events=[], complete_result={"text": "ok", "toolEvents": []} ) @@ -1386,11 +1381,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): ): await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) - cfg = spy_client.last_complete_kwargs["tool_config"] - self.assertEqual( - [tool["function"]["name"] for tool in cfg["tools"]], ["lookup"] - ) - self.assertEqual(cfg["tool_choice"], req.tool_choice) + prompt = spy_client.last_complete_args[0] + self.assertIn("lookup(", prompt) + self.assertNotIn("write_file(", prompt) async def test_openai_non_stream_rejects_forced_tool_outside_allowlist( self, @@ -1579,7 +1572,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(openai_spy.last_complete_args[2], "chat") self.assertEqual(anthropic_spy.last_complete_args[2], "chat") - async def test_anthropic_stream_forwards_tool_config_when_enabled(self) -> None: + async def test_anthropic_stream_uses_emulation_instead_of_forwarding_tool_config(self) -> None: spy_client = _SpyClient( stream_events=[{"type": "text", "text": "ok"}], complete_result={} ) @@ -1619,9 +1612,7 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): await _collect_stream(response) self.assertIn("tool_config", spy_client.last_stream_kwargs) - cfg = spy_client.last_stream_kwargs["tool_config"] - self.assertEqual(cfg["provider"], "anthropic") - self.assertEqual(len(cfg["tools"]), 1) + self.assertIsNone(spy_client.last_stream_kwargs["tool_config"]) self.assertEqual(spy_client.last_stream_args[2], "agent") async def test_anthropic_non_stream_does_not_forward_tool_config_when_disabled( @@ -1710,12 +1701,10 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): ) self.assertIn("tool_config", spy_client.last_complete_kwargs) - cfg = spy_client.last_complete_kwargs["tool_config"] - self.assertEqual(cfg["provider"], "anthropic") - self.assertEqual(len(cfg["tools"]), 1) + self.assertIsNone(spy_client.last_complete_kwargs["tool_config"]) self.assertEqual(spy_client.last_complete_args[2], "agent") - async def test_anthropic_non_stream_filters_tools_by_allowlist(self) -> None: + async def test_anthropic_non_stream_filters_tools_by_allowlist_before_emulation(self) -> None: spy_client = _SpyClient( stream_events=[], complete_result={"text": "ok", "toolEvents": []} ) @@ -1760,9 +1749,9 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): ), ) - cfg = spy_client.last_complete_kwargs["tool_config"] - self.assertEqual([tool["name"] for tool in cfg["tools"]], ["lookup"]) - self.assertEqual(cfg["tool_choice"], req.tool_choice) + prompt = spy_client.last_complete_args[0] + self.assertIn("lookup(", prompt) + self.assertNotIn("write_file(", prompt) async def test_anthropic_non_stream_rejects_forced_tool_outside_allowlist( self, @@ -2183,6 +2172,201 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertIn('{"temperature":"22C"}', prompt) self.assertIn("Assistant:", prompt) + async def test_openai_assistant_tool_calls_are_projected_into_emulation_prompt(self) -> None: + spy_client = _SpyClient( + stream_events=[], + complete_result={ + "text": "done", + "toolEvents": [], + "sessionId": "sess-emulated-tool-history", + }, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[ + { + "role": "assistant", + "content": "I will check that", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "fetch_weather", + "arguments": '{"city":"Hangzhou"}', + }, + } + ], + }, + {"role": "user", "content": "continue"}, + ], + stream=False, + tools=[ + { + "type": "function", + "function": { + "name": "fetch_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + prompt = spy_client.last_complete_args[0] + self.assertIn("I will check that", prompt) + self.assertIn('"tool": "fetch_weather"', prompt) + self.assertIn('"city": "Hangzhou"', prompt) + + async def test_openai_emulation_prompt_includes_proxy_tool_guidance(self) -> None: + spy_client = _SpyClient( + stream_events=[], + complete_result={"text": "done", "toolEvents": [], "sessionId": "sess-guidance"}, + ) + req = ChatCompletionsRequest( + model="org_auto", + messages=[{"role": "user", "content": "inspect README"}], + stream=False, + tools=[ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "bash", + "parameters": { + "type": "object", + "properties": {"command": {"type": "string"}}, + }, + }, + }, + ], + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), + ): + await main.v1_chat_completions(req, _make_request("/v1/chat/completions")) + + prompt = spy_client.last_complete_args[0] + self.assertIn("DIRECT tool access inside an IDE", prompt) + self.assertIn("Tool routing guide:", prompt) + self.assertIn("Read a specific local file or code path: use read_file.", prompt) + self.assertIn("Core tool syntax examples", prompt) + self.assertIn("Coding and file-work discipline:", prompt) + self.assertIn("NEVER say that tools are unavailable", prompt) + + async def test_anthropic_tool_history_is_projected_into_emulation_prompt(self) -> None: + spy_client = _SpyClient( + stream_events=[], + complete_result={ + "text": "done", + "toolEvents": [], + "sessionId": "sess-anthropic-history", + }, + ) + req = AnthropicMessagesRequest( + model="claude-3-5-sonnet-20241022", + max_tokens=128, + messages=[ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I will check"}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "fetch_weather", + "input": {"city": "Hangzhou"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_1", + "content": '{"temperature":"22C"}', + } + ], + }, + {"role": "user", "content": "continue"}, + ], + stream=False, + tools=[ + { + "name": "fetch_weather", + "input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + } + ], + ) + + with ( + patch.object(main, "pool", _FakePool(_FakeInstance(spy_client))), + patch.object(main, "chat_guard", _FakeGuard()), + patch.object( + main, "_ensure_instance_logged_in", AsyncMock(return_value={"id": "u"}) + ), + patch.object( + main.stats_collector, "record_chat", AsyncMock(return_value=None) + ), + patch.object(main.settings, "api_keys", ["test-key"]), + ): + await main.v1_messages( + req, + _make_request( + "/v1/messages", + headers={ + "x-api-key": "test-key", + "anthropic-version": "2023-06-01", + }, + ), + ) + + prompt = spy_client.last_complete_args[0] + self.assertIn("I will check", prompt) + self.assertIn('"tool": "fetch_weather"', prompt) + self.assertIn('"city": "Hangzhou"', prompt) + self.assertIn("Tool result:", prompt) + self.assertIn('{"temperature":"22C"}', prompt) + async def test_anthropic_non_stream_synthesizes_tool_use_from_json_action_block( self, ) -> None: @@ -2434,6 +2618,177 @@ class ToolCallBridgeTests(unittest.IsolatedAsyncioTestCase): self.assertEqual(detail["error"]["message"], "invalid upstream response") +class CapabilitiesEndpointTests(unittest.IsolatedAsyncioTestCase): + async def test_capabilities_payload_shape(self) -> None: + with ( + patch.object(main.settings, "tool_forward_enabled", True), + patch.object(main.settings, "tool_allowlist", ["lookup"]), + patch.object(main.settings, "session_reuse_enabled", True), + patch.object(main.settings, "session_cache_max_entries", 123), + patch.object(main.settings, "session_cache_ttl_sec", 45.0), + patch.object(main.settings, "instance_count", 2), + patch.object(main.settings, "default_model", "org_auto"), + patch.object(main.settings, "default_ask_mode", "chat"), + patch.object(main.settings, "api_keys", ["test-key"]), + patch.object(main.settings, "admin_token", "adm"), + patch.object(main.settings, "metrics_public", False), + ): + response = await main.capabilities() + + self.assertEqual(response.status_code, 200) + payload = json.loads(response.body) + self.assertEqual(payload["service"], "lingma-openai-gateway") + self.assertIn("protocols", payload) + self.assertIn("features", payload) + self.assertTrue(payload["protocols"]["openai"]["chat_completions"]) + self.assertTrue(payload["protocols"]["anthropic"]["messages"]) + self.assertTrue(payload["protocols"]["openai"]["request_tools_forwarded"]) + self.assertEqual(payload["features"]["tooling"]["allowlist"], ["lookup"]) + self.assertEqual(payload["features"]["pool"]["configured_instance_count"], 2) + self.assertTrue(payload["features"]["auth"]["v1_requires_auth"]) + + async def test_v1_capabilities_auth_guard_requires_authentication(self) -> None: + with patch.object(main.settings, "api_keys", ["test-key"]): + with self.assertRaises(main.AnthropicAuthError) as ctx: + main.anthropic_auth_guard( + _make_request( + "/v1/capabilities", + headers={"anthropic-version": "2023-06-01"}, + ) + ) + + self.assertEqual(ctx.exception.status_code, 401) + + async def test_v1_capabilities_returns_payload_with_auth(self) -> None: + with ( + patch.object(main.settings, "api_keys", ["test-key"]), + patch.object(main.settings, "tool_forward_enabled", False), + ): + main.anthropic_auth_guard( + _make_request( + "/v1/capabilities", + headers={"x-api-key": "test-key", "anthropic-version": "2023-06-01"}, + ) + ) + response = await main.v1_capabilities() + + self.assertEqual(response.status_code, 200) + payload = json.loads(response.body) + self.assertFalse(payload["protocols"]["openai"]["request_tools_forwarded"]) + + + +class AdminIntrospectionEndpointTests(unittest.IsolatedAsyncioTestCase): + async def test_internal_effective_config_requires_admin_token(self) -> None: + with ( + patch.object(main.settings, "api_keys", ["api-key"]), + patch.object(main.settings, "admin_token", "admin-secret"), + ): + with self.assertRaises(main.HTTPException) as ctx: + main.admin_auth_guard( + _make_request( + "/internal/effective-config", + headers={"authorization": "Bearer wrong-token"}, + ) + ) + + self.assertEqual(ctx.exception.status_code, 401) + + async def test_internal_effective_config_redacts_secrets(self) -> None: + with ( + patch.object(main.settings, "api_keys", ["api-key-1", "api-key-2"]), + patch.object(main.settings, "admin_token", "admin-secret"), + patch.object(main.settings, "metrics_token", "metrics-secret"), + patch.object(main.settings, "default_model", "org_auto"), + patch.object(main.settings, "tool_forward_enabled", True), + patch.object(main.settings, "session_reuse_enabled", True), + patch.object(main.settings, "metrics_public", False), + patch.object(main.settings, "auto_login_enabled", True), + patch.object( + main.settings, + "accounts", + [ + SimpleNamespace( + username="user-a", + password="pass-a", + session_bundle_b64="bundle-a", + session_bundle_file="/secrets/bundle-a.txt", + ) + ], + ), + ): + main.admin_auth_guard( + _make_request( + "/internal/effective-config", + headers={"authorization": "Bearer admin-secret"}, + ) + ) + response = await main.internal_effective_config() + + self.assertEqual(response.status_code, 200) + payload = json.loads(response.body) + settings_payload = payload["settings"] + self.assertEqual(settings_payload["api_keys"], ["***", "***"]) + self.assertEqual(settings_payload["admin_token"], "***") + self.assertEqual(settings_payload["metrics_token"], "***") + self.assertEqual(settings_payload["accounts"][0]["password"], "***") + self.assertEqual(settings_payload["accounts"][0]["session_bundle_b64"], "***") + self.assertEqual(settings_payload["accounts"][0]["username"], "user-a") + self.assertEqual( + settings_payload["accounts"][0]["session_bundle_file"], + "/secrets/bundle-a.txt", + ) + self.assertTrue(payload["feature_flags"]["tool_forward_enabled"]) + self.assertTrue(payload["feature_flags"]["session_reuse_enabled"]) + + async def test_internal_debug_requests_redacts_sensitive_fields(self) -> None: + main._DEBUG_REQUEST_LOG.clear() + main._record_debug_request( + "openai", + "/v1/chat/completions", + { + "api_key": "secret-key", + "session_bundle": "bundle-value", + "image_url": "data:image/png;base64,abcd", + "tool_calls": [ + { + "function": { + "arguments": "x" * 3001, + } + } + ], + }, + _make_request("/v1/chat/completions", headers={"x-request-id": "req-123"}), + ) + + response = await main.internal_debug_requests(limit=10) + + self.assertEqual(response.status_code, 200) + payload = json.loads(response.body) + self.assertEqual(payload["count"], 1) + item = payload["items"][0] + self.assertEqual(item["request_id"], "req-123") + self.assertEqual(item["body"]["api_key"], "***") + self.assertEqual(item["body"]["session_bundle"], "***") + self.assertEqual(item["body"]["image_url"], "[redacted-data-url]") + self.assertTrue(item["body"]["tool_calls"][0]["function"]["arguments"].endswith("... [truncated]")) + + async def test_internal_debug_requests_requires_admin_token(self) -> None: + with ( + patch.object(main.settings, "api_keys", ["api-key"]), + patch.object(main.settings, "admin_token", "admin-secret"), + ): + with self.assertRaises(main.HTTPException) as ctx: + main.admin_auth_guard( + _make_request( + "/internal/debug/requests", + headers={"authorization": "Bearer wrong-token"}, + ) + ) + + self.assertEqual(ctx.exception.status_code, 401) + + class SessionCacheToolFingerprintTests(unittest.TestCase): def test_build_key_changes_with_tool_config(self) -> None: from app.session_cache import SessionCache