333 lines
10 KiB
Python
333 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from ..concurrency import InFlightGuard
|
|
from ..lingma_pool import LingmaPool, PoolInstance
|
|
from ..model_map import build_model_name_map, flatten_model_keys, resolve_model
|
|
from ..session_cache import SessionCache, hash_branch_context
|
|
|
|
|
|
@dataclass
|
|
class ExecutionContext:
|
|
ask_mode: str
|
|
lookup_key: str | None
|
|
write_key: str | None
|
|
cached_session_id: str | None
|
|
inst: PoolInstance
|
|
model: str
|
|
prompt: str
|
|
is_reply: bool
|
|
affinity: str | None
|
|
|
|
|
|
@dataclass
|
|
class StartedExecution:
|
|
ticket: Any
|
|
prompt_tokens: int
|
|
|
|
|
|
@dataclass
|
|
class CompletedExecution:
|
|
result: dict[str, Any]
|
|
completion_tokens: int
|
|
|
|
|
|
class UpstreamExecutionError(Exception):
|
|
pass
|
|
|
|
|
|
def _resolve_ask_mode(model: str, has_tooling_context: bool, *, default_ask_mode: str) -> str:
|
|
model_name = (model or "").lower()
|
|
if model_name in {"lingma-agent", "agent"} or has_tooling_context:
|
|
return "agent"
|
|
return default_ask_mode
|
|
|
|
|
|
def _tool_config_summary(tool_config: dict[str, Any] | None) -> dict[str, Any]:
|
|
if not isinstance(tool_config, dict):
|
|
return {"present": False, "provider": None, "tool_names": [], "tool_choice": None}
|
|
tools = tool_config.get("tools")
|
|
tool_names: list[str] = []
|
|
if isinstance(tools, list):
|
|
for tool in tools:
|
|
if not isinstance(tool, dict):
|
|
continue
|
|
if tool.get("type") == "function":
|
|
fn = tool.get("function")
|
|
if isinstance(fn, dict) and isinstance(fn.get("name"), str) and fn.get("name").strip():
|
|
tool_names.append(fn.get("name").strip())
|
|
continue
|
|
name = tool.get("name")
|
|
if isinstance(name, str) and name.strip():
|
|
tool_names.append(name.strip())
|
|
return {
|
|
"present": True,
|
|
"provider": tool_config.get("provider"),
|
|
"tool_names": tool_names,
|
|
"tool_choice": tool_config.get("tool_choice"),
|
|
}
|
|
|
|
|
|
async def _apply_cached_instance_or_invalidate(
|
|
*,
|
|
protocol: str,
|
|
logger: Any,
|
|
session_cache: SessionCache,
|
|
inst: PoolInstance,
|
|
cached_instance_name: str | None,
|
|
cached_session_id: str | None,
|
|
lookup_key: str | None,
|
|
) -> str | None:
|
|
if cached_instance_name and inst.name != cached_instance_name:
|
|
logger.info(
|
|
"%s session cache instance %s unhealthy, falling back to %s",
|
|
protocol,
|
|
cached_instance_name,
|
|
inst.name,
|
|
)
|
|
if lookup_key:
|
|
await session_cache.invalidate(lookup_key)
|
|
return None
|
|
return cached_session_id
|
|
|
|
|
|
async def prepare_execution_context(
|
|
*,
|
|
protocol: str,
|
|
requested_model: str,
|
|
has_tooling_context: bool,
|
|
tool_config: dict[str, Any] | None,
|
|
messages_dump: list[dict[str, Any]],
|
|
api_key: str,
|
|
affinity_key: str | None,
|
|
pool: LingmaPool,
|
|
session_cache: SessionCache,
|
|
logger: Any,
|
|
default_model: str,
|
|
default_ask_mode: str,
|
|
ensure_instance_logged_in: Callable[[PoolInstance], Awaitable[Any]],
|
|
last_user_text: Callable[[list[dict[str, Any]]], str],
|
|
messages_to_prompt: Callable[[list[dict[str, Any]]], str],
|
|
) -> ExecutionContext:
|
|
ask_mode = _resolve_ask_mode(
|
|
requested_model,
|
|
has_tooling_context,
|
|
default_ask_mode=default_ask_mode,
|
|
)
|
|
logger.info(
|
|
"%s.prepare requested_model=%s ask_mode=%s tooling=%s tool_config=%s",
|
|
protocol,
|
|
requested_model,
|
|
ask_mode,
|
|
has_tooling_context,
|
|
_tool_config_summary(tool_config),
|
|
)
|
|
|
|
reuse_eligible = (
|
|
session_cache.enabled
|
|
and ask_mode == "chat"
|
|
and len(messages_dump) >= 2
|
|
and not has_tooling_context
|
|
)
|
|
lookup_key: str | None = None
|
|
write_key: str | None = None
|
|
cached_session_id: str | None = None
|
|
cached_instance_name: str | None = None
|
|
if reuse_eligible:
|
|
prefix_branch_context = hash_branch_context(messages_dump[:-1])
|
|
lookup_key = session_cache.build_key(
|
|
api_key,
|
|
messages_dump[:-1],
|
|
tool_config=tool_config,
|
|
branch_context=prefix_branch_context,
|
|
)
|
|
write_key = session_cache.build_key(
|
|
api_key,
|
|
messages_dump,
|
|
tool_config=tool_config,
|
|
branch_context=hash_branch_context(messages_dump),
|
|
)
|
|
entry = await session_cache.get(lookup_key)
|
|
if entry is None:
|
|
legacy_lookup_key = session_cache.build_key(api_key, messages_dump[:-1], tool_config=tool_config)
|
|
entry = await session_cache.get(legacy_lookup_key)
|
|
if entry is not None:
|
|
lookup_key = legacy_lookup_key
|
|
if entry is not None:
|
|
cached_session_id = entry.session_id
|
|
cached_instance_name = entry.instance_name or None
|
|
|
|
affinity = cached_instance_name or affinity_key
|
|
inst = pool.pick(affinity_key=affinity)
|
|
cached_session_id = await _apply_cached_instance_or_invalidate(
|
|
protocol=protocol,
|
|
logger=logger,
|
|
session_cache=session_cache,
|
|
inst=inst,
|
|
cached_instance_name=cached_instance_name,
|
|
cached_session_id=cached_session_id,
|
|
lookup_key=lookup_key,
|
|
)
|
|
|
|
await ensure_instance_logged_in(inst)
|
|
|
|
models = await inst.client.query_models()
|
|
available = flatten_model_keys(models)
|
|
name_map = build_model_name_map(models)
|
|
model = resolve_model(requested_model, available, default_model, name_map)
|
|
|
|
if cached_session_id:
|
|
prompt = last_user_text(messages_dump)
|
|
is_reply = True
|
|
else:
|
|
prompt = messages_to_prompt(messages_dump)
|
|
is_reply = False
|
|
|
|
logger.info(
|
|
"%s.context inst=%s model=%s ask_mode=%s reuse_eligible=%s reused_session=%s affinity=%s",
|
|
protocol,
|
|
inst.name,
|
|
model,
|
|
ask_mode,
|
|
reuse_eligible,
|
|
bool(cached_session_id),
|
|
affinity,
|
|
)
|
|
|
|
return ExecutionContext(
|
|
ask_mode=ask_mode,
|
|
lookup_key=lookup_key,
|
|
write_key=write_key,
|
|
cached_session_id=cached_session_id,
|
|
inst=inst,
|
|
model=model,
|
|
prompt=prompt,
|
|
is_reply=is_reply,
|
|
affinity=affinity,
|
|
)
|
|
|
|
|
|
async def start_execution(
|
|
*,
|
|
protocol: str,
|
|
execution: ExecutionContext,
|
|
stream: bool,
|
|
chat_guard: InFlightGuard,
|
|
logger: Any,
|
|
estimate_tokens: Callable[[str], int],
|
|
extra_log_context: dict[str, Any] | None = None,
|
|
) -> StartedExecution:
|
|
if not execution.prompt:
|
|
raise ValueError("messages is empty")
|
|
|
|
prompt_tokens = estimate_tokens(execution.prompt)
|
|
ticket = await chat_guard.try_acquire()
|
|
execution.inst.in_flight += 1
|
|
log_extra = {
|
|
"ctx_instance": execution.inst.name,
|
|
"ctx_model": execution.model,
|
|
"ctx_ask_mode": execution.ask_mode,
|
|
"ctx_stream": stream,
|
|
"ctx_prompt_tokens": prompt_tokens,
|
|
"ctx_in_flight": chat_guard.in_flight,
|
|
"ctx_affinity": execution.affinity,
|
|
"ctx_session_reuse": bool(execution.cached_session_id),
|
|
}
|
|
if extra_log_context:
|
|
log_extra.update(extra_log_context)
|
|
logger.info(
|
|
"%s.start inst=%s model=%s ask_mode=%s stream=%s prompt_tokens~%d reuse=%s",
|
|
protocol,
|
|
execution.inst.name,
|
|
execution.model,
|
|
execution.ask_mode,
|
|
stream,
|
|
prompt_tokens,
|
|
bool(execution.cached_session_id),
|
|
extra=log_extra,
|
|
)
|
|
return StartedExecution(ticket=ticket, prompt_tokens=prompt_tokens)
|
|
|
|
|
|
async def complete_execution(
|
|
*,
|
|
protocol: str,
|
|
execution: ExecutionContext,
|
|
prompt_tokens: int,
|
|
tool_config: dict[str, Any] | None,
|
|
logger: Any,
|
|
stats_collector: Any,
|
|
session_cache: SessionCache,
|
|
estimate_tokens: Callable[[str], int],
|
|
) -> CompletedExecution:
|
|
try:
|
|
logger.info(
|
|
"%s.complete inst=%s ask_mode=%s tool_config=%s",
|
|
protocol,
|
|
execution.inst.name,
|
|
execution.ask_mode,
|
|
_tool_config_summary(tool_config),
|
|
)
|
|
result = await execution.inst.client.chat_complete(
|
|
execution.prompt,
|
|
execution.model,
|
|
execution.ask_mode,
|
|
session_id=execution.cached_session_id,
|
|
is_reply=execution.is_reply,
|
|
tool_config=tool_config,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("%s.complete error (inst=%s): %s", protocol, execution.inst.name, exc)
|
|
await stats_collector.record_chat(
|
|
stream=False,
|
|
success=False,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=0,
|
|
)
|
|
if execution.cached_session_id and execution.lookup_key:
|
|
await session_cache.invalidate(execution.lookup_key)
|
|
raise UpstreamExecutionError from exc
|
|
|
|
completion_tokens = estimate_tokens(result.get("text") or "")
|
|
await stats_collector.record_chat(
|
|
stream=False,
|
|
success=True,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
if execution.write_key:
|
|
sid = result.get("sessionId")
|
|
if sid:
|
|
await session_cache.put(execution.write_key, sid, execution.inst.name)
|
|
return CompletedExecution(result=result, completion_tokens=completion_tokens)
|
|
|
|
|
|
async def finalize_stream_execution(
|
|
*,
|
|
success: bool,
|
|
write_key: str | None,
|
|
session_id: str | None,
|
|
inst: PoolInstance,
|
|
ticket: Any,
|
|
session_cache: SessionCache,
|
|
stats_collector: Any,
|
|
prompt_tokens: int,
|
|
completion_tokens: int,
|
|
) -> None:
|
|
if success and write_key and session_id:
|
|
await session_cache.put(write_key, session_id, inst.name)
|
|
await stats_collector.record_chat(
|
|
stream=True,
|
|
success=success,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
release_execution(ticket=ticket, inst=inst)
|
|
|
|
|
|
def release_execution(*, ticket: Any, inst: PoolInstance) -> None:
|
|
inst.in_flight = max(0, inst.in_flight - 1)
|
|
ticket.release()
|