feat: add streaming responses for anthropic and openai

This commit is contained in:
coolxll
2026-03-25 22:13:40 +08:00
parent 585c3ba5ab
commit a3907959c1
2 changed files with 345 additions and 11 deletions

View File

@@ -150,10 +150,6 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request)
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error()) writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
return return
} }
if req.Stream {
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported")
return
}
normalized, err := normalizeAnthropicRequest(req) normalized, err := normalizeAnthropicRequest(req)
if err != nil { if err != nil {
@@ -161,6 +157,11 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request)
return return
} }
if req.Stream {
s.handleAnthropicStream(w, r, normalized)
return
}
result, err := s.svc.Generate(r.Context(), normalized) result, err := s.svc.Generate(r.Context(), normalized)
if err != nil { if err != nil {
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error()) writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
@@ -202,10 +203,6 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", err.Error()) writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
return return
} }
if req.Stream {
writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported")
return
}
normalized, err := normalizeOpenAIRequest(req) normalized, err := normalizeOpenAIRequest(req)
if err != nil { if err != nil {
@@ -213,6 +210,11 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
return return
} }
if req.Stream {
s.handleOpenAIStream(w, r, normalized)
return
}
result, err := s.svc.Generate(r.Context(), normalized) result, err := s.svc.Generate(r.Context(), normalized)
if err != nil { if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error()) writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error())
@@ -243,6 +245,248 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ
}) })
} }
func (s *Server) handleAnthropicStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) {
flusher, ok := w.(http.Flusher)
if !ok {
writeAnthropicError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server")
return
}
events, done, err := s.svc.GenerateStream(r.Context(), req)
if err != nil {
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
return
}
model := strings.TrimSpace(req.Model)
if model == "" {
model = "lingma"
}
msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
streamingHeaders(w)
if err := writeSSEEvent(w, flusher, "message_start", map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"content": []any{},
"model": model,
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": 0,
"output_tokens": 0,
},
},
}); err != nil {
return
}
if err := writeSSEEvent(w, flusher, "content_block_start", map[string]any{
"type": "content_block_start",
"index": 0,
"content_block": map[string]any{
"type": "text",
"text": "",
},
}); err != nil {
return
}
eventsCh := events
doneCh := done
var final *service.ChatResult
var finalErr error
for eventsCh != nil || doneCh != nil {
select {
case <-r.Context().Done():
return
case event, ok := <-eventsCh:
if !ok {
eventsCh = nil
continue
}
if event.Delta == "" {
continue
}
if err := writeSSEEvent(w, flusher, "content_block_delta", map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]any{
"type": "text_delta",
"text": event.Delta,
},
}); err != nil {
return
}
case result, ok := <-doneCh:
if !ok {
doneCh = nil
continue
}
final = result.Result
finalErr = result.Err
doneCh = nil
}
}
if finalErr != nil {
_ = writeSSEEvent(w, flusher, "error", map[string]any{
"type": "error",
"error": map[string]any{
"type": "api_error",
"message": finalErr.Error(),
},
})
return
}
if final == nil {
_ = writeSSEEvent(w, flusher, "error", map[string]any{
"type": "error",
"error": map[string]any{
"type": "api_error",
"message": "stream finished without a final result",
},
})
return
}
if err := writeSSEEvent(w, flusher, "content_block_stop", map[string]any{
"type": "content_block_stop",
"index": 0,
}); err != nil {
return
}
if err := writeSSEEvent(w, flusher, "message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]any{
"output_tokens": final.OutputTokens,
},
}); err != nil {
return
}
_ = writeSSEEvent(w, flusher, "message_stop", map[string]any{
"type": "message_stop",
})
}
func (s *Server) handleOpenAIStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) {
flusher, ok := w.(http.Flusher)
if !ok {
writeOpenAIError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server")
return
}
events, done, err := s.svc.GenerateStream(r.Context(), req)
if err != nil {
writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error())
return
}
model := strings.TrimSpace(req.Model)
if model == "" {
model = "lingma"
}
chatID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
created := time.Now().Unix()
streamingHeaders(w)
if err := writeOpenAIChunk(w, flusher, map[string]any{
"id": chatID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{
"role": "assistant",
},
"finish_reason": nil,
},
},
}); err != nil {
return
}
eventsCh := events
doneCh := done
var finalErr error
for eventsCh != nil || doneCh != nil {
select {
case <-r.Context().Done():
return
case event, ok := <-eventsCh:
if !ok {
eventsCh = nil
continue
}
if event.Delta == "" {
continue
}
if err := writeOpenAIChunk(w, flusher, map[string]any{
"id": chatID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{
"content": event.Delta,
},
"finish_reason": nil,
},
},
}); err != nil {
return
}
case result, ok := <-doneCh:
if !ok {
doneCh = nil
continue
}
finalErr = result.Err
doneCh = nil
}
}
if finalErr != nil {
_ = writeOpenAIChunk(w, flusher, map[string]any{
"error": map[string]any{
"message": finalErr.Error(),
"type": "api_error",
"code": nil,
"param": nil,
},
})
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
return
}
if err := writeOpenAIChunk(w, flusher, map[string]any{
"id": chatID,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": []map[string]any{
{
"index": 0,
"delta": map[string]any{},
"finish_reason": "stop",
},
},
}); err != nil {
return
}
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}
func normalizeAnthropicRequest(req anthropicRequest) (service.ChatRequest, error) { func normalizeAnthropicRequest(req anthropicRequest) (service.ChatRequest, error) {
messages := make([]service.ChatMessage, 0, len(req.Messages)) messages := make([]service.ChatMessage, 0, len(req.Messages))
for _, message := range req.Messages { for _, message := range req.Messages {
@@ -372,6 +616,41 @@ func writeOpenAIError(w http.ResponseWriter, status int, kind string, message st
}) })
} }
func streamingHeaders(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusOK)
}
func writeSSEEvent(w http.ResponseWriter, flusher http.Flusher, event string, payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil {
return err
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil {
return err
}
flusher.Flush()
return nil
}
func writeOpenAIChunk(w http.ResponseWriter, flusher http.Flusher, payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil {
return err
}
flusher.Flush()
return nil
}
func withCORS(next http.Handler) http.Handler { func withCORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")

View File

@@ -60,6 +60,15 @@ type ChatResult struct {
EffectiveSession SessionMode EffectiveSession SessionMode
} }
type StreamEvent struct {
Delta string
}
type StreamResult struct {
Result *ChatResult
Err error
}
type Model struct { type Model struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
@@ -151,7 +160,39 @@ func (s *Service) ListModels(ctx context.Context) ([]Model, error) {
func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) { func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
return s.generateLocked(ctx, req, nil)
}
func (s *Service) GenerateStream(ctx context.Context, req ChatRequest) (<-chan StreamEvent, <-chan StreamResult, error) {
events := make(chan StreamEvent, 256)
done := make(chan StreamResult, 1)
go func() {
s.mu.Lock()
result, err := s.generateLocked(ctx, req, func(delta string) {
if delta == "" {
return
}
select {
case events <- StreamEvent{Delta: delta}:
case <-ctx.Done():
}
})
s.mu.Unlock()
close(events)
done <- StreamResult{Result: result, Err: err}
close(done)
}()
return events, done, nil
}
func (s *Service) generateLocked(
ctx context.Context,
req ChatRequest,
onDelta func(string),
) (*ChatResult, error) {
requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout) requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout)
defer cancel() defer cancel()
@@ -198,7 +239,7 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
} }
} }
runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta) runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta, onDelta)
if err != nil { if err != nil {
if effectiveMode == SessionModeReuse { if effectiveMode == SessionModeReuse {
s.stickySessionID = "" s.stickySessionID = ""
@@ -220,7 +261,18 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120)) return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120))
} }
result := &ChatResult{ return s.buildChatResult(req, sessionID, requestID, prompt, runResult, effectiveMode), nil
}
func (s *Service) buildChatResult(
req ChatRequest,
sessionID string,
requestID string,
prompt string,
runResult *promptRunResult,
effectiveMode SessionMode,
) *ChatResult {
return &ChatResult{
Text: runResult.AssistantText, Text: runResult.AssistantText,
Model: valueOr(strings.TrimSpace(req.Model), "lingma"), Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
InputTokens: estimateTokens(prompt), InputTokens: estimateTokens(prompt),
@@ -234,7 +286,6 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
PipePath: s.pipePath, PipePath: s.pipePath,
EffectiveSession: effectiveMode, EffectiveSession: effectiveMode,
} }
return result, nil
} }
func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) { func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) {
@@ -322,6 +373,7 @@ func (s *Service) runPromptLocked(
text string, text string,
requestID string, requestID string,
meta map[string]any, meta map[string]any,
onDelta func(string),
) (*promptRunResult, error) { ) (*promptRunResult, error) {
notifications, cancel := client.Subscribe() notifications, cancel := client.Subscribe()
defer cancel() defer cancel()
@@ -367,6 +419,9 @@ func (s *Service) runPromptLocked(
chunk := nestedString(nestedMap(update, "content"), "text") chunk := nestedString(nestedMap(update, "content"), "text")
if chunk != "" { if chunk != "" {
builder.WriteString(chunk) builder.WriteString(chunk)
if onDelta != nil {
onDelta(chunk)
}
} }
case "notification": case "notification":
switch nestedString(update, "type") { switch nestedString(update, "type") {