Files
lingma-openai-gateway/app/http/execution_core.py
GitHub Actions 109c34a8dc refactor: share request execution lifecycle
Extract the shared request startup, completion, and cleanup flow so OpenAI and Anthropic routes keep the same wire behavior with less duplicated orchestration.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-04-23 18:44:40 +08:00

282 lines
8.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Awaitable, Callable
from ..concurrency import InFlightGuard
from ..lingma_pool import LingmaPool, PoolInstance
from ..model_map import build_model_name_map, flatten_model_keys, resolve_model
from ..session_cache import SessionCache, hash_branch_context
@dataclass
class ExecutionContext:
ask_mode: str
lookup_key: str | None
write_key: str | None
cached_session_id: str | None
inst: PoolInstance
model: str
prompt: str
is_reply: bool
affinity: str | None
@dataclass
class StartedExecution:
ticket: Any
prompt_tokens: int
@dataclass
class CompletedExecution:
result: dict[str, Any]
completion_tokens: int
class UpstreamExecutionError(Exception):
pass
def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str:
model_name = (model or "").lower()
if model_name in {"lingma-agent", "agent"} or has_tooling_context:
return "agent"
return default_ask_mode
async def _apply_cached_instance_or_invalidate(
*,
protocol: str,
logger: Any,
session_cache: SessionCache,
inst: PoolInstance,
cached_instance_name: str | None,
cached_session_id: str | None,
lookup_key: str | None,
) -> str | None:
if cached_instance_name and inst.name != cached_instance_name:
logger.info(
"%s session cache instance %s unhealthy, falling back to %s",
protocol,
cached_instance_name,
inst.name,
)
if lookup_key:
await session_cache.invalidate(lookup_key)
return None
return cached_session_id
async def prepare_execution_context(
*,
protocol: str,
requested_model: str,
has_tooling_context: bool,
tool_config: dict[str, Any] | None,
messages_dump: list[dict[str, Any]],
api_key: str,
affinity_key: str | None,
pool: LingmaPool,
session_cache: SessionCache,
logger: Any,
default_model: str,
default_ask_mode: str,
ensure_instance_logged_in: Callable[[PoolInstance], Awaitable[Any]],
last_user_text: Callable[[list[dict[str, Any]]], str],
messages_to_prompt: Callable[[list[dict[str, Any]]], str],
) -> ExecutionContext:
ask_mode = _resolve_ask_mode(
requested_model,
has_tooling_context,
default_ask_mode=default_ask_mode,
)
reuse_eligible = (
session_cache.enabled
and ask_mode == "chat"
and len(messages_dump) >= 2
and not has_tooling_context
)
lookup_key: str | None = None
write_key: str | None = None
cached_session_id: str | None = None
cached_instance_name: str | None = None
if reuse_eligible:
prefix_branch_context = hash_branch_context(messages_dump[:-1])
lookup_key = session_cache.build_key(
api_key,
messages_dump[:-1],
tool_config=tool_config,
branch_context=prefix_branch_context,
)
write_key = session_cache.build_key(
api_key,
messages_dump,
tool_config=tool_config,
branch_context=hash_branch_context(messages_dump),
)
entry = await session_cache.get(lookup_key)
if entry is None:
legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
entry = await session_cache.get(legacy_lookup_key)
if entry is not None:
lookup_key = legacy_lookup_key
if entry is not None:
cached_session_id = entry.session_id
cached_instance_name = entry.instance_name or None
affinity = cached_instance_name or affinity_key
inst = pool.pick(affinity_key=affinity)
cached_session_id = await _apply_cached_instance_or_invalidate(
protocol=protocol,
logger=logger,
session_cache=session_cache,
inst=inst,
cached_instance_name=cached_instance_name,
cached_session_id=cached_session_id,
lookup_key=lookup_key,
)
await ensure_instance_logged_in(inst)
models = await inst.client.query_models()
available = flatten_model_keys(models)
name_map = build_model_name_map(models)
model = resolve_model(requested_model, available, default_model, name_map)
if cached_session_id:
prompt = last_user_text(messages_dump)
is_reply = True
else:
prompt = messages_to_prompt(messages_dump)
is_reply = False
return ExecutionContext(
ask_mode=ask_mode,
lookup_key=lookup_key,
write_key=write_key,
cached_session_id=cached_session_id,
inst=inst,
model=model,
prompt=prompt,
is_reply=is_reply,
affinity=affinity,
)
async def start_execution(
*,
protocol: str,
execution: ExecutionContext,
stream: bool,
chat_guard: InFlightGuard,
logger: Any,
estimate_tokens: Callable[[str], int],
extra_log_context: dict[str, Any] | None = None,
) -> StartedExecution:
if not execution.prompt:
raise ValueError("messages is empty")
prompt_tokens = estimate_tokens(execution.prompt)
ticket = await chat_guard.try_acquire()
execution.inst.in_flight += 1
log_extra = {
"ctx_instance": execution.inst.name,
"ctx_model": execution.model,
"ctx_ask_mode": execution.ask_mode,
"ctx_stream": stream,
"ctx_prompt_tokens": prompt_tokens,
"ctx_in_flight": chat_guard.in_flight,
"ctx_affinity": execution.affinity,
"ctx_session_reuse": bool(execution.cached_session_id),
}
if extra_log_context:
log_extra.update(extra_log_context)
logger.info(
"%s.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
protocol,
execution.inst.name,
execution.model,
execution.ask_mode,
stream,
prompt_tokens,
bool(execution.cached_session_id),
extra=log_extra,
)
return StartedExecution(ticket=ticket, prompt_tokens=prompt_tokens)
async def complete_execution(
*,
protocol: str,
execution: ExecutionContext,
prompt_tokens: int,
tool_config: dict[str, Any] | None,
logger: Any,
stats_collector: Any,
session_cache: SessionCache,
estimate_tokens: Callable[[str], int],
) -> CompletedExecution:
try:
result = await execution.inst.client.chat_complete(
execution.prompt,
execution.model,
execution.ask_mode,
session_id=execution.cached_session_id,
is_reply=execution.is_reply,
tool_config=tool_config,
)
except Exception as exc:
logger.warning("%s.complete error (inst=%s): %s", protocol, execution.inst.name, exc)
await stats_collector.record_chat(
stream=False,
success=False,
prompt_tokens=prompt_tokens,
completion_tokens=0,
)
if execution.cached_session_id and execution.lookup_key:
await session_cache.invalidate(execution.lookup_key)
raise UpstreamExecutionError from exc
completion_tokens = estimate_tokens(result.get("text") or "")
await stats_collector.record_chat(
stream=False,
success=True,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
if execution.write_key:
sid = result.get("sessionId")
if sid:
await session_cache.put(execution.write_key, sid, execution.inst.name)
return CompletedExecution(result=result, completion_tokens=completion_tokens)
async def finalize_stream_execution(
*,
success: bool,
write_key: str | None,
session_id: str | None,
inst: PoolInstance,
ticket: Any,
session_cache: SessionCache,
stats_collector: Any,
prompt_tokens: int,
completion_tokens: int,
) -> None:
if success and write_key and session_id:
await session_cache.put(write_key, session_id, inst.name)
await stats_collector.record_chat(
stream=True,
success=success,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
release_execution(ticket=ticket, inst=inst)
def release_execution(*, ticket: Any, inst: PoolInstance) -> None:
inst.in_flight = max(0, inst.in_flight - 1)
ticket.release()