from __future__ import annotations from dataclasses import dataclass from typing import Any, Awaitable, Callable from ..concurrency import InFlightGuard from ..lingma_pool import LingmaPool, PoolInstance from ..model_map import build_model_name_map, flatten_model_keys, resolve_model from ..session_cache import SessionCache, hash_branch_context @dataclass class ExecutionContext: ask_mode: str lookup_key: str | None write_key: str | None cached_session_id: str | None inst: PoolInstance model: str prompt: str is_reply: bool affinity: str | None @dataclass class StartedExecution: ticket: Any prompt_tokens: int @dataclass class CompletedExecution: result: dict[str, Any] completion_tokens: int class UpstreamExecutionError(Exception): pass def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str: model_name = (model or "").lower() if model_name in {"lingma-agent", "agent"} or has_tooling_context: return "agent" return default_ask_mode def _tool_config_summary(tool_config: dict[str, Any] | None) -> dict[str, Any]: if not isinstance(tool_config, dict): return {"present": False, "provider": None, "tool_names": [], "tool_choice": None} tools = tool_config.get("tools") tool_names: list[str] = [] if isinstance(tools, list): for tool in tools: if not isinstance(tool, dict): continue if tool.get("type") == "function": fn = tool.get("function") if isinstance(fn, dict) and isinstance(fn.get("name"), str) and fn.get("name").strip(): tool_names.append(fn.get("name").strip()) continue name = tool.get("name") if isinstance(name, str) and name.strip(): tool_names.append(name.strip()) return { "present": True, "provider": tool_config.get("provider"), "tool_names": tool_names, "tool_choice": tool_config.get("tool_choice"), } async def _apply_cached_instance_or_invalidate( *, protocol: str, logger: Any, session_cache: SessionCache, inst: PoolInstance, cached_instance_name: str | None, cached_session_id: str | None, lookup_key: str | None, ) -> str | None: if cached_instance_name and inst.name != cached_instance_name: logger.info( "%s session cache instance %s unhealthy, falling back to %s", protocol, cached_instance_name, inst.name, ) if lookup_key: await session_cache.invalidate(lookup_key) return None return cached_session_id async def prepare_execution_context( *, protocol: str, requested_model: str, has_tooling_context: bool, tool_config: dict[str, Any] | None, messages_dump: list[dict[str, Any]], api_key: str, affinity_key: str | None, pool: LingmaPool, session_cache: SessionCache, logger: Any, default_model: str, default_ask_mode: str, ensure_instance_logged_in: Callable[[PoolInstance], Awaitable[Any]], last_user_text: Callable[[list[dict[str, Any]]], str], messages_to_prompt: Callable[[list[dict[str, Any]]], str], ) -> ExecutionContext: ask_mode = _resolve_ask_mode( requested_model, has_tooling_context, default_ask_mode=default_ask_mode, ) logger.info( "%s.prepare requested_model=%s ask_mode=%s tooling=%s tool_config=%s", protocol, requested_model, ask_mode, has_tooling_context, _tool_config_summary(tool_config), ) reuse_eligible = ( session_cache.enabled and ask_mode == "chat" and len(messages_dump) >= 2 and not has_tooling_context ) lookup_key: str | None = None write_key: str | None = None cached_session_id: str | None = None cached_instance_name: str | None = None if reuse_eligible: prefix_branch_context = hash_branch_context(messages_dump[:-1]) lookup_key = session_cache.build_key( api_key, messages_dump[:-1], tool_config=tool_config, branch_context=prefix_branch_context, ) write_key = session_cache.build_key( api_key, messages_dump, tool_config=tool_config, branch_context=hash_branch_context(messages_dump), ) entry = await session_cache.get(lookup_key) if entry is None: legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config) entry = await session_cache.get(legacy_lookup_key) if entry is not None: lookup_key = legacy_lookup_key if entry is not None: cached_session_id = entry.session_id cached_instance_name = entry.instance_name or None affinity = cached_instance_name or affinity_key inst = pool.pick(affinity_key=affinity) cached_session_id = await _apply_cached_instance_or_invalidate( protocol=protocol, logger=logger, session_cache=session_cache, inst=inst, cached_instance_name=cached_instance_name, cached_session_id=cached_session_id, lookup_key=lookup_key, ) await ensure_instance_logged_in(inst) models = await inst.client.query_models() available = flatten_model_keys(models) name_map = build_model_name_map(models) model = resolve_model(requested_model, available, default_model, name_map) if cached_session_id: prompt = last_user_text(messages_dump) is_reply = True else: prompt = messages_to_prompt(messages_dump) is_reply = False logger.info( "%s.context inst=%s model=%s ask_mode=%s reuse_eligible=%s reused_session=%s affinity=%s", protocol, inst.name, model, ask_mode, reuse_eligible, bool(cached_session_id), affinity, ) return ExecutionContext( ask_mode=ask_mode, lookup_key=lookup_key, write_key=write_key, cached_session_id=cached_session_id, inst=inst, model=model, prompt=prompt, is_reply=is_reply, affinity=affinity, ) async def start_execution( *, protocol: str, execution: ExecutionContext, stream: bool, chat_guard: InFlightGuard, logger: Any, estimate_tokens: Callable[[str], int], extra_log_context: dict[str, Any] | None = None, ) -> StartedExecution: if not execution.prompt: raise ValueError("messages is empty") prompt_tokens = estimate_tokens(execution.prompt) ticket = await chat_guard.try_acquire() execution.inst.in_flight += 1 log_extra = { "ctx_instance": execution.inst.name, "ctx_model": execution.model, "ctx_ask_mode": execution.ask_mode, "ctx_stream": stream, "ctx_prompt_tokens": prompt_tokens, "ctx_in_flight": chat_guard.in_flight, "ctx_affinity": execution.affinity, "ctx_session_reuse": bool(execution.cached_session_id), } if extra_log_context: log_extra.update(extra_log_context) logger.info( "%s.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s", protocol, execution.inst.name, execution.model, execution.ask_mode, stream, prompt_tokens, bool(execution.cached_session_id), extra=log_extra, ) return StartedExecution(ticket=ticket, prompt_tokens=prompt_tokens) async def complete_execution( *, protocol: str, execution: ExecutionContext, prompt_tokens: int, tool_config: dict[str, Any] | None, logger: Any, stats_collector: Any, session_cache: SessionCache, estimate_tokens: Callable[[str], int], ) -> CompletedExecution: try: logger.info( "%s.complete inst=%s ask_mode=%s tool_config=%s", protocol, execution.inst.name, execution.ask_mode, _tool_config_summary(tool_config), ) result = await execution.inst.client.chat_complete( execution.prompt, execution.model, execution.ask_mode, session_id=execution.cached_session_id, is_reply=execution.is_reply, tool_config=tool_config, ) except Exception as exc: logger.warning("%s.complete error (inst=%s): %s", protocol, execution.inst.name, exc) await stats_collector.record_chat( stream=False, success=False, prompt_tokens=prompt_tokens, completion_tokens=0, ) if execution.cached_session_id and execution.lookup_key: await session_cache.invalidate(execution.lookup_key) raise UpstreamExecutionError from exc completion_tokens = estimate_tokens(result.get("text") or "") await stats_collector.record_chat( stream=False, success=True, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) if execution.write_key: sid = result.get("sessionId") if sid: await session_cache.put(execution.write_key, sid, execution.inst.name) return CompletedExecution(result=result, completion_tokens=completion_tokens) async def finalize_stream_execution( *, success: bool, write_key: str | None, session_id: str | None, inst: PoolInstance, ticket: Any, session_cache: SessionCache, stats_collector: Any, prompt_tokens: int, completion_tokens: int, ) -> None: if success and write_key and session_id: await session_cache.put(write_key, session_id, inst.name) await stats_collector.record_chat( stream=True, success=success, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) release_execution(ticket=ticket, inst=inst) def release_execution(*, ticket: Any, inst: PoolInstance) -> None: inst.in_flight = max(0, inst.in_flight - 1) ticket.release()