refactor: extract OpenAI Responses route wrapper
Keep app.main.v1_responses as the compatibility entrypoint while moving the Responses wrapper and SSE bridge into a dedicated module. This reduces app/main.py without changing the existing Responses behavior or test patch points. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
326
app/http/openai_responses.py
Normal file
326
app/http/openai_responses.py
Normal file
@@ -0,0 +1,326 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from ..openai_schema import ChatCompletionsRequest, ResponsesRequest
|
||||
from .responses_adapter import (
|
||||
_responses_non_stream_from_chat_payload,
|
||||
_responses_to_chat_request,
|
||||
_responses_usage_from_chat,
|
||||
_sse_data,
|
||||
)
|
||||
|
||||
|
||||
async def _responses_stream_from_chat_stream(
|
||||
chat_stream: StreamingResponse,
|
||||
*,
|
||||
response_id: str,
|
||||
model: str,
|
||||
):
|
||||
created_at = int(time.time())
|
||||
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
completed_sent = False
|
||||
output_item_id = f"msg_{uuid.uuid4().hex}"
|
||||
output_index = 0
|
||||
content_index = 0
|
||||
output_text_parts: list[str] = []
|
||||
function_call_items: list[dict[str, Any]] = []
|
||||
function_call_index_by_id: dict[str, int] = {}
|
||||
function_call_arguments_by_id: dict[str, str] = {}
|
||||
function_call_name_by_id: dict[str, str] = {}
|
||||
function_call_id_by_upstream_index: dict[int, str] = {}
|
||||
|
||||
def _message_item(status: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": output_item_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": status,
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function_call",
|
||||
"call_id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
def _completed_output_items() -> list[dict[str, Any]]:
|
||||
return [_message_item("completed"), *function_call_items]
|
||||
|
||||
def _completed_frame() -> str:
|
||||
return _sse_data(
|
||||
{
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": created_at,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": _completed_output_items(),
|
||||
"usage": usage,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def _finish_output_item_frames() -> list[str]:
|
||||
frames = [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_text.done",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
),
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("completed"),
|
||||
}
|
||||
),
|
||||
]
|
||||
for idx, item in enumerate(function_call_items, start=1):
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"response_id": response_id,
|
||||
"item_id": item["id"],
|
||||
"output_index": idx,
|
||||
"arguments": item["arguments"],
|
||||
}
|
||||
)
|
||||
)
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": idx,
|
||||
"item": item,
|
||||
}
|
||||
)
|
||||
)
|
||||
return frames
|
||||
|
||||
def _ensure_function_call_item(call_id: str) -> list[str]:
|
||||
existing_index = function_call_index_by_id.get(call_id)
|
||||
name = function_call_name_by_id.get(call_id, "tool")
|
||||
arguments = function_call_arguments_by_id.get(call_id, "")
|
||||
if existing_index is not None:
|
||||
function_call_items[existing_index] = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return []
|
||||
item = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
function_call_items.append(item)
|
||||
item_index = len(function_call_items) - 1
|
||||
function_call_index_by_id[call_id] = item_index
|
||||
return [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": item_index + 1,
|
||||
"item": _function_call_item(
|
||||
call_id,
|
||||
status="in_progress",
|
||||
name=name,
|
||||
arguments="",
|
||||
),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": created_at,
|
||||
"status": "in_progress",
|
||||
"model": model,
|
||||
"output": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("in_progress"),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for part in chat_stream.body_iterator:
|
||||
chunk = part.decode("utf-8") if isinstance(part, bytes) else str(part)
|
||||
for frame in chunk.split("\n\n"):
|
||||
frame = frame.strip()
|
||||
if not frame or not frame.startswith("data:"):
|
||||
continue
|
||||
body = frame[len("data:") :].strip()
|
||||
if body == "[DONE]":
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
frame_usage = _responses_usage_from_chat(payload.get("usage"))
|
||||
if any(frame_usage.values()):
|
||||
usage = frame_usage
|
||||
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
choice = choices[0] if isinstance(choices[0], dict) else {}
|
||||
delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {}
|
||||
|
||||
text = delta.get("content")
|
||||
if isinstance(text, str) and text:
|
||||
output_text_parts.append(text)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_text.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"delta": text,
|
||||
}
|
||||
)
|
||||
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
|
||||
upstream_index_raw = tool_call.get("index")
|
||||
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
|
||||
call_id = str(
|
||||
tool_call.get("id")
|
||||
or function_call_id_by_upstream_index.get(upstream_index)
|
||||
or f"call_{upstream_index}"
|
||||
)
|
||||
function_call_id_by_upstream_index[upstream_index] = call_id
|
||||
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
|
||||
function_call_name_by_id[call_id] = name
|
||||
arguments_delta = str(fn.get("arguments") or "")
|
||||
accumulated_arguments = (
|
||||
function_call_arguments_by_id.get(call_id, "") + arguments_delta
|
||||
)
|
||||
function_call_arguments_by_id[call_id] = accumulated_arguments
|
||||
for event in _ensure_function_call_item(call_id):
|
||||
yield event
|
||||
if arguments_delta:
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": call_id,
|
||||
"output_index": function_call_index_by_id[call_id] + 1,
|
||||
"delta": arguments_delta,
|
||||
}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
except Exception:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def handle_responses(
|
||||
req: ResponsesRequest,
|
||||
request: Request,
|
||||
*,
|
||||
chat_completions_handler: Callable[[ChatCompletionsRequest, Request], Awaitable[Any]],
|
||||
streaming_response_headers: dict[str, str],
|
||||
):
|
||||
chat_req = _responses_to_chat_request(req)
|
||||
chat_response = await chat_completions_handler(chat_req, request)
|
||||
|
||||
if isinstance(chat_response, StreamingResponse):
|
||||
response_id = f"resp_{uuid.uuid4().hex}"
|
||||
return StreamingResponse(
|
||||
_responses_stream_from_chat_stream(
|
||||
chat_response,
|
||||
response_id=response_id,
|
||||
model=req.model,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers=streaming_response_headers,
|
||||
)
|
||||
|
||||
invalid_upstream_error = {
|
||||
"error": {"message": "invalid upstream response", "type": "upstream_error"}
|
||||
}
|
||||
try:
|
||||
chat_payload = json.loads(chat_response.body)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=invalid_upstream_error,
|
||||
)
|
||||
if not isinstance(chat_payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=invalid_upstream_error,
|
||||
)
|
||||
return JSONResponse(content=_responses_non_stream_from_chat_payload(chat_payload))
|
||||
319
app/main.py
319
app/main.py
@@ -30,14 +30,7 @@ from .http.execution_core import (
|
||||
_resolve_ask_mode as _shared_resolve_ask_mode,
|
||||
prepare_execution_context,
|
||||
)
|
||||
from .http.responses_adapter import (
|
||||
_responses_id_from_chat_id,
|
||||
_responses_input_to_messages,
|
||||
_responses_non_stream_from_chat_payload,
|
||||
_responses_to_chat_request,
|
||||
_responses_usage_from_chat,
|
||||
_sse_data,
|
||||
)
|
||||
from .http.openai_responses import handle_responses
|
||||
from .http.tool_bridge import (
|
||||
_allowed_stream_tool_event,
|
||||
_allowed_tool_events,
|
||||
@@ -941,313 +934,15 @@ async def v1_chat_completions(req: ChatCompletionsRequest, request: Request):
|
||||
|
||||
|
||||
|
||||
async def _responses_stream_from_chat_stream(
|
||||
chat_stream: StreamingResponse,
|
||||
*,
|
||||
response_id: str,
|
||||
model: str,
|
||||
):
|
||||
created_at = int(time.time())
|
||||
usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
completed_sent = False
|
||||
output_item_id = f"msg_{uuid.uuid4().hex}"
|
||||
output_index = 0
|
||||
content_index = 0
|
||||
output_text_parts: list[str] = []
|
||||
function_call_items: list[dict[str, Any]] = []
|
||||
function_call_index_by_id: dict[str, int] = {}
|
||||
function_call_arguments_by_id: dict[str, str] = {}
|
||||
function_call_name_by_id: dict[str, str] = {}
|
||||
function_call_id_by_upstream_index: dict[int, str] = {}
|
||||
|
||||
def _message_item(status: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": output_item_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": status,
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function_call",
|
||||
"call_id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
def _completed_output_items() -> list[dict[str, Any]]:
|
||||
return [_message_item("completed"), *function_call_items]
|
||||
|
||||
def _completed_frame() -> str:
|
||||
return _sse_data(
|
||||
{
|
||||
"type": "response.completed",
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": created_at,
|
||||
"status": "completed",
|
||||
"model": model,
|
||||
"output": _completed_output_items(),
|
||||
"usage": usage,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def _finish_output_item_frames() -> list[str]:
|
||||
frames = [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_text.done",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"text": "".join(output_text_parts),
|
||||
}
|
||||
),
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("completed"),
|
||||
}
|
||||
),
|
||||
]
|
||||
for idx, item in enumerate(function_call_items, start=1):
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.done",
|
||||
"response_id": response_id,
|
||||
"item_id": item["id"],
|
||||
"output_index": idx,
|
||||
"arguments": item["arguments"],
|
||||
}
|
||||
)
|
||||
)
|
||||
frames.append(
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.done",
|
||||
"response_id": response_id,
|
||||
"output_index": idx,
|
||||
"item": item,
|
||||
}
|
||||
)
|
||||
)
|
||||
return frames
|
||||
|
||||
def _ensure_function_call_item(call_id: str) -> list[str]:
|
||||
existing_index = function_call_index_by_id.get(call_id)
|
||||
name = function_call_name_by_id.get(call_id, "tool")
|
||||
arguments = function_call_arguments_by_id.get(call_id, "")
|
||||
if existing_index is not None:
|
||||
function_call_items[existing_index] = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
return []
|
||||
item = _function_call_item(
|
||||
call_id,
|
||||
status="completed",
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
function_call_items.append(item)
|
||||
item_index = len(function_call_items) - 1
|
||||
function_call_index_by_id[call_id] = item_index
|
||||
return [
|
||||
_sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": item_index + 1,
|
||||
"item": _function_call_item(
|
||||
call_id,
|
||||
status="in_progress",
|
||||
name=name,
|
||||
arguments="",
|
||||
),
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": created_at,
|
||||
"status": "in_progress",
|
||||
"model": model,
|
||||
"output": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_item.added",
|
||||
"response_id": response_id,
|
||||
"output_index": output_index,
|
||||
"item": _message_item("in_progress"),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
async for part in chat_stream.body_iterator:
|
||||
chunk = part.decode("utf-8") if isinstance(part, bytes) else str(part)
|
||||
for frame in chunk.split("\n\n"):
|
||||
frame = frame.strip()
|
||||
if not frame or not frame.startswith("data:"):
|
||||
continue
|
||||
body = frame[len("data:") :].strip()
|
||||
if body == "[DONE]":
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
frame_usage = _responses_usage_from_chat(payload.get("usage"))
|
||||
if any(frame_usage.values()):
|
||||
usage = frame_usage
|
||||
|
||||
choices = payload.get("choices")
|
||||
if not isinstance(choices, list) or not choices:
|
||||
continue
|
||||
choice = choices[0] if isinstance(choices[0], dict) else {}
|
||||
delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {}
|
||||
|
||||
text = delta.get("content")
|
||||
if isinstance(text, str) and text:
|
||||
output_text_parts.append(text)
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.output_text.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": output_item_id,
|
||||
"output_index": output_index,
|
||||
"content_index": content_index,
|
||||
"delta": text,
|
||||
}
|
||||
)
|
||||
|
||||
tool_calls = delta.get("tool_calls")
|
||||
if isinstance(tool_calls, list):
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {}
|
||||
upstream_index_raw = tool_call.get("index")
|
||||
upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx
|
||||
call_id = str(
|
||||
tool_call.get("id")
|
||||
or function_call_id_by_upstream_index.get(upstream_index)
|
||||
or f"call_{upstream_index}"
|
||||
)
|
||||
function_call_id_by_upstream_index[upstream_index] = call_id
|
||||
name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool")
|
||||
function_call_name_by_id[call_id] = name
|
||||
arguments_delta = str(fn.get("arguments") or "")
|
||||
accumulated_arguments = (
|
||||
function_call_arguments_by_id.get(call_id, "") + arguments_delta
|
||||
)
|
||||
function_call_arguments_by_id[call_id] = accumulated_arguments
|
||||
for event in _ensure_function_call_item(call_id):
|
||||
yield event
|
||||
if arguments_delta:
|
||||
yield _sse_data(
|
||||
{
|
||||
"type": "response.function_call_arguments.delta",
|
||||
"response_id": response_id,
|
||||
"item_id": call_id,
|
||||
"output_index": function_call_index_by_id[call_id] + 1,
|
||||
"delta": arguments_delta,
|
||||
}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
except Exception:
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
completed_sent = True
|
||||
return
|
||||
|
||||
if not completed_sent:
|
||||
for event in _finish_output_item_frames():
|
||||
yield event
|
||||
yield _completed_frame()
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
|
||||
@app.post("/responses", dependencies=[Depends(auth_guard)])
|
||||
@app.post("/v1/responses", dependencies=[Depends(auth_guard)])
|
||||
async def v1_responses(req: ResponsesRequest, request: Request):
|
||||
chat_req = _responses_to_chat_request(req)
|
||||
chat_response = await v1_chat_completions(chat_req, request)
|
||||
|
||||
if isinstance(chat_response, StreamingResponse):
|
||||
response_id = f"resp_{uuid.uuid4().hex}"
|
||||
return StreamingResponse(
|
||||
_responses_stream_from_chat_stream(
|
||||
chat_response,
|
||||
response_id=response_id,
|
||||
model=req.model,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
invalid_upstream_error = {
|
||||
"error": {"message": "invalid upstream response", "type": "upstream_error"}
|
||||
}
|
||||
try:
|
||||
chat_payload = json.loads(chat_response.body)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=invalid_upstream_error,
|
||||
)
|
||||
if not isinstance(chat_payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=invalid_upstream_error,
|
||||
)
|
||||
return JSONResponse(content=_responses_non_stream_from_chat_payload(chat_payload))
|
||||
return await handle_responses(
|
||||
req,
|
||||
request,
|
||||
chat_completions_handler=v1_chat_completions,
|
||||
streaming_response_headers=STREAMING_RESPONSE_HEADERS,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user