from __future__ import annotations import asyncio import json import time import uuid from typing import Any, Awaitable, Callable from fastapi import HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from ..openai_schema import ChatCompletionsRequest, ResponsesRequest from .responses_adapter import ( _responses_non_stream_from_chat_payload, _responses_to_chat_request, _responses_usage_from_chat, _sse_data, ) async def _responses_stream_from_chat_stream( chat_stream: StreamingResponse, *, response_id: str, model: str, ): created_at = int(time.time()) usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} completed_sent = False output_item_id = f"msg_{uuid.uuid4().hex}" output_index = 0 content_index = 0 output_text_parts: list[str] = [] function_call_items: list[dict[str, Any]] = [] function_call_index_by_id: dict[str, int] = {} function_call_arguments_by_id: dict[str, str] = {} function_call_name_by_id: dict[str, str] = {} function_call_id_by_upstream_index: dict[int, str] = {} def _message_item(status: str) -> dict[str, Any]: return { "id": output_item_id, "type": "message", "role": "assistant", "status": status, "content": [ { "type": "output_text", "text": "".join(output_text_parts), } ], } def _function_call_item(call_id: str, *, status: str, name: str, arguments: str) -> dict[str, Any]: return { "id": call_id, "type": "function_call", "call_id": call_id, "name": name, "arguments": arguments, "status": status, } def _completed_output_items() -> list[dict[str, Any]]: return [_message_item("completed"), *function_call_items] def _completed_frame() -> str: return _sse_data( { "type": "response.completed", "response": { "id": response_id, "object": "response", "created_at": created_at, "status": "completed", "model": model, "output": _completed_output_items(), "usage": usage, }, } ) def _finish_output_item_frames() -> list[str]: frames = [ _sse_data( { "type": "response.output_text.done", "response_id": response_id, "item_id": output_item_id, "output_index": output_index, "content_index": content_index, "text": "".join(output_text_parts), } ), _sse_data( { "type": "response.output_item.done", "response_id": response_id, "output_index": output_index, "item": _message_item("completed"), } ), ] for idx, item in enumerate(function_call_items, start=1): frames.append( _sse_data( { "type": "response.function_call_arguments.done", "response_id": response_id, "item_id": item["id"], "output_index": idx, "arguments": item["arguments"], } ) ) frames.append( _sse_data( { "type": "response.output_item.done", "response_id": response_id, "output_index": idx, "item": item, } ) ) return frames def _ensure_function_call_item(call_id: str) -> list[str]: existing_index = function_call_index_by_id.get(call_id) name = function_call_name_by_id.get(call_id, "tool") arguments = function_call_arguments_by_id.get(call_id, "") if existing_index is not None: function_call_items[existing_index] = _function_call_item( call_id, status="completed", name=name, arguments=arguments, ) return [] item = _function_call_item( call_id, status="completed", name=name, arguments=arguments, ) function_call_items.append(item) item_index = len(function_call_items) - 1 function_call_index_by_id[call_id] = item_index return [ _sse_data( { "type": "response.output_item.added", "response_id": response_id, "output_index": item_index + 1, "item": _function_call_item( call_id, status="in_progress", name=name, arguments="", ), } ) ] yield _sse_data( { "type": "response.created", "response": { "id": response_id, "object": "response", "created_at": created_at, "status": "in_progress", "model": model, "output": [], }, } ) yield _sse_data( { "type": "response.output_item.added", "response_id": response_id, "output_index": output_index, "item": _message_item("in_progress"), } ) try: async for part in chat_stream.body_iterator: chunk = part.decode("utf-8") if isinstance(part, bytes) else str(part) for frame in chunk.split("\n\n"): frame = frame.strip() if not frame or not frame.startswith("data:"): continue body = frame[len("data:") :].strip() if body == "[DONE]": for event in _finish_output_item_frames(): yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True return try: payload = json.loads(body) except Exception: continue frame_usage = _responses_usage_from_chat(payload.get("usage")) if any(frame_usage.values()): usage = frame_usage choices = payload.get("choices") if not isinstance(choices, list) or not choices: continue choice = choices[0] if isinstance(choices[0], dict) else {} delta = choice.get("delta") if isinstance(choice.get("delta"), dict) else {} text = delta.get("content") if isinstance(text, str) and text: output_text_parts.append(text) yield _sse_data( { "type": "response.output_text.delta", "response_id": response_id, "item_id": output_item_id, "output_index": output_index, "content_index": content_index, "delta": text, } ) tool_calls = delta.get("tool_calls") if isinstance(tool_calls, list): for idx, tool_call in enumerate(tool_calls): if not isinstance(tool_call, dict): continue fn = tool_call.get("function") if isinstance(tool_call.get("function"), dict) else {} upstream_index_raw = tool_call.get("index") upstream_index = upstream_index_raw if isinstance(upstream_index_raw, int) else idx call_id = str( tool_call.get("id") or function_call_id_by_upstream_index.get(upstream_index) or f"call_{upstream_index}" ) function_call_id_by_upstream_index[upstream_index] = call_id name = str(fn.get("name") or function_call_name_by_id.get(call_id) or "tool") function_call_name_by_id[call_id] = name arguments_delta = str(fn.get("arguments") or "") accumulated_arguments = ( function_call_arguments_by_id.get(call_id, "") + arguments_delta ) function_call_arguments_by_id[call_id] = accumulated_arguments for event in _ensure_function_call_item(call_id): yield event if arguments_delta: yield _sse_data( { "type": "response.function_call_arguments.delta", "response_id": response_id, "item_id": call_id, "output_index": function_call_index_by_id[call_id] + 1, "delta": arguments_delta, } ) except asyncio.CancelledError: if not completed_sent: for event in _finish_output_item_frames(): yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True return except Exception: if not completed_sent: for event in _finish_output_item_frames(): yield event yield _completed_frame() yield "data: [DONE]\n\n" completed_sent = True return if not completed_sent: for event in _finish_output_item_frames(): yield event yield _completed_frame() yield "data: [DONE]\n\n" async def handle_responses( req: ResponsesRequest, request: Request, *, chat_completions_handler: Callable[[ChatCompletionsRequest, Request], Awaitable[Any]], streaming_response_headers: dict[str, str], ): chat_req = _responses_to_chat_request(req) chat_response = await chat_completions_handler(chat_req, request) if isinstance(chat_response, StreamingResponse): response_id = f"resp_{uuid.uuid4().hex}" return StreamingResponse( _responses_stream_from_chat_stream( chat_response, response_id=response_id, model=req.model, ), media_type="text/event-stream", headers=streaming_response_headers, ) invalid_upstream_error = { "error": {"message": "invalid upstream response", "type": "upstream_error"} } try: chat_payload = json.loads(chat_response.body) except Exception: raise HTTPException( status_code=502, detail=invalid_upstream_error, ) if not isinstance(chat_payload, dict): raise HTTPException( status_code=502, detail=invalid_upstream_error, ) return JSONResponse(content=_responses_non_stream_from_chat_payload(chat_payload))