Keep app.main.v1_responses as the compatibility entrypoint while moving the Responses wrapper and SSE bridge into a dedicated module. This reduces app/main.py without changing the existing Responses behavior or test patch points. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
327 lines
12 KiB
Python
327 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import time
|
|
import uuid
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from fastapi import HTTPException, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from ..openai_schema import ChatCompletionsRequest, ResponsesRequest
|
|
from .responses_adapter import (
|
|
_responses_non_stream_from_chat_payload,
|
|
_responses_to_chat_request,
|
|
_responses_usage_from_chat,
|
|
_sse_data,
|
|
)
|
|
|
|
|
|
async def _responses_stream_from_chat_stream(
|
|
chat_stream: StreamingResponse,
|
|
*,
|
|
response_id: str,
|
|
model: str,
|
|
):
|
|
created_at = int(time.time())
|
|
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
|
completed_sent = False
|
|
output_item_id = f"msg_{uuid.uuid4().hex}"
|
|
output_index = 0
|
|
content_index = 0
|
|
output_text_parts: list[str] = []
|
|
function_call_items: list[dict[str, Any]] = []
|
|
function_call_index_by_id: dict[str, int] = {}
|
|
function_call_arguments_by_id: dict[str, str] = {}
|
|
function_call_name_by_id: dict[str, str] = {}
|
|
function_call_id_by_upstream_index: dict[int, str] = {}
|
|
|
|
def _message_item(status: str) -> dict[str, Any]:
|
|
return {
|
|
"id": output_item_id,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"status": status,
|
|
"content": [
|
|
{
|
|
"type": "output_text",
|
|
"text": "".join(output_text_parts),
|
|
}
|
|
],
|
|
}
|
|
|
|
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
|
|
return {
|
|
"id": call_id,
|
|
"type": "function_call",
|
|
"call_id": call_id,
|
|
"name": name,
|
|
"arguments": arguments,
|
|
"status": status,
|
|
}
|
|
|
|
def _completed_output_items() -> list[dict[str, Any]]:
|
|
return [_message_item("completed"), *function_call_items]
|
|
|
|
def _completed_frame() -> str:
|
|
return _sse_data(
|
|
{
|
|
"type": "response.completed",
|
|
"response": {
|
|
"id": response_id,
|
|
"object": "response",
|
|
"created_at": created_at,
|
|
"status": "completed",
|
|
"model": model,
|
|
"output": _completed_output_items(),
|
|
"usage": usage,
|
|
},
|
|
}
|
|
)
|
|
|
|
def _finish_output_item_frames() -> list[str]:
|
|
frames = [
|
|
_sse_data(
|
|
{
|
|
"type": "response.output_text.done",
|
|
"response_id": response_id,
|
|
"item_id": output_item_id,
|
|
"output_index": output_index,
|
|
"content_index": content_index,
|
|
"text": "".join(output_text_parts),
|
|
}
|
|
),
|
|
_sse_data(
|
|
{
|
|
"type": "response.output_item.done",
|
|
"response_id": response_id,
|
|
"output_index": output_index,
|
|
"item": _message_item("completed"),
|
|
}
|
|
),
|
|
]
|
|
for idx, item in enumerate(function_call_items, start=1):
|
|
frames.append(
|
|
_sse_data(
|
|
{
|
|
"type": "response.function_call_arguments.done",
|
|
"response_id": response_id,
|
|
"item_id": item["id"],
|
|
"output_index": idx,
|
|
"arguments": item["arguments"],
|
|
}
|
|
)
|
|
)
|
|
frames.append(
|
|
_sse_data(
|
|
{
|
|
"type": "response.output_item.done",
|
|
"response_id": response_id,
|
|
"output_index": idx,
|
|
"item": item,
|
|
}
|
|
)
|
|
)
|
|
return frames
|
|
|
|
def _ensure_function_call_item(call_id: str) -> list[str]:
|
|
existing_index = function_call_index_by_id.get(call_id)
|
|
name = function_call_name_by_id.get(call_id, "tool")
|
|
arguments = function_call_arguments_by_id.get(call_id, "")
|
|
if existing_index is not None:
|
|
function_call_items[existing_index] = _function_call_item(
|
|
call_id,
|
|
status="completed",
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
return []
|
|
item = _function_call_item(
|
|
call_id,
|
|
status="completed",
|
|
name=name,
|
|
arguments=arguments,
|
|
)
|
|
function_call_items.append(item)
|
|
item_index = len(function_call_items) - 1
|
|
function_call_index_by_id[call_id] = item_index
|
|
return [
|
|
_sse_data(
|
|
{
|
|
"type": "response.output_item.added",
|
|
"response_id": response_id,
|
|
"output_index": item_index + 1,
|
|
"item": _function_call_item(
|
|
call_id,
|
|
status="in_progress",
|
|
name=name,
|
|
arguments="",
|
|
),
|
|
}
|
|
)
|
|
]
|
|
|
|
yield _sse_data(
|
|
{
|
|
"type": "response.created",
|
|
"response": {
|
|
"id": response_id,
|
|
"object": "response",
|
|
"created_at": created_at,
|
|
"status": "in_progress",
|
|
"model": model,
|
|
"output": [],
|
|
},
|
|
}
|
|
)
|
|
yield _sse_data(
|
|
{
|
|
"type": "response.output_item.added",
|
|
"response_id": response_id,
|
|
"output_index": output_index,
|
|
"item": _message_item("in_progress"),
|
|
}
|
|
)
|
|
|
|
try:
|
|
async for part in chat_stream.body_iterator:
|
|
chunk = part.decode("utf-8") if isinstance(part, bytes) else str(part)
|
|
for frame in chunk.split("\n\n"):
|
|
frame = frame.strip()
|
|
if not frame or not frame.startswith("data:"):
|
|
continue
|
|
body = frame[len("data:") :].strip()
|
|
if body == "[DONE]":
|
|
for event in _finish_output_item_frames():
|
|
yield event
|
|
yield _completed_frame()
|
|
yield "data: [DONE]\n\n"
|
|
completed_sent = True
|
|
return
|
|
|
|
try:
|
|
payload = json.loads(body)
|
|
except Exception:
|
|
continue
|
|
|
|
frame_usage = _responses_usage_from_chat(payload.get("usage"))
|
|
if any(frame_usage.values()):
|
|
usage = frame_usage
|
|
|
|
choices = payload.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
continue
|
|
choice = choices[0] if isinstance(choices[0], dict) else {}
|
|
delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {}
|
|
|
|
text = delta.get("content")
|
|
if isinstance(text, str) and text:
|
|
output_text_parts.append(text)
|
|
yield _sse_data(
|
|
{
|
|
"type": "response.output_text.delta",
|
|
"response_id": response_id,
|
|
"item_id": output_item_id,
|
|
"output_index": output_index,
|
|
"content_index": content_index,
|
|
"delta": text,
|
|
}
|
|
)
|
|
|
|
tool_calls = delta.get("tool_calls")
|
|
if isinstance(tool_calls, list):
|
|
for idx, tool_call in enumerate(tool_calls):
|
|
if not isinstance(tool_call, dict):
|
|
continue
|
|
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
|
|
upstream_index_raw = tool_call.get("index")
|
|
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
|
|
call_id = str(
|
|
tool_call.get("id")
|
|
or function_call_id_by_upstream_index.get(upstream_index)
|
|
or f"call_{upstream_index}"
|
|
)
|
|
function_call_id_by_upstream_index[upstream_index] = call_id
|
|
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
|
|
function_call_name_by_id[call_id] = name
|
|
arguments_delta = str(fn.get("arguments") or "")
|
|
accumulated_arguments = (
|
|
function_call_arguments_by_id.get(call_id, "") + arguments_delta
|
|
)
|
|
function_call_arguments_by_id[call_id] = accumulated_arguments
|
|
for event in _ensure_function_call_item(call_id):
|
|
yield event
|
|
if arguments_delta:
|
|
yield _sse_data(
|
|
{
|
|
"type": "response.function_call_arguments.delta",
|
|
"response_id": response_id,
|
|
"item_id": call_id,
|
|
"output_index": function_call_index_by_id[call_id] + 1,
|
|
"delta": arguments_delta,
|
|
}
|
|
)
|
|
except asyncio.CancelledError:
|
|
if not completed_sent:
|
|
for event in _finish_output_item_frames():
|
|
yield event
|
|
yield _completed_frame()
|
|
yield "data: [DONE]\n\n"
|
|
completed_sent = True
|
|
return
|
|
except Exception:
|
|
if not completed_sent:
|
|
for event in _finish_output_item_frames():
|
|
yield event
|
|
yield _completed_frame()
|
|
yield "data: [DONE]\n\n"
|
|
completed_sent = True
|
|
return
|
|
|
|
if not completed_sent:
|
|
for event in _finish_output_item_frames():
|
|
yield event
|
|
yield _completed_frame()
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
async def handle_responses(
|
|
req: ResponsesRequest,
|
|
request: Request,
|
|
*,
|
|
chat_completions_handler: Callable[[ChatCompletionsRequest, Request], Awaitable[Any]],
|
|
streaming_response_headers: dict[str, str],
|
|
):
|
|
chat_req = _responses_to_chat_request(req)
|
|
chat_response = await chat_completions_handler(chat_req, request)
|
|
|
|
if isinstance(chat_response, StreamingResponse):
|
|
response_id = f"resp_{uuid.uuid4().hex}"
|
|
return StreamingResponse(
|
|
_responses_stream_from_chat_stream(
|
|
chat_response,
|
|
response_id=response_id,
|
|
model=req.model,
|
|
),
|
|
media_type="text/event-stream",
|
|
headers=streaming_response_headers,
|
|
)
|
|
|
|
invalid_upstream_error = {
|
|
"error": {"message": "invalid upstream response", "type": "upstream_error"}
|
|
}
|
|
try:
|
|
chat_payload = json.loads(chat_response.body)
|
|
except Exception:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=invalid_upstream_error,
|
|
)
|
|
if not isinstance(chat_payload, dict):
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=invalid_upstream_error,
|
|
)
|
|
return JSONResponse(content=_responses_non_stream_from_chat_payload(chat_payload))
|