diff --git a/internal/httpapi/server.go b/internal/httpapi/server.go index c0de7d6..e1010ed 100644 --- a/internal/httpapi/server.go +++ b/internal/httpapi/server.go @@ -150,10 +150,6 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request) writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error()) return } - if req.Stream { - writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported") - return - } normalized, err := normalizeAnthropicRequest(req) if err != nil { @@ -161,6 +157,11 @@ func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request) return } + if req.Stream { + s.handleAnthropicStream(w, r, normalized) + return + } + result, err := s.svc.Generate(r.Context(), normalized) if err != nil { writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error()) @@ -202,10 +203,6 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", err.Error()) return } - if req.Stream { - writeOpenAIError(w, http.StatusBadRequest, "invalid_request_error", "streaming is not supported") - return - } normalized, err := normalizeOpenAIRequest(req) if err != nil { @@ -213,6 +210,11 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ return } + if req.Stream { + s.handleOpenAIStream(w, r, normalized) + return + } + result, err := s.svc.Generate(r.Context(), normalized) if err != nil { writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error()) @@ -243,6 +245,248 @@ func (s *Server) handleOpenAIChatCompletions(w http.ResponseWriter, r *http.Requ }) } +func (s *Server) handleAnthropicStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) { + flusher, ok := w.(http.Flusher) + if !ok { + writeAnthropicError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server") + return + } + + events, done, err := s.svc.GenerateStream(r.Context(), req) + if err != nil { + writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + model := strings.TrimSpace(req.Model) + if model == "" { + model = "lingma" + } + msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano()) + streamingHeaders(w) + if err := writeSSEEvent(w, flusher, "message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": msgID, + "type": "message", + "role": "assistant", + "content": []any{}, + "model": model, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + }); err != nil { + return + } + if err := writeSSEEvent(w, flusher, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": 0, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }); err != nil { + return + } + + eventsCh := events + doneCh := done + var final *service.ChatResult + var finalErr error + + for eventsCh != nil || doneCh != nil { + select { + case <-r.Context().Done(): + return + case event, ok := <-eventsCh: + if !ok { + eventsCh = nil + continue + } + if event.Delta == "" { + continue + } + if err := writeSSEEvent(w, flusher, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": 0, + "delta": map[string]any{ + "type": "text_delta", + "text": event.Delta, + }, + }); err != nil { + return + } + case result, ok := <-doneCh: + if !ok { + doneCh = nil + continue + } + final = result.Result + finalErr = result.Err + doneCh = nil + } + } + + if finalErr != nil { + _ = writeSSEEvent(w, flusher, "error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": finalErr.Error(), + }, + }) + return + } + if final == nil { + _ = writeSSEEvent(w, flusher, "error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": "stream finished without a final result", + }, + }) + return + } + if err := writeSSEEvent(w, flusher, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": 0, + }); err != nil { + return + } + if err := writeSSEEvent(w, flusher, "message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": final.OutputTokens, + }, + }); err != nil { + return + } + _ = writeSSEEvent(w, flusher, "message_stop", map[string]any{ + "type": "message_stop", + }) +} + +func (s *Server) handleOpenAIStream(w http.ResponseWriter, r *http.Request, req service.ChatRequest) { + flusher, ok := w.(http.Flusher) + if !ok { + writeOpenAIError(w, http.StatusInternalServerError, "api_error", "streaming is not supported by this server") + return + } + + events, done, err := s.svc.GenerateStream(r.Context(), req) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "api_error", err.Error()) + return + } + + model := strings.TrimSpace(req.Model) + if model == "" { + model = "lingma" + } + chatID := fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()) + created := time.Now().Unix() + streamingHeaders(w) + if err := writeOpenAIChunk(w, flusher, map[string]any{ + "id": chatID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{ + { + "index": 0, + "delta": map[string]any{ + "role": "assistant", + }, + "finish_reason": nil, + }, + }, + }); err != nil { + return + } + + eventsCh := events + doneCh := done + var finalErr error + + for eventsCh != nil || doneCh != nil { + select { + case <-r.Context().Done(): + return + case event, ok := <-eventsCh: + if !ok { + eventsCh = nil + continue + } + if event.Delta == "" { + continue + } + if err := writeOpenAIChunk(w, flusher, map[string]any{ + "id": chatID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{ + { + "index": 0, + "delta": map[string]any{ + "content": event.Delta, + }, + "finish_reason": nil, + }, + }, + }); err != nil { + return + } + case result, ok := <-doneCh: + if !ok { + doneCh = nil + continue + } + finalErr = result.Err + doneCh = nil + } + } + + if finalErr != nil { + _ = writeOpenAIChunk(w, flusher, map[string]any{ + "error": map[string]any{ + "message": finalErr.Error(), + "type": "api_error", + "code": nil, + "param": nil, + }, + }) + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + return + } + if err := writeOpenAIChunk(w, flusher, map[string]any{ + "id": chatID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{ + { + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + }); err != nil { + return + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() +} + func normalizeAnthropicRequest(req anthropicRequest) (service.ChatRequest, error) { messages := make([]service.ChatMessage, 0, len(req.Messages)) for _, message := range req.Messages { @@ -372,6 +616,41 @@ func writeOpenAIError(w http.ResponseWriter, status int, kind string, message st }) } +func streamingHeaders(w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + w.WriteHeader(http.StatusOK) +} + +func writeSSEEvent(w http.ResponseWriter, flusher http.Flusher, event string, payload any) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil { + return err + } + flusher.Flush() + return nil +} + +func writeOpenAIChunk(w http.ResponseWriter, flusher http.Flusher, payload any) error { + body, err := json.Marshal(payload) + if err != nil { + return err + } + if _, err := fmt.Fprintf(w, "data: %s\n\n", body); err != nil { + return err + } + flusher.Flush() + return nil +} + func withCORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") diff --git a/internal/service/service.go b/internal/service/service.go index 615a89c..05bb53e 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -60,6 +60,15 @@ type ChatResult struct { EffectiveSession SessionMode } +type StreamEvent struct { + Delta string +} + +type StreamResult struct { + Result *ChatResult + Err error +} + type Model struct { ID string `json:"id"` Name string `json:"name"` @@ -151,7 +160,39 @@ func (s *Service) ListModels(ctx context.Context) ([]Model, error) { func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) { s.mu.Lock() defer s.mu.Unlock() + return s.generateLocked(ctx, req, nil) +} +func (s *Service) GenerateStream(ctx context.Context, req ChatRequest) (<-chan StreamEvent, <-chan StreamResult, error) { + events := make(chan StreamEvent, 256) + done := make(chan StreamResult, 1) + + go func() { + s.mu.Lock() + result, err := s.generateLocked(ctx, req, func(delta string) { + if delta == "" { + return + } + select { + case events <- StreamEvent{Delta: delta}: + case <-ctx.Done(): + } + }) + s.mu.Unlock() + + close(events) + done <- StreamResult{Result: result, Err: err} + close(done) + }() + + return events, done, nil +} + +func (s *Service) generateLocked( + ctx context.Context, + req ChatRequest, + onDelta func(string), +) (*ChatResult, error) { requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout) defer cancel() @@ -198,7 +239,7 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e } } - runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta) + runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta, onDelta) if err != nil { if effectiveMode == SessionModeReuse { s.stickySessionID = "" @@ -220,7 +261,18 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120)) } - result := &ChatResult{ + return s.buildChatResult(req, sessionID, requestID, prompt, runResult, effectiveMode), nil +} + +func (s *Service) buildChatResult( + req ChatRequest, + sessionID string, + requestID string, + prompt string, + runResult *promptRunResult, + effectiveMode SessionMode, +) *ChatResult { + return &ChatResult{ Text: runResult.AssistantText, Model: valueOr(strings.TrimSpace(req.Model), "lingma"), InputTokens: estimateTokens(prompt), @@ -234,7 +286,6 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e PipePath: s.pipePath, EffectiveSession: effectiveMode, } - return result, nil } func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) { @@ -322,6 +373,7 @@ func (s *Service) runPromptLocked( text string, requestID string, meta map[string]any, + onDelta func(string), ) (*promptRunResult, error) { notifications, cancel := client.Subscribe() defer cancel() @@ -367,6 +419,9 @@ func (s *Service) runPromptLocked( chunk := nestedString(nestedMap(update, "content"), "text") if chunk != "" { builder.WriteString(chunk) + if onDelta != nil { + onDelta(chunk) + } } case "notification": switch nestedString(update, "type") {