diff --git a/.env.example b/.env.example index b935136..82e2d1c 100644 --- a/.env.example +++ b/.env.example @@ -1,22 +1,14 @@ +# ==================== 必要配置(先填这部分) ==================== + # 网关监听地址 HOST=0.0.0.0 # 网关监听端口 PORT=8317 -# API Key,可配置多个(逗号分隔)。空 = 不鉴权(启动会打 warning),仅用于本地 dev -API_KEYS=sk-your-api-key -# 独立的 /metrics 鉴权 token(留空则退化为 API_KEYS 亦可访问;若与 API_KEYS 同时为空,/metrics 默认 503) -METRICS_TOKEN= -# 显式把 /metrics 设为公开(仅在私网采集器场景使用) -METRICS_PUBLIC=false -# 独立的 /internal/* 管理 token(留空则退化为 API_KEYS);强烈建议生产环境单独配置 -ADMIN_TOKEN= -# 日志级别(DEBUG / INFO / WARNING / ERROR) -LOG_LEVEL=INFO -# /v1/chat/completions 并发上限(<=0 表示不限流) -GATEWAY_MAX_IN_FLIGHT=4 -# 排队等待超时秒数,超过后返回 429 + Retry-After -GATEWAY_QUEUE_TIMEOUT_SEC=30 +# API Key,可配置多个(逗号分隔)。空 = 不鉴权(仅建议本地 dev) +API_KEYS=sk-your-api-key +# /internal/* 管理 token(留空则退化为 API_KEYS) +ADMIN_TOKEN= # 容器内 Lingma 二进制路径 LINGMA_BIN=/app/data/bin/Lingma @@ -26,12 +18,11 @@ LINGMA_SOURCE_TYPE=marketplace LINGMA_MARKETPLACE_PUBLISHER=Alibaba-Cloud # Marketplace 扩展名 LINGMA_MARKETPLACE_EXTENSION=tongyi-lingma -# VSIX 下载地址(最新优先) -LINGMA_VSIX_URL=https://tongyi-code.oss-cn-hangzhou.aliyuncs.com/vscode/tongyi-lingma-latest.vsix # 启动时总是尝试从 VSIX 刷新二进制 LINGMA_BOOTSTRAP_ALWAYS=true # 强制刷新(true 时忽略本地缓存) LINGMA_FORCE_REFRESH=false + # Lingma 工作目录(登录/会话数据) LINGMA_WORK_DIR=/app/data/.lingma/vscode/sharedClientCache # Lingma WebSocket 端口 @@ -43,11 +34,39 @@ LINGMA_RPC_TIMEOUT=30 # 默认模型(无法映射时使用) DEFAULT_MODEL=org_auto -# 默认模式:chat 或 agent -DEFAULT_ASK_MODE=chat +# 默认模式:chat 或 agent(工具调用建议 agent) +DEFAULT_ASK_MODE=agent -# 请求侧 tools/tool_choice 透传到 Lingma(默认开启,可显式关闭) +# 请求侧 tools/tool_choice 透传到 Lingma(工具调用建议开启) TOOL_FORWARD_ENABLED=true + +# 登录方式(二选一) +# A. 账号密码(单实例) +LINGMA_USERNAME= +LINGMA_PASSWORD= +# B. 会话 bundle(推荐生产) +# LINGMA_SESSION_BUNDLE= +# LINGMA_SESSION_BUNDLE_FILE=/secrets/lingma-session.b64 + + +# ==================== 可选配置(按需) ==================== + +# 独立的 /metrics 鉴权 token(留空则退化为 API_KEYS 亦可访问) +METRICS_TOKEN= +# 显式把 /metrics 设为公开(仅私网采集器场景) +METRICS_PUBLIC=false + +# 日志级别(DEBUG / INFO / WARNING / ERROR) +LOG_LEVEL=INFO + +# /v1/chat/completions 并发上限(<=0 表示不限流) +GATEWAY_MAX_IN_FLIGHT=4 +# 排队等待超时秒数,超过后返回 429 + Retry-After +GATEWAY_QUEUE_TIMEOUT_SEC=30 + +# VSIX 下载地址(仅 LINGMA_SOURCE_TYPE=vsix 或 marketplace 回退时使用) +LINGMA_VSIX_URL=https://tongyi-code.oss-cn-hangzhou.aliyuncs.com/vscode/tongyi-lingma-latest.vsix + # 可选:允许透传的工具名白名单,逗号分隔;为空表示不额外限制 TOOL_ALLOWLIST= @@ -63,41 +82,15 @@ AUTO_LOGIN_TIMEOUT=180 # 自动登录重试次数 AUTO_LOGIN_MAX_RETRY=2 -# Lingma 登录用户名(仅当 LINGMA_ACCOUNTS 为空时生效,单实例模式) -LINGMA_USERNAME= -# Lingma 登录密码(仅当 LINGMA_ACCOUNTS 为空时生效) -LINGMA_PASSWORD= - -# ==== 多实例池(方案乙:多账号) ==== +# ==== 多实例池(可选) ==== # 多账号列表,支持两种格式: # CSV: user1:pass1,user2:pass2 # JSON: [{"username":"u1","password":"p1"},{"username":"u2","password":"p2"}] -# 配置后每个账号对应一个独立 Lingma 实例(独立 workDir + 独立自动登录) LINGMA_ACCOUNTS= -# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用并打 warning +# 实例数量:默认等于 LINGMA_ACCOUNTS 数;显式指定时账号不足会循环复用 LINGMA_INSTANCE_COUNT= -# ==== 登录态注入:跳过 Playwright 自动登录 ==== -# 方式 1:base64 字符串,内容 = tar.gz(workDir/cache/{id,user,quota,config.json}) -# 通过 `POST /internal/session/export` 从另一个已登录实例导出得到。 -# 配了这个就可以不填 LINGMA_USERNAME / LINGMA_PASSWORD。 -# LINGMA_SESSION_BUNDLE= - -# 方式 2:指向宿主机上的 bundle 文件路径(文件内容即 base64 字符串) -# LINGMA_SESSION_BUNDLE_FILE=/secrets/lingma-session.b64 - -# 多账号时走 JSON 模式,每个账号可以独立带 session_bundle: -# LINGMA_ACCOUNTS=[ -# {"username":"u1","password":"p1","session_bundle":"H4sI..."}, -# {"username":"u2","password":"p2","session_bundle_file":"/secrets/u2.b64"} -# ] -# 注意:一旦 workDir 里已经有登录态(cache/user 非空),bundle 会被跳过, -# 你手动登录的 / 旧容器的登录态不会被覆盖。 - -# ==== 会话复用(多轮对话命中上游 KV cache,减少首 token 延迟) ==== -# 开关(默认开) +# ==== 会话复用(可选,默认开) ==== SESSION_REUSE_ENABLED=true -# 最多缓存多少条会话 (LRU) SESSION_CACHE_MAX_ENTRIES=256 -# 会话 TTL 秒数;超时自动失效,避免 Lingma 侧早已回收还在命中 SESSION_CACHE_TTL_SEC=1800 diff --git a/app/bootstrap_lingma.py b/app/bootstrap_lingma.py index 8c3ffdf..61f651c 100644 --- a/app/bootstrap_lingma.py +++ b/app/bootstrap_lingma.py @@ -40,7 +40,39 @@ def _pick_lingma_binary_path(inner_zip: zipfile.ZipFile) -> str: raise RuntimeError("Lingma binary not found inside nested zip") -def _query_marketplace_latest_vsix(publisher: str, extension: str) -> tuple[str, str, dict]: +def _infer_release_root(member_path: str) -> str: + parts = [p for p in member_path.split("/") if p] + if "x86_64_linux" in parts: + idx = parts.index("x86_64_linux") + if idx > 0: + return "/".join(parts[:idx]) + if len(parts) > 1: + return parts[0] + return "" + + +def _extract_release_tree( + inner_zip: zipfile.ZipFile, release_root: str, out_dir: Path +) -> None: + prefix = f"{release_root}/" if release_root else "" + for info in inner_zip.infolist(): + name = info.filename + if not name or name.endswith("/"): + continue + if prefix and not name.startswith(prefix): + continue + rel = name[len(prefix) :] if prefix else name + if not rel: + continue + dest = out_dir / rel + dest.parent.mkdir(parents=True, exist_ok=True) + with inner_zip.open(info, "r") as src, dest.open("wb") as dst: + dst.write(src.read()) + + +def _query_marketplace_latest_vsix( + publisher: str, extension: str +) -> tuple[str, str, dict]: api = "https://marketplace.visualstudio.com/_apis/public/gallery/extensionquery" payload = { "filters": [ @@ -58,7 +90,9 @@ def _query_marketplace_latest_vsix(publisher: str, extension: str) -> tuple[str, "assetTypes": [], "flags": 950, } - req = urllib.request.Request(api, data=json.dumps(payload).encode("utf-8"), method="POST") + req = urllib.request.Request( + api, data=json.dumps(payload).encode("utf-8"), method="POST" + ) req.add_header("accept", "application/json;api-version=3.0-preview.1") req.add_header("content-type", "application/json") req.add_header("x-market-client-id", "VSCode 1.115.0") @@ -83,7 +117,11 @@ def _query_marketplace_latest_vsix(publisher: str, extension: str) -> tuple[str, "https://marketplace.visualstudio.com/_apis/public/gallery/" f"publishers/{publisher}/vsextensions/{extension}/{version}/vspackage" ) - return vsix_url, version, {"publisher": publisher, "extension": extension, "version": version} + return ( + vsix_url, + version, + {"publisher": publisher, "extension": extension, "version": version}, + ) def bootstrap_from_vsix() -> None: @@ -106,7 +144,9 @@ def bootstrap_from_vsix() -> None: old_marker = {} if marker_path.exists(): try: - old_marker = json.loads(marker_path.read_text(encoding="utf-8", errors="ignore")) + old_marker = json.loads( + marker_path.read_text(encoding="utf-8", errors="ignore") + ) except Exception: old_marker = {} @@ -115,15 +155,17 @@ def bootstrap_from_vsix() -> None: source_meta = {"source": source_type} if source_type == "marketplace": try: - resolved_url, resolved_version, source_meta = _query_marketplace_latest_vsix( - mp_publisher, mp_extension + resolved_url, resolved_version, source_meta = ( + _query_marketplace_latest_vsix(mp_publisher, mp_extension) ) print( f"[bootstrap] marketplace latest: {mp_publisher}.{mp_extension} " f"version={resolved_version}" ) except Exception as exc: - print(f"[bootstrap] marketplace query failed, fallback to LINGMA_VSIX_URL: {exc}") + print( + f"[bootstrap] marketplace query failed, fallback to LINGMA_VSIX_URL: {exc}" + ) resolved_url = vsix_url if ( @@ -144,9 +186,18 @@ def bootstrap_from_vsix() -> None: print(f"[bootstrap] downloading VSIX: {resolved_url}") try: - with urllib.request.urlopen(resolved_url, timeout=120) as r: - data = r.read() - vsix_path.write_bytes(data) + with ( + urllib.request.urlopen(resolved_url, timeout=30) as r, + vsix_path.open("wb") as f, + ): + total = 0 + while True: + chunk = r.read(1024 * 1024) + if not chunk: + break + f.write(chunk) + total += len(chunk) + print(f"[bootstrap] VSIX downloaded bytes={total}") except Exception as exc: if lingma_bin.exists(): print(f"[bootstrap] download failed, fallback to existing Lingma: {exc}") @@ -162,10 +213,18 @@ def bootstrap_from_vsix() -> None: with zipfile.ZipFile(io.BytesIO(nested_zip_bytes), "r") as inner_zip: lingma_member = _pick_lingma_binary_path(inner_zip) lingma_bytes = inner_zip.read(lingma_member) + release_root = _infer_release_root(lingma_member) + lingma_bin.parent.mkdir(parents=True, exist_ok=True) + release_dir = lingma_bin.parent / (release_root or "2.5.20") + _extract_release_tree(inner_zip, release_root, release_dir) - lingma_bin.parent.mkdir(parents=True, exist_ok=True) lingma_bin.write_bytes(lingma_bytes) os.chmod(lingma_bin, 0o755) + extension_main = release_dir / "extension" / "main.js" + if extension_main.exists(): + print(f"[bootstrap] extension ready: {extension_main}") + else: + print(f"[bootstrap] extension missing under: {release_dir}") marker = { "source": source_type, @@ -174,6 +233,7 @@ def bootstrap_from_vsix() -> None: "downloaded_at": int(time.time()), "nested_zip": nested_zip_name, "member": lingma_member, + "release_root": release_root, "size": len(lingma_bytes), } marker.update(source_meta) diff --git a/app/http/tool_bridge.py b/app/http/tool_bridge.py index 8f3b8fd..6ab982d 100644 --- a/app/http/tool_bridge.py +++ b/app/http/tool_bridge.py @@ -2,6 +2,7 @@ from __future__ import annotations import ast import json +import re import uuid from typing import Any @@ -50,10 +51,16 @@ def _tool_event_allowed( *, forced_tool_name: str | None = None, ) -> bool: - if not (tool_config and isinstance(tool_config.get("tools"), list) and tool_config.get("tools")): + if not ( + tool_config + and isinstance(tool_config.get("tools"), list) + and tool_config.get("tools") + ): return True for tool in tool_config.get("tools") or []: - if tool_name == _anthropic_tool_name(tool) or tool_name == _openai_tool_name(tool): + if tool_name == _anthropic_tool_name(tool) or tool_name == _openai_tool_name( + tool + ): return True return bool(forced_tool_name and tool_name == forced_tool_name) @@ -67,7 +74,9 @@ def _allowed_tool_event( if not isinstance(tool, dict): return None tool_name = str(tool.get("name") or "") - if not _tool_event_allowed(tool_name, tool_config, forced_tool_name=forced_tool_name): + if not _tool_event_allowed( + tool_name, tool_config, forced_tool_name=forced_tool_name + ): return None return tool @@ -104,7 +113,9 @@ def _allowed_stream_tool_event( if not isinstance(tool, dict): return None tool_name = str(tool.get("name") or "") - if not _tool_event_allowed(tool_name, tool_config, forced_tool_name=forced_tool_name): + if not _tool_event_allowed( + tool_name, tool_config, forced_tool_name=forced_tool_name + ): return None return tool @@ -150,7 +161,9 @@ def _json_object_from_text(text: str) -> dict[str, Any] | None: return parsed if isinstance(parsed, dict) else None -def _tool_code_single_arg_name(tools: list[dict[str, Any]] | None, forced_tool_name: str) -> str | None: +def _tool_code_single_arg_name( + tools: list[dict[str, Any]] | None, forced_tool_name: str +) -> str | None: if not isinstance(tools, list): return None for tool in tools: @@ -228,7 +241,9 @@ def _forced_tool_event_from_text( ) -> dict[str, Any] | None: parsed = _json_object_from_text(text) if parsed is None: - parsed = _tool_code_object_from_text(text, forced_tool_name, single_arg_name=single_arg_name) + parsed = _tool_code_object_from_text( + text, forced_tool_name, single_arg_name=single_arg_name + ) if parsed is None: return None @@ -288,7 +303,9 @@ def _forced_tool_fallback_event( ) -def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> dict[str, Any]: +def _openai_tool_call( + tool: dict[str, Any], *, forced_id: str | None = None +) -> dict[str, Any]: return { "id": str(tool.get("id") or forced_id or f"call_{uuid.uuid4().hex}"), "type": "function", @@ -299,6 +316,42 @@ def _openai_tool_call(tool: dict[str, Any], *, forced_id: str | None = None) -> } +def _extract_function_call_event_from_text( + text: str, + *, + forced_tool_name: str | None, +) -> dict[str, Any] | None: + raw = (text or "").strip() + if not raw: + return None + m = re.search(r"\s*(\{.*?\})\s*", raw, flags=re.S) + if not m: + return None + try: + payload = json.loads(m.group(1)) + except Exception: + return None + if not isinstance(payload, dict): + return None + name = payload.get("name") + if not isinstance(name, str) or not name.strip(): + return None + name = name.strip() + if forced_tool_name and name != forced_tool_name: + return None + arguments = payload.get("arguments") + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except Exception: + return None + if arguments is None: + arguments = {} + if not isinstance(arguments, dict): + return None + return {"name": name, "input": arguments} + + def _anthropic_tool_use_block( tool: dict[str, Any], *, forced_id: str | None = None ) -> dict[str, Any]: diff --git a/app/main.py b/app/main.py index 15f70c1..6182397 100644 --- a/app/main.py +++ b/app/main.py @@ -43,12 +43,10 @@ from .http.tool_bridge import ( _anthropic_forced_tool_name, _anthropic_tool_result_block, _anthropic_tool_use_block, - _forced_tool_event_from_text, - _forced_tool_fallback_event, + _extract_function_call_event_from_text, _json_string, _openai_forced_tool_name, _openai_tool_call, - _tool_code_single_arg_name, ) from .http.tooling_policy import ( _anthropic_has_tooling_context, @@ -84,7 +82,9 @@ chat_guard = InFlightGuard( queue_timeout_sec=settings.gateway_queue_timeout_sec, ) session_cache = SessionCache( - max_entries=settings.session_cache_max_entries if settings.session_reuse_enabled else 0, + max_entries=settings.session_cache_max_entries + if settings.session_reuse_enabled + else 0, ttl_sec=settings.session_cache_ttl_sec, ) @@ -99,7 +99,12 @@ def _require_pool() -> LingmaPool: if pool is None: raise HTTPException( status_code=503, - detail={"error": {"message": "pool not initialized", "type": "service_unavailable"}}, + detail={ + "error": { + "message": "pool not initialized", + "type": "service_unavailable", + } + }, ) return pool @@ -254,7 +259,12 @@ async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: logger.warning("[%s] auth_status failed before chat: %s", inst.name, exc) raise HTTPException( status_code=503, - detail={"error": {"message": "Lingma is not ready", "type": "service_unavailable"}}, + detail={ + "error": { + "message": "Lingma is not ready", + "type": "service_unavailable", + } + }, ) if status and status.get("id"): @@ -263,13 +273,20 @@ async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: if not settings.auto_login_enabled: raise HTTPException( status_code=401, - detail={"error": {"message": "Lingma not logged in", "type": "invalid_request_error"}}, + detail={ + "error": { + "message": "Lingma not logged in", + "type": "invalid_request_error", + } + }, ) if settings.dedicated_domain_url: try: current = await client.get_endpoint() - current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + current_ep = ( + (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + ) if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: @@ -281,13 +298,23 @@ async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: logger.warning("[%s] generate_login_url failed: %s", inst.name, exc) raise HTTPException( status_code=502, - detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, + detail={ + "error": { + "message": "generate login url failed", + "type": "upstream_error", + } + }, ) if not login_url: raise HTTPException( status_code=502, - detail={"error": {"message": "generate login url failed", "type": "upstream_error"}}, + detail={ + "error": { + "message": "generate login url failed", + "type": "upstream_error", + } + }, ) await auto_login.ensure_started(login_url) @@ -312,7 +339,12 @@ async def _ensure_instance_logged_in(inst: PoolInstance) -> dict: ) raise HTTPException( status_code=401, - detail={"error": {"message": "Lingma auto login failed", "type": "invalid_request_error"}}, + detail={ + "error": { + "message": "Lingma auto login failed", + "type": "invalid_request_error", + } + }, ) @@ -416,7 +448,6 @@ async def _apply_cached_instance_or_invalidate( ) - def _streaming_response(event_stream) -> StreamingResponse: return StreamingResponse( event_stream, @@ -505,7 +536,12 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): except ValueError: raise HTTPException( status_code=400, - detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}}, + detail={ + "error": { + "message": "messages is empty", + "type": "invalid_request_error", + } + }, ) except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) @@ -534,7 +570,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): completion_tokens_holder = {"n": 0} stream_meta: dict = {} forced_tool_name = _openai_forced_tool_name(req.tool_choice) - forced_tool_single_arg_name = _tool_code_single_arg_name(req.tools, forced_tool_name) if forced_tool_name else None async def event_stream(_ticket=ticket, _inst=inst, _meta=stream_meta): success = False @@ -602,7 +637,9 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "tool_calls": [ { "index": idx, - **_openai_tool_call(tool, forced_id=tool_id), + **_openai_tool_call( + tool, forced_id=tool_id + ), } ] }, @@ -622,18 +659,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): continue yield _text_payload(text) - if buffered_text_parts and not saw_tool_call and forced_tool_name: - fallback_event = _forced_tool_event_from_text( + if buffered_text_parts and forced_tool_name and not saw_tool_call: + inferred = _extract_function_call_event_from_text( "".join(buffered_text_parts), - forced_tool_name, - single_arg_name=forced_tool_single_arg_name, + forced_tool_name=forced_tool_name, ) - if fallback_event is not None: + if inferred is not None: + tool_id = "call_inferred_0" + tool_call_indexes[tool_id] = 0 saw_tool_call = True - tool_id = "call_fallback_0" - idx = 0 - tool_call_indexes[tool_id] = idx - fallback_tool_call = _openai_tool_call(fallback_event, forced_id=tool_id) payload = { "id": completion_id, "object": "chat.completion.chunk", @@ -645,8 +679,10 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "delta": { "tool_calls": [ { - "index": idx, - **fallback_tool_call, + "index": 0, + **_openai_tool_call( + inferred, forced_id=tool_id + ), } ] }, @@ -669,7 +705,9 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): { "index": 0, "delta": {}, - "finish_reason": "tool_calls" if saw_tool_call else "stop", + "finish_reason": "tool_calls" + if saw_tool_call + else "stop", } ], } @@ -685,7 +723,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens_holder["n"], - "total_tokens": prompt_tokens + completion_tokens_holder["n"], + "total_tokens": prompt_tokens + + completion_tokens_holder["n"], }, } yield f"data: {json.dumps(usage_payload, ensure_ascii=False)}\n\n" @@ -738,7 +777,12 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): except UpstreamExecutionError: raise HTTPException( status_code=502, - detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, + detail={ + "error": { + "message": "upstream lingma error", + "type": "upstream_error", + } + }, ) result = completed.result @@ -757,13 +801,14 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): tool_calls.append(_openai_tool_call(item, forced_id=tool_id)) saw_tool_call = True if not saw_tool_call and forced_tool_name: - fallback_event = _forced_tool_fallback_event( + inferred = _extract_function_call_event_from_text( message_content, forced_tool_name=forced_tool_name, - tools=req.tools, ) - if fallback_event is not None: - tool_calls.append(_openai_tool_call(fallback_event, forced_id="call_fallback_0")) + if inferred is not None: + tool_calls.append( + _openai_tool_call(inferred, forced_id="call_inferred_0") + ) saw_tool_call = True message_content = "" response = ChatCompletionResponse( @@ -783,7 +828,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): ], ) - data = response.model_dump() data["latency"] = { "first_token_ms": result.get("firstTokenLatencyMs"), @@ -801,8 +845,6 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): release_execution(ticket=ticket, inst=inst) - - @app.post("/responses", dependencies=[Depends(auth_guard)]) @app.post("/v1/responses", dependencies=[Depends(auth_guard)]) async def v1_responses(req: ResponsesRequest, request: Request): @@ -814,7 +856,6 @@ async def v1_responses(req: ResponsesRequest, request: Request): ) - def _anthropic_error(status_code: int, error_type: str, message: str) -> JSONResponse: """Build an Anthropic-shaped error response (`type:error` envelope).""" return JSONResponse( @@ -879,15 +920,15 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): try: p = _require_pool() except HTTPException as exc: - return _anthropic_error(exc.status_code, "overloaded_error", "gateway not ready") + return _anthropic_error( + exc.status_code, "overloaded_error", "gateway not ready" + ) messages_dump = anthropic_to_internal_messages(req) # Prefer the auth token actually accepted so session-cache bucketing is # consistent regardless of which auth header style the caller used. api_key = ( - request.headers.get("x-api-key", "").strip() - or _extract_api_key(request) - or "-" + request.headers.get("x-api-key", "").strip() or _extract_api_key(request) or "-" ) # ------------------------------------------------------------- session reuse @@ -924,9 +965,15 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): messages_to_prompt=_messages_to_prompt, ) except HTTPException as exc: - err_type = "authentication_error" if exc.status_code == 401 else "overloaded_error" + err_type = ( + "authentication_error" if exc.status_code == 401 else "overloaded_error" + ) detail = exc.detail if isinstance(exc.detail, dict) else {} - msg = (detail.get("error") or {}).get("message") or str(detail) or "upstream error" + msg = ( + (detail.get("error") or {}).get("message") + or str(detail) + or "upstream error" + ) return _anthropic_error(exc.status_code, err_type, msg) ask_mode = execution.ask_mode write_key = execution.write_key @@ -950,7 +997,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): return _anthropic_error(400, "invalid_request_error", "messages is empty") except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) - logger.warning("anthropic rejected by backpressure, retry_after=%ds", retry_after) + logger.warning( + "anthropic rejected by backpressure, retry_after=%ds", retry_after + ) resp = _anthropic_error( 429, "overloaded_error", @@ -1016,7 +1065,10 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): if text_block_open: yield _sse( "content_block_stop", - {"type": "content_block_stop", "index": block_index}, + { + "type": "content_block_stop", + "index": block_index, + }, ) block_index += 1 text_block_open = False @@ -1029,9 +1081,13 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): if not tool: continue - tool_id = str(tool.get("id") or f"toolu_stream_{block_index}") + tool_id = str( + tool.get("id") or f"toolu_stream_{block_index}" + ) - tool_use_block = _anthropic_tool_use_block(tool, forced_id=tool_id) + tool_use_block = _anthropic_tool_use_block( + tool, forced_id=tool_id + ) yield _sse( "content_block_start", { @@ -1046,7 +1102,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ) block_index += 1 - tool_result_block = _anthropic_tool_result_block(tool, forced_id=tool_id) + tool_result_block = _anthropic_tool_result_block( + tool, forced_id=tool_id + ) if tool_result_block is not None: yield _sse( "content_block_start", @@ -1058,7 +1116,10 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ) yield _sse( "content_block_stop", - {"type": "content_block_stop", "index": block_index}, + { + "type": "content_block_stop", + "index": block_index, + }, ) block_index += 1 else: @@ -1114,7 +1175,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): }, ) - # 6) message_stop — terminal event, no [DONE] sentinel. yield _sse("message_stop", {"type": "message_stop"}) success = True @@ -1122,7 +1182,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): logger.info("anthropic.stream cancelled (inst=%s)", _inst.name) raise except Exception as exc: - logger.warning("anthropic.stream error (inst=%s): %s", _inst.name, exc) + logger.warning( + "anthropic.stream error (inst=%s): %s", _inst.name, exc + ) # Best-effort error frame. Anthropic clients treat any # unexpected event gracefully; we prefer visibility over # silent truncation. @@ -1155,7 +1217,6 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): ticket_transferred = True return _streaming_response(event_stream()) - try: completed = await complete_execution( protocol="anthropic", @@ -1196,22 +1257,18 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): saw_pending_tool_use = True if not saw_tool_event and forced_tool_name: - fallback_event = _forced_tool_fallback_event( + inferred = _extract_function_call_event_from_text( text, forced_tool_name=forced_tool_name, - tools=req.tools, ) - if fallback_event is not None: - content_blocks = [] - tool_id = "toolu_fallback_0" - content_blocks.append(_anthropic_tool_use_block(fallback_event, forced_id=tool_id)) - tool_result = _anthropic_tool_result_block(fallback_event, forced_id=tool_id) - saw_pending_tool_use = tool_result is None - if tool_result is not None: - content_blocks.append(tool_result) + if inferred is not None: + content_blocks = [ + _anthropic_tool_use_block(inferred, forced_id="toolu_inferred_0") + ] + saw_tool_event = True + saw_pending_tool_use = True response_body: dict = { - "id": message_id, "type": "message", "role": "assistant", @@ -1256,25 +1313,38 @@ async def internal_auto_login_start(instance: str | None = None): status = await client.auth_status() if status and status.get("id"): - return {"ok": True, "state": "already_logged_in", "instance": target.name, "auth": status} + return { + "ok": True, + "state": "already_logged_in", + "instance": target.name, + "auth": status, + } if settings.dedicated_domain_url: try: current = await client.get_endpoint() - current_ep = (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + current_ep = ( + (current or {}).get("endpoint", "") if isinstance(current, dict) else "" + ) if current_ep != settings.dedicated_domain_url: await client.update_endpoint(settings.dedicated_domain_url) except Exception as exc: - logger.warning("[%s] switch dedicated endpoint failed: %s", target.name, exc) + logger.warning( + "[%s] switch dedicated endpoint failed: %s", target.name, exc + ) try: login_url, _login_raw = await client.generate_login_url() except Exception as exc: logger.warning("[%s] generate_login_url failed: %s", target.name, exc) - raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) + raise HTTPException( + status_code=502, detail={"error": {"message": "generate login url failed"}} + ) if not login_url: - raise HTTPException(status_code=502, detail={"error": {"message": "generate login url failed"}}) + raise HTTPException( + status_code=502, detail={"error": {"message": "generate login url failed"}} + ) started = await auto_login.ensure_started(login_url) return { @@ -1327,7 +1397,9 @@ async def internal_session_export(instance: str | None = None): target = inst break if target is None: - raise HTTPException(status_code=404, detail={"error": f"instance {instance} not found"}) + raise HTTPException( + status_code=404, detail={"error": f"instance {instance} not found"} + ) else: target = p.pick() @@ -1380,7 +1452,9 @@ async def internal_models_raw(instance: str | None = None): target = inst break if target is None: - raise HTTPException(status_code=404, detail={"error": f"instance {instance} not found"}) + raise HTTPException( + status_code=404, detail={"error": f"instance {instance} not found"} + ) else: target = p.pick() await _ensure_instance_logged_in(target) @@ -1414,4 +1488,6 @@ async def metrics(): lines.extend(pool.prometheus_lines()) lines.extend(session_cache.prometheus_lines()) extra = "\n".join(lines) + "\n" - return StreamingResponse(iter([base + extra]), media_type="text/plain; version=0.0.4") + return StreamingResponse( + iter([base + extra]), media_type="text/plain; version=0.0.4" + )