refactor: share request execution lifecycle
Extract the shared request startup, completion, and cleanup flow so OpenAI and Anthropic routes keep the same wire behavior with less duplicated orchestration. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from ..concurrency import InFlightGuard
|
||||
from ..lingma_pool import LingmaPool, PoolInstance
|
||||
from ..model_map import build_model_name_map, flatten_model_keys, resolve_model
|
||||
from ..session_cache import SessionCache, hash_branch_context
|
||||
@@ -21,6 +22,22 @@ class ExecutionContext:
|
||||
affinity: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StartedExecution:
|
||||
ticket: Any
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletedExecution:
|
||||
result: dict[str, Any]
|
||||
completion_tokens: int
|
||||
|
||||
|
||||
class UpstreamExecutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str:
|
||||
model_name = (model or "").lower()
|
||||
if model_name in {"lingma-agent", "agent"} or has_tooling_context:
|
||||
@@ -146,3 +163,119 @@ async def prepare_execution_context(
|
||||
is_reply=is_reply,
|
||||
affinity=affinity,
|
||||
)
|
||||
|
||||
|
||||
async def start_execution(
|
||||
*,
|
||||
protocol: str,
|
||||
execution: ExecutionContext,
|
||||
stream: bool,
|
||||
chat_guard: InFlightGuard,
|
||||
logger: Any,
|
||||
estimate_tokens: Callable[[str], int],
|
||||
extra_log_context: dict[str, Any] | None = None,
|
||||
) -> StartedExecution:
|
||||
if not execution.prompt:
|
||||
raise ValueError("messages is empty")
|
||||
|
||||
prompt_tokens = estimate_tokens(execution.prompt)
|
||||
ticket = await chat_guard.try_acquire()
|
||||
execution.inst.in_flight += 1
|
||||
log_extra = {
|
||||
"ctx_instance": execution.inst.name,
|
||||
"ctx_model": execution.model,
|
||||
"ctx_ask_mode": execution.ask_mode,
|
||||
"ctx_stream": stream,
|
||||
"ctx_prompt_tokens": prompt_tokens,
|
||||
"ctx_in_flight": chat_guard.in_flight,
|
||||
"ctx_affinity": execution.affinity,
|
||||
"ctx_session_reuse": bool(execution.cached_session_id),
|
||||
}
|
||||
if extra_log_context:
|
||||
log_extra.update(extra_log_context)
|
||||
logger.info(
|
||||
"%s.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
|
||||
protocol,
|
||||
execution.inst.name,
|
||||
execution.model,
|
||||
execution.ask_mode,
|
||||
stream,
|
||||
prompt_tokens,
|
||||
bool(execution.cached_session_id),
|
||||
extra=log_extra,
|
||||
)
|
||||
return StartedExecution(ticket=ticket, prompt_tokens=prompt_tokens)
|
||||
|
||||
|
||||
async def complete_execution(
|
||||
*,
|
||||
protocol: str,
|
||||
execution: ExecutionContext,
|
||||
prompt_tokens: int,
|
||||
tool_config: dict[str, Any] | None,
|
||||
logger: Any,
|
||||
stats_collector: Any,
|
||||
session_cache: SessionCache,
|
||||
estimate_tokens: Callable[[str], int],
|
||||
) -> CompletedExecution:
|
||||
try:
|
||||
result = await execution.inst.client.chat_complete(
|
||||
execution.prompt,
|
||||
execution.model,
|
||||
execution.ask_mode,
|
||||
session_id=execution.cached_session_id,
|
||||
is_reply=execution.is_reply,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("%s.complete error (inst=%s): %s", protocol, execution.inst.name, exc)
|
||||
await stats_collector.record_chat(
|
||||
stream=False,
|
||||
success=False,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
)
|
||||
if execution.cached_session_id and execution.lookup_key:
|
||||
await session_cache.invalidate(execution.lookup_key)
|
||||
raise UpstreamExecutionError from exc
|
||||
|
||||
completion_tokens = estimate_tokens(result.get("text") or "")
|
||||
await stats_collector.record_chat(
|
||||
stream=False,
|
||||
success=True,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
if execution.write_key:
|
||||
sid = result.get("sessionId")
|
||||
if sid:
|
||||
await session_cache.put(execution.write_key, sid, execution.inst.name)
|
||||
return CompletedExecution(result=result, completion_tokens=completion_tokens)
|
||||
|
||||
|
||||
async def finalize_stream_execution(
|
||||
*,
|
||||
success: bool,
|
||||
write_key: str | None,
|
||||
session_id: str | None,
|
||||
inst: PoolInstance,
|
||||
ticket: Any,
|
||||
session_cache: SessionCache,
|
||||
stats_collector: Any,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> None:
|
||||
if success and write_key and session_id:
|
||||
await session_cache.put(write_key, session_id, inst.name)
|
||||
await stats_collector.record_chat(
|
||||
stream=True,
|
||||
success=success,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
release_execution(ticket=ticket, inst=inst)
|
||||
|
||||
|
||||
def release_execution(*, ticket: Any, inst: PoolInstance) -> None:
|
||||
inst.in_flight = max(0, inst.in_flight - 1)
|
||||
ticket.release()
|
||||
|
||||
Reference in New Issue
Block a user