diff --git a/app/http/execution_core.py b/app/http/execution_core.py index 0591a82..edb6497 100644 --- a/app/http/execution_core.py +++ b/app/http/execution_core.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass from typing import Any, Awaitable, Callable +from ..concurrency import InFlightGuard from ..lingma_pool import LingmaPool, PoolInstance from ..model_map import build_model_name_map, flatten_model_keys, resolve_model from ..session_cache import SessionCache, hash_branch_context @@ -21,6 +22,22 @@ class ExecutionContext: affinity: str | None +@dataclass +class StartedExecution: + ticket: Any + prompt_tokens: int + + +@dataclass +class CompletedExecution: + result: dict[str, Any] + completion_tokens: int + + +class UpstreamExecutionError(Exception): + pass + + def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str: model_name = (model or "").lower() if model_name in {"lingma-agent", "agent"} or has_tooling_context: @@ -146,3 +163,119 @@ async def prepare_execution_context( is_reply=is_reply, affinity=affinity, ) + + +async def start_execution( + *, + protocol: str, + execution: ExecutionContext, + stream: bool, + chat_guard: InFlightGuard, + logger: Any, + estimate_tokens: Callable[[str], int], + extra_log_context: dict[str, Any] | None = None, +) -> StartedExecution: + if not execution.prompt: + raise ValueError("messages is empty") + + prompt_tokens = estimate_tokens(execution.prompt) + ticket = await chat_guard.try_acquire() + execution.inst.in_flight += 1 + log_extra = { + "ctx_instance": execution.inst.name, + "ctx_model": execution.model, + "ctx_ask_mode": execution.ask_mode, + "ctx_stream": stream, + "ctx_prompt_tokens": prompt_tokens, + "ctx_in_flight": chat_guard.in_flight, + "ctx_affinity": execution.affinity, + "ctx_session_reuse": bool(execution.cached_session_id), + } + if extra_log_context: + log_extra.update(extra_log_context) + logger.info( + "%s.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s", + protocol, + execution.inst.name, + execution.model, + execution.ask_mode, + stream, + prompt_tokens, + bool(execution.cached_session_id), + extra=log_extra, + ) + return StartedExecution(ticket=ticket, prompt_tokens=prompt_tokens) + + +async def complete_execution( + *, + protocol: str, + execution: ExecutionContext, + prompt_tokens: int, + tool_config: dict[str, Any] | None, + logger: Any, + stats_collector: Any, + session_cache: SessionCache, + estimate_tokens: Callable[[str], int], +) -> CompletedExecution: + try: + result = await execution.inst.client.chat_complete( + execution.prompt, + execution.model, + execution.ask_mode, + session_id=execution.cached_session_id, + is_reply=execution.is_reply, + tool_config=tool_config, + ) + except Exception as exc: + logger.warning("%s.complete error (inst=%s): %s", protocol, execution.inst.name, exc) + await stats_collector.record_chat( + stream=False, + success=False, + prompt_tokens=prompt_tokens, + completion_tokens=0, + ) + if execution.cached_session_id and execution.lookup_key: + await session_cache.invalidate(execution.lookup_key) + raise UpstreamExecutionError from exc + + completion_tokens = estimate_tokens(result.get("text") or "") + await stats_collector.record_chat( + stream=False, + success=True, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + if execution.write_key: + sid = result.get("sessionId") + if sid: + await session_cache.put(execution.write_key, sid, execution.inst.name) + return CompletedExecution(result=result, completion_tokens=completion_tokens) + + +async def finalize_stream_execution( + *, + success: bool, + write_key: str | None, + session_id: str | None, + inst: PoolInstance, + ticket: Any, + session_cache: SessionCache, + stats_collector: Any, + prompt_tokens: int, + completion_tokens: int, +) -> None: + if success and write_key and session_id: + await session_cache.put(write_key, session_id, inst.name) + await stats_collector.record_chat( + stream=True, + success=success, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + release_execution(ticket=ticket, inst=inst) + + +def release_execution(*, ticket: Any, inst: PoolInstance) -> None: + inst.in_flight = max(0, inst.in_flight - 1) + ticket.release() diff --git a/app/main.py b/app/main.py index a0c2dcc..f54ee8d 100644 --- a/app/main.py +++ b/app/main.py @@ -28,7 +28,12 @@ from .config import Settings, load_settings from .http.execution_core import ( _apply_cached_instance_or_invalidate as _shared_apply_cached_instance_or_invalidate, _resolve_ask_mode as _shared_resolve_ask_mode, + UpstreamExecutionError, + complete_execution, + finalize_stream_execution, prepare_execution_context, + release_execution, + start_execution, ) from .http.openai_responses import handle_responses from .http.tool_bridge import ( @@ -472,27 +477,29 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): messages_to_prompt=_messages_to_prompt, ) ask_mode = execution.ask_mode - lookup_key = execution.lookup_key write_key = execution.write_key cached_session_id = execution.cached_session_id inst = execution.inst model = execution.model prompt = execution.prompt is_reply = execution.is_reply - affinity = execution.affinity - if not prompt: + include_usage = _include_usage(req.stream_options) + + try: + started = await start_execution( + protocol="chat", + execution=execution, + stream=req.stream, + chat_guard=chat_guard, + logger=logger, + estimate_tokens=estimate_tokens, + ) + except ValueError: raise HTTPException( status_code=400, detail={"error": {"message": "messages is empty", "type": "invalid_request_error"}}, ) - prompt_tokens = estimate_tokens(prompt) - include_usage = _include_usage(req.stream_options) - - # Backpressure: acquire a slot *after* the cheap validation but before any - # upstream call. This ensures we reject quickly when saturated. - try: - ticket = await chat_guard.try_acquire() except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) logger.warning("chat rejected by backpressure, retry_after=%ds", retry_after) @@ -508,26 +515,8 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): headers={"Retry-After": str(retry_after)}, ) - inst.in_flight += 1 - logger.info( - "chat.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s", - inst.name, - model, - ask_mode, - req.stream, - prompt_tokens, - bool(cached_session_id), - extra={ - "ctx_instance": inst.name, - "ctx_model": model, - "ctx_ask_mode": ask_mode, - "ctx_stream": req.stream, - "ctx_prompt_tokens": prompt_tokens, - "ctx_in_flight": chat_guard.in_flight, - "ctx_affinity": affinity, - "ctx_session_reuse": bool(cached_session_id), - }, - ) + ticket = started.ticket + prompt_tokens = started.prompt_tokens ticket_transferred = False @@ -715,59 +704,40 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): exc, ) finally: - if success and write_key: - sid = _meta.get("session_id") - if sid: - await session_cache.put(write_key, sid, _inst.name) - await stats_collector.record_chat( - stream=True, + await finalize_stream_execution( success=success, + write_key=write_key, + session_id=_meta.get("session_id"), + inst=_inst, + ticket=_ticket, + session_cache=session_cache, + stats_collector=stats_collector, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens_holder["n"], ) - _inst.in_flight = max(0, _inst.in_flight - 1) - _ticket.release() ticket_transferred = True return _streaming_response(event_stream()) try: - result = await inst.client.chat_complete( - prompt, - model, - ask_mode, - session_id=cached_session_id, - is_reply=is_reply, - tool_config=tool_config, - ) - except Exception as exc: - logger.warning("chat.complete error (inst=%s): %s", inst.name, exc) - await stats_collector.record_chat( - stream=False, - success=False, + completed = await complete_execution( + protocol="chat", + execution=execution, prompt_tokens=prompt_tokens, - completion_tokens=0, + tool_config=tool_config, + logger=logger, + stats_collector=stats_collector, + session_cache=session_cache, + estimate_tokens=estimate_tokens, ) - # If we used a cached session and the call blew up, drop it so the - # next turn can start fresh instead of hitting the same dead session. - if cached_session_id and lookup_key: - await session_cache.invalidate(lookup_key) + except UpstreamExecutionError: raise HTTPException( status_code=502, detail={"error": {"message": "upstream lingma error", "type": "upstream_error"}}, ) - completion_tokens = estimate_tokens(result.get("text") or "") - await stats_collector.record_chat( - stream=False, - success=True, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - if write_key: - sid = result.get("sessionId") - if sid: - await session_cache.put(write_key, sid, inst.name) + result = completed.result + completion_tokens = completed.completion_tokens forced_tool_name = _openai_forced_tool_name(req.tool_choice) tool_events = _allowed_tool_events( result.get("toolEvents"), @@ -823,8 +793,7 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request): return JSONResponse(content=data) finally: if not ticket_transferred: - inst.in_flight = max(0, inst.in_flight - 1) - ticket.release() + release_execution(ticket=ticket, inst=inst) @@ -949,22 +918,25 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): msg = (detail.get("error") or {}).get("message") or str(detail) or "upstream error" return _anthropic_error(exc.status_code, err_type, msg) ask_mode = execution.ask_mode - lookup_key = execution.lookup_key write_key = execution.write_key cached_session_id = execution.cached_session_id inst = execution.inst model = execution.model prompt = execution.prompt is_reply = execution.is_reply - affinity = execution.affinity - if not prompt: - return _anthropic_error(400, "invalid_request_error", "messages is empty") - - prompt_tokens = estimate_tokens(prompt) - # ------------------------------------------------------------- backpressure try: - ticket = await chat_guard.try_acquire() + started = await start_execution( + protocol="anthropic", + execution=execution, + stream=req.stream, + chat_guard=chat_guard, + logger=logger, + estimate_tokens=estimate_tokens, + extra_log_context={"ctx_api": "anthropic"}, + ) + except ValueError: + return _anthropic_error(400, "invalid_request_error", "messages is empty") except BackpressureRejected as exc: retry_after = max(1, int(exc.retry_after)) logger.warning("anthropic rejected by backpressure, retry_after=%ds", retry_after) @@ -976,27 +948,9 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): resp.headers["Retry-After"] = str(retry_after) return resp - inst.in_flight += 1 + ticket = started.ticket + prompt_tokens = started.prompt_tokens message_id = f"msg_{uuid.uuid4().hex}" - logger.info( - "anthropic.start inst=%s model=%s stream=%s prompt_tokens~%d reuse=%s", - inst.name, - model, - req.stream, - prompt_tokens, - bool(cached_session_id), - extra={ - "ctx_instance": inst.name, - "ctx_model": model, - "ctx_ask_mode": ask_mode, - "ctx_stream": req.stream, - "ctx_prompt_tokens": prompt_tokens, - "ctx_in_flight": chat_guard.in_flight, - "ctx_affinity": affinity, - "ctx_session_reuse": bool(cached_session_id), - "ctx_api": "anthropic", - }, - ) ticket_transferred = False @@ -1175,59 +1129,39 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): except Exception: pass finally: - # Session write-back only on clean finish — partial streams - # leave Lingma's session in an indeterminate state. - if success and write_key: - sid = _meta.get("session_id") - if sid: - await session_cache.put(write_key, sid, _inst.name) - await stats_collector.record_chat( - stream=True, + await finalize_stream_execution( success=success, + write_key=write_key, + session_id=_meta.get("session_id"), + inst=_inst, + ticket=_ticket, + session_cache=session_cache, + stats_collector=stats_collector, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens_holder["n"], ) - _inst.in_flight = max(0, _inst.in_flight - 1) - _ticket.release() ticket_transferred = True return _streaming_response(event_stream()) - # ------------------------------------------------------------- non-stream try: - result = await inst.client.chat_complete( - prompt, - model, - ask_mode, - session_id=cached_session_id, - is_reply=is_reply, - tool_config=tool_config, - ) - except Exception as exc: - logger.warning("anthropic.complete error (inst=%s): %s", inst.name, exc) - await stats_collector.record_chat( - stream=False, - success=False, + completed = await complete_execution( + protocol="anthropic", + execution=execution, prompt_tokens=prompt_tokens, - completion_tokens=0, + tool_config=tool_config, + logger=logger, + stats_collector=stats_collector, + session_cache=session_cache, + estimate_tokens=estimate_tokens, ) - if cached_session_id and lookup_key: - await session_cache.invalidate(lookup_key) + except UpstreamExecutionError: return _anthropic_error(502, "api_error", "upstream lingma error") + result = completed.result text = result.get("text") or "" - completion_tokens = estimate_tokens(text) - await stats_collector.record_chat( - stream=False, - success=True, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - ) - if write_key: - sid = result.get("sessionId") - if sid: - await session_cache.put(write_key, sid, inst.name) + completion_tokens = completed.completion_tokens content_blocks: list[dict[str, Any]] = [] if text: @@ -1286,8 +1220,7 @@ async def v1_messages(req: AnthropicMessagesRequest, request: Request): return JSONResponse(content=response_body) finally: if not ticket_transferred: - inst.in_flight = max(0, inst.in_flight - 1) - ticket.release() + release_execution(ticket=ticket, inst=inst) @app.post("/internal/auto-login/start", dependencies=[Depends(admin_auth_guard)])