feat: add streaming responses for anthropic and openai
This commit is contained in:
@@ -60,6 +60,15 @@ type ChatResult struct {
|
||||
EffectiveSession SessionMode
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
type StreamResult struct {
|
||||
Result *ChatResult
|
||||
Err error
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
@@ -151,7 +160,39 @@ func (s *Service) ListModels(ctx context.Context) ([]Model, error) {
|
||||
func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.generateLocked(ctx, req, nil)
|
||||
}
|
||||
|
||||
func (s *Service) GenerateStream(ctx context.Context, req ChatRequest) (<-chan StreamEvent, <-chan StreamResult, error) {
|
||||
events := make(chan StreamEvent, 256)
|
||||
done := make(chan StreamResult, 1)
|
||||
|
||||
go func() {
|
||||
s.mu.Lock()
|
||||
result, err := s.generateLocked(ctx, req, func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case events <- StreamEvent{Delta: delta}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
})
|
||||
s.mu.Unlock()
|
||||
|
||||
close(events)
|
||||
done <- StreamResult{Result: result, Err: err}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
return events, done, nil
|
||||
}
|
||||
|
||||
func (s *Service) generateLocked(
|
||||
ctx context.Context,
|
||||
req ChatRequest,
|
||||
onDelta func(string),
|
||||
) (*ChatResult, error) {
|
||||
requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -198,7 +239,7 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
||||
}
|
||||
}
|
||||
|
||||
runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta)
|
||||
runResult, err := s.runPromptLocked(requestCtx, ipcClient, sessionID, prompt, requestID, meta, onDelta)
|
||||
if err != nil {
|
||||
if effectiveMode == SessionModeReuse {
|
||||
s.stickySessionID = ""
|
||||
@@ -220,7 +261,18 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
||||
return nil, fmt.Errorf("Lingma IPC response remained incomplete before timeout. Partial reply: %s", truncate(runResult.AssistantText, 120))
|
||||
}
|
||||
|
||||
result := &ChatResult{
|
||||
return s.buildChatResult(req, sessionID, requestID, prompt, runResult, effectiveMode), nil
|
||||
}
|
||||
|
||||
func (s *Service) buildChatResult(
|
||||
req ChatRequest,
|
||||
sessionID string,
|
||||
requestID string,
|
||||
prompt string,
|
||||
runResult *promptRunResult,
|
||||
effectiveMode SessionMode,
|
||||
) *ChatResult {
|
||||
return &ChatResult{
|
||||
Text: runResult.AssistantText,
|
||||
Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
|
||||
InputTokens: estimateTokens(prompt),
|
||||
@@ -234,7 +286,6 @@ func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, e
|
||||
PipePath: s.pipePath,
|
||||
EffectiveSession: effectiveMode,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client, error) {
|
||||
@@ -322,6 +373,7 @@ func (s *Service) runPromptLocked(
|
||||
text string,
|
||||
requestID string,
|
||||
meta map[string]any,
|
||||
onDelta func(string),
|
||||
) (*promptRunResult, error) {
|
||||
notifications, cancel := client.Subscribe()
|
||||
defer cancel()
|
||||
@@ -367,6 +419,9 @@ func (s *Service) runPromptLocked(
|
||||
chunk := nestedString(nestedMap(update, "content"), "text")
|
||||
if chunk != "" {
|
||||
builder.WriteString(chunk)
|
||||
if onDelta != nil {
|
||||
onDelta(chunk)
|
||||
}
|
||||
}
|
||||
case "notification":
|
||||
switch nestedString(update, "type") {
|
||||
|
||||
Reference in New Issue
Block a user