feat: add streaming responses for anthropic and openai
This commit is contained in:
@@ -150,10 +150,6 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request)
|
|||||||
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
|
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Stream {
|
|
||||||
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized, err := normalizeAnthropicRequest(req)
|
normalized, err := normalizeAnthropicRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,6 +157,11 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Stream {
|
||||||
|
s.handleAnthropicStream(w, r, normalized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := s.svc.Generate(r.Context(), normalized)
|
result, err := s.svc.Generate(r.Context(), normalized)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
|
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
@@ -202,10 +203,6 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
|
|||||||
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
|
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if req.Stream {
|
|
||||||
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized, err := normalizeOpenAIRequest(req)
|
normalized, err := normalizeOpenAIRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -213,6 +210,11 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.Stream {
|
||||||
|
s.handleOpenAIStream(w, r, normalized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result, err := s.svc.Generate(r.Context(), normalized)
|
result, err := s.svc.Generate(r.Context(), normalized)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error())
|
writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
@@ -243,6 +245,248 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleAnthropicStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
writeAnthropicError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
events, done, err := s.svc.GenerateStream(r.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model := strings.TrimSpace(req.Model)
|
||||||
|
if model == "" {
|
||||||
|
model = "lingma"
|
||||||
|
}
|
||||||
|
msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||||||
|
streamingHeaders(w)
|
||||||
|
if err := writeSSEEvent(w, flusher, "message_start", map[string]any{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]any{
|
||||||
|
"id": msgID,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []any{},
|
||||||
|
"model": model,
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]any{
|
||||||
|
"input_tokens": 0,
|
||||||
|
"output_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writeSSEEvent(w, flusher, "content_block_start", map[string]any{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": map[string]any{
|
||||||
|
"type": "text",
|
||||||
|
"text": "",
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsCh := events
|
||||||
|
doneCh := done
|
||||||
|
var final *service.ChatResult
|
||||||
|
var finalErr error
|
||||||
|
|
||||||
|
for eventsCh != nil || doneCh != nil {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case event, ok := <-eventsCh:
|
||||||
|
if !ok {
|
||||||
|
eventsCh = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if event.Delta == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeSSEEvent(w, flusher, "content_block_delta", map[string]any{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"type": "text_delta",
|
||||||
|
"text": event.Delta,
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case result, ok := <-doneCh:
|
||||||
|
if !ok {
|
||||||
|
doneCh = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
final = result.Result
|
||||||
|
finalErr = result.Err
|
||||||
|
doneCh = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalErr != nil {
|
||||||
|
_ = writeSSEEvent(w, flusher, "error", map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": finalErr.Error(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if final == nil {
|
||||||
|
_ = writeSSEEvent(w, flusher, "error", map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]any{
|
||||||
|
"type": "api_error",
|
||||||
|
"message": "stream finished without a final result",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writeSSEEvent(w, flusher, "content_block_stop", map[string]any{
|
||||||
|
"type": "content_block_stop",
|
||||||
|
"index": 0,
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writeSSEEvent(w, flusher, "message_delta", map[string]any{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": map[string]any{
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": nil,
|
||||||
|
},
|
||||||
|
"usage": map[string]any{
|
||||||
|
"output_tokens": final.OutputTokens,
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = writeSSEEvent(w, flusher, "message_stop", map[string]any{
|
||||||
|
"type": "message_stop",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleOpenAIStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) {
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
events, done, err := s.svc.GenerateStream(r.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
model := strings.TrimSpace(req.Model)
|
||||||
|
if model == "" {
|
||||||
|
model = "lingma"
|
||||||
|
}
|
||||||
|
chatID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
|
||||||
|
created := time.Now().Unix()
|
||||||
|
streamingHeaders(w)
|
||||||
|
if err := writeOpenAIChunk(w, flusher, map[string]any{
|
||||||
|
"id": chatID,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsCh := events
|
||||||
|
doneCh := done
|
||||||
|
var finalErr error
|
||||||
|
|
||||||
|
for eventsCh != nil || doneCh != nil {
|
||||||
|
select {
|
||||||
|
case <-r.Context().Done():
|
||||||
|
return
|
||||||
|
case event, ok := <-eventsCh:
|
||||||
|
if !ok {
|
||||||
|
eventsCh = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if event.Delta == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := writeOpenAIChunk(w, flusher, map[string]any{
|
||||||
|
"id": chatID,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{
|
||||||
|
"content": event.Delta,
|
||||||
|
},
|
||||||
|
"finish_reason": nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case result, ok := <-doneCh:
|
||||||
|
if !ok {
|
||||||
|
doneCh = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
finalErr = result.Err
|
||||||
|
doneCh = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalErr != nil {
|
||||||
|
_ = writeOpenAIChunk(w, flusher, map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"message": finalErr.Error(),
|
||||||
|
"type": "api_error",
|
||||||
|
"code": nil,
|
||||||
|
"param": nil,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writeOpenAIChunk(w, flusher, map[string]any{
|
||||||
|
"id": chatID,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": created,
|
||||||
|
"model": model,
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": map[string]any{},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeAnthropicRequest(req anthropicRequest) (service.ChatRequest, error) {
|
func normalizeAnthropicRequest(req anthropicRequest) (service.ChatRequest, error) {
|
||||||
messages := make([]service.ChatMessage, 0, len(req.Messages))
|
messages := make([]service.ChatMessage, 0, len(req.Messages))
|
||||||
for _, message := range req.Messages {
|
for _, message := range req.Messages {
|
||||||
@@ -372,6 +616,41 @@ func writeOpenAIError(w http.ResponseWriter, status int, kind string, message st
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func streamingHeaders(w http.ResponseWriter) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Accel-Buffering", "no")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSEEvent(w http.ResponseWriter, flusher http.Flusher, event string, payload any) error {
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOpenAIChunk(w http.ResponseWriter, flusher http.Flusher, payload any) error {
|
||||||
|
body, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func withCORS(next http.Handler) http.Handler {
|
func withCORS(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
|||||||
@@ -60,6 +60,15 @@ type ChatResult struct {
|
|||||||
EffectiveSession SessionMode
|
EffectiveSession SessionMode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type StreamEvent struct {
|
||||||
|
Delta string
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamResult struct {
|
||||||
|
Result *ChatResult
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -151,7 +160,39 @@ func (s *Service) ListModels(ctx context.Context) ([]Model, error) {
|
|||||||
func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) {
|
func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
return s.generateLocked(ctx, req, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) GenerateStream(ctx context.Context, req ChatRequest) (<-chan StreamEvent, <-chan StreamResult, error) {
|
||||||
|
events := make(chan StreamEvent, 256)
|
||||||
|
done := make(chan StreamResult, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
s.mu.Lock()
|
||||||
|
result, err := s.generateLocked(ctx, req, func(delta string) {
|
||||||
|
if delta == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case events <- StreamEvent{Delta: delta}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
})
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
close(events)
|
||||||
|
done <- StreamResult{Result: result, Err: err}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return events, done, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) generateLocked(
|
||||||
|
ctx context.Context,
|
||||||
|
req ChatRequest,
|
||||||
|
onDelta func(string),
|
||||||
|
) (*ChatResult, error) {
|
||||||
requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout)
|
requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -198,7 +239,7 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta)
|
runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta, onDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if effectiveMode == SessionModeReuse {
|
if effectiveMode == SessionModeReuse {
|
||||||
s.stickySessionID = ""
|
s.stickySessionID = ""
|
||||||
@@ -220,7 +261,18 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
|||||||
return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120))
|
return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120))
|
||||||
}
|
}
|
||||||
|
|
||||||
result := &ChatResult{
|
return s.buildChatResult(req, sessionID, requestID, prompt, runResult, effectiveMode), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) buildChatResult(
|
||||||
|
req ChatRequest,
|
||||||
|
sessionID string,
|
||||||
|
requestID string,
|
||||||
|
prompt string,
|
||||||
|
runResult *promptRunResult,
|
||||||
|
effectiveMode SessionMode,
|
||||||
|
) *ChatResult {
|
||||||
|
return &ChatResult{
|
||||||
Text: runResult.AssistantText,
|
Text: runResult.AssistantText,
|
||||||
Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
|
Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
|
||||||
InputTokens: estimateTokens(prompt),
|
InputTokens: estimateTokens(prompt),
|
||||||
@@ -234,7 +286,6 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
|||||||
PipePath: s.pipePath,
|
PipePath: s.pipePath,
|
||||||
EffectiveSession: effectiveMode,
|
EffectiveSession: effectiveMode,
|
||||||
}
|
}
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) {
|
func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) {
|
||||||
@@ -322,6 +373,7 @@ func (s *Service) runPromptLocked(
|
|||||||
text string,
|
text string,
|
||||||
requestID string,
|
requestID string,
|
||||||
meta map[string]any,
|
meta map[string]any,
|
||||||
|
onDelta func(string),
|
||||||
) (*promptRunResult, error) {
|
) (*promptRunResult, error) {
|
||||||
notifications, cancel := client.Subscribe()
|
notifications, cancel := client.Subscribe()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -367,6 +419,9 @@ func (s *Service) runPromptLocked(
|
|||||||
chunk := nestedString(nestedMap(update, "content"), "text")
|
chunk := nestedString(nestedMap(update, "content"), "text")
|
||||||
if chunk != "" {
|
if chunk != "" {
|
||||||
builder.WriteString(chunk)
|
builder.WriteString(chunk)
|
||||||
|
if onDelta != nil {
|
||||||
|
onDelta(chunk)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case "notification":
|
case "notification":
|
||||||
switch nestedString(update, "type") {
|
switch nestedString(update, "type") {
|
||||||
|
|||||||
Reference in New Issue
Block a user