Add experimental Lingma remote backend
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lingma-ipc-proxy/internal/service"
|
||||
@@ -23,9 +24,11 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
svc *service.Service
|
||||
http *http.Server
|
||||
sem chan struct{}
|
||||
svc *service.Service
|
||||
http *http.Server
|
||||
sem chan struct{}
|
||||
recMu sync.RWMutex
|
||||
records []debugRequestRecord
|
||||
// OnRequest is called after each request completes with summary info.
|
||||
// method, path, statusCode, duration, requestBody, responseBody
|
||||
OnRequest func(method, path string, statusCode int, duration time.Duration, reqBody, respBody string)
|
||||
@@ -84,6 +87,16 @@ type modelResponse struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type debugRequestRecord struct {
|
||||
Time string `json:"time"`
|
||||
Method string `json:"method"`
|
||||
Path string `json:"path"`
|
||||
StatusCode int `json:"statusCode"`
|
||||
DurationMS int64 `json:"durationMs"`
|
||||
Request string `json:"request,omitempty"`
|
||||
Response string `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(addr string, svc *service.Service) *Server {
|
||||
s := &Server{
|
||||
svc: svc,
|
||||
@@ -92,6 +105,10 @@ func NewServer(addr string, svc *service.Service) *Server {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/", s.handleRoot)
|
||||
mux.HandleFunc("/health", s.handleRoot)
|
||||
mux.HandleFunc("/debug/requests", s.handleDebugRequests)
|
||||
mux.HandleFunc("/debug/logs", s.handleDebugRequests)
|
||||
mux.HandleFunc("/api/requests", s.handleDebugRequests)
|
||||
mux.HandleFunc("/api/logs", s.handleDebugRequests)
|
||||
mux.HandleFunc("/capabilities", s.handleCapabilities)
|
||||
mux.HandleFunc("/v1/capabilities", s.handleCapabilities)
|
||||
mux.HandleFunc("/v1/models", s.handleModels)
|
||||
@@ -151,6 +168,10 @@ func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodHead {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
writeOpenAIError(w, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
@@ -162,6 +183,44 @@ func (s *Server) handleRoot(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleDebugRequests(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodHead {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
writeOpenAIError(w, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 50
|
||||
if raw := strings.TrimSpace(r.URL.Query().Get("limit")); raw != "" {
|
||||
if parsed, err := strconv.Atoi(raw); err == nil {
|
||||
switch {
|
||||
case parsed < 1:
|
||||
limit = 1
|
||||
case parsed > 200:
|
||||
limit = 200
|
||||
default:
|
||||
limit = parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
records := s.debugRecords(limit)
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"ok": true,
|
||||
"service": "lingma-ipc-proxy",
|
||||
"count": len(records),
|
||||
"requests": records,
|
||||
"state": s.svc.State(),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
@@ -552,6 +611,101 @@ func (s *Server) handleAnthropicStream(w http.ResponseWriter, r *http.Request, r
|
||||
}
|
||||
msgID := fmt.Sprintf("msg_%d", time.Now().UnixNano())
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
result, err := s.svc.Generate(r.Context(), req)
|
||||
if err != nil {
|
||||
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
streamingHeaders(w)
|
||||
if err := writeSSEEvent(w, flusher, "message_start", map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []any{},
|
||||
"model": model,
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": result.InputTokens,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
index := 0
|
||||
if strings.TrimSpace(result.Text) != "" {
|
||||
if err := writeSSEEvent(w, flusher, "content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": map[string]any{"type": "text", "text": ""},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
if err := writeSSEEvent(w, flusher, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{"type": "text_delta", "text": result.Text},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
if err := writeSSEEvent(w, flusher, "content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
for _, tc := range result.ToolCalls {
|
||||
if err := writeSSEEvent(w, flusher, "content_block_start", map[string]any{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": map[string]any{"type": "tool_use", "id": tc.ID, "name": tc.Name, "input": map[string]any{}},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
if err := writeSSEEvent(w, flusher, "content_block_delta", map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": map[string]any{"type": "input_json_delta", "partial_json": string(argsJSON)},
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
if err := writeSSEEvent(w, flusher, "content_block_stop", map[string]any{
|
||||
"type": "content_block_stop",
|
||||
"index": index,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
index++
|
||||
}
|
||||
|
||||
stopReason := "end_turn"
|
||||
if len(result.ToolCalls) > 0 {
|
||||
stopReason = "tool_use"
|
||||
}
|
||||
_ = writeSSEEvent(w, flusher, "message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]any{
|
||||
"output_tokens": result.OutputTokens,
|
||||
},
|
||||
})
|
||||
_ = writeSSEEvent(w, flusher, "message_stop", map[string]any{"type": "message_stop"})
|
||||
return
|
||||
}
|
||||
|
||||
events, done, err := s.svc.GenerateStream(r.Context(), req)
|
||||
if err != nil {
|
||||
writeAnthropicError(w, http.StatusInternalServerError, "api_error", err.Error())
|
||||
@@ -1141,10 +1295,11 @@ func (rw *recordingResponseWriter) Flush() {
|
||||
|
||||
func (s *Server) withRecorder(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if s.OnRequest == nil {
|
||||
if isDebugInspectionPath(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Read request body for recording, then restore for downstream handler
|
||||
@@ -1161,10 +1316,54 @@ func (s *Server) withRecorder(next http.Handler) http.Handler {
|
||||
|
||||
respBody := sanitizeRecordedBody(rw.body)
|
||||
|
||||
go s.OnRequest(r.Method, r.URL.Path, rw.statusCode, duration, reqBody, respBody)
|
||||
s.recordRequest(r.Method, r.URL.Path, rw.statusCode, duration, reqBody, respBody)
|
||||
if s.OnRequest != nil {
|
||||
go s.OnRequest(r.Method, r.URL.Path, rw.statusCode, duration, reqBody, respBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isDebugInspectionPath(path string) bool {
|
||||
switch path {
|
||||
case "/debug/requests", "/debug/logs", "/api/requests", "/api/logs":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) recordRequest(method, path string, statusCode int, duration time.Duration, reqBody, respBody string) {
|
||||
s.recMu.Lock()
|
||||
defer s.recMu.Unlock()
|
||||
|
||||
s.records = append(s.records, debugRequestRecord{
|
||||
Time: time.Now().Format(time.RFC3339),
|
||||
Method: method,
|
||||
Path: path,
|
||||
StatusCode: statusCode,
|
||||
DurationMS: duration.Milliseconds(),
|
||||
Request: reqBody,
|
||||
Response: respBody,
|
||||
})
|
||||
if len(s.records) > 200 {
|
||||
s.records = s.records[len(s.records)-200:]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) debugRecords(limit int) []debugRequestRecord {
|
||||
s.recMu.RLock()
|
||||
defer s.recMu.RUnlock()
|
||||
|
||||
if limit > len(s.records) {
|
||||
limit = len(s.records)
|
||||
}
|
||||
out := make([]debugRequestRecord, 0, limit)
|
||||
for i := len(s.records) - 1; i >= 0 && len(out) < limit; i-- {
|
||||
out = append(out, s.records[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func sanitizeRecordedBody(body []byte) string {
|
||||
if len(body) == 0 {
|
||||
return ""
|
||||
@@ -1254,7 +1453,7 @@ func truncateRecordedString(value string) string {
|
||||
func withCORS(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, HEAD, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key, anthropic-version")
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
|
||||
464
internal/remote/client.go
Normal file
464
internal/remote/client.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseURL = "https://lingma.alibabacloud.com"
|
||||
chatPath = "/algo/api/v2/service/pro/sse/agent_chat_generation"
|
||||
chatQuery = "?FetchKeys=llm_model_result&AgentId=agent_common"
|
||||
modelListPath = "/algo/api/v2/model/list"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
BaseURL string
|
||||
AuthFile string
|
||||
CosyVersion string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
cfg Config
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Key string `json:"key"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Model string `json:"model"`
|
||||
Enable bool `json:"enable"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Stream bool
|
||||
Temperature *float64
|
||||
}
|
||||
|
||||
type ChatResult struct {
|
||||
Text string
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
RequestID string
|
||||
CredentialSrc string
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
func New(cfg Config) *Client {
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = ResolveBaseURL("")
|
||||
}
|
||||
if cfg.CosyVersion == "" {
|
||||
cfg.CosyVersion = "2.11.2"
|
||||
}
|
||||
if cfg.Timeout <= 0 {
|
||||
cfg.Timeout = 120 * time.Second
|
||||
}
|
||||
cfg.BaseURL = strings.TrimRight(cfg.BaseURL, "/")
|
||||
return &Client{cfg: cfg, client: &http.Client{Timeout: cfg.Timeout}}
|
||||
}
|
||||
|
||||
func ResolveBaseURL(explicit string) string {
|
||||
if strings.TrimSpace(explicit) != "" {
|
||||
return strings.TrimRight(strings.TrimSpace(explicit), "/")
|
||||
}
|
||||
if value := strings.TrimSpace(os.Getenv("LINGMA_REMOTE_BASE_URL")); value != "" {
|
||||
return strings.TrimRight(value, "/")
|
||||
}
|
||||
for _, path := range candidateConfigFiles() {
|
||||
if value := readBaseURLHint(path); value != "" {
|
||||
return strings.TrimRight(value, "/")
|
||||
}
|
||||
}
|
||||
return DefaultBaseURL
|
||||
}
|
||||
|
||||
func (c *Client) Warmup(ctx context.Context) error {
|
||||
_, err := LoadCredential(c.cfg.AuthFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer cancel()
|
||||
_, err = c.ListModels(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
|
||||
cred, err := LoadCredential(c.cfg.AuthFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers, err := c.headers(cred, modelListPath, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.BaseURL+modelListPath, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("remote model list status %d: %s", resp.StatusCode, truncate(string(body), 500))
|
||||
}
|
||||
var payload struct {
|
||||
Chat []Model `json:"chat"`
|
||||
Inline []Model `json:"inline"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return append(payload.Chat, payload.Inline...), nil
|
||||
}
|
||||
|
||||
func (c *Client) Chat(ctx context.Context, request ChatRequest, onDelta func(string)) (*ChatResult, error) {
|
||||
cred, err := LoadCredential(c.cfg.AuthFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestID := newHexID()
|
||||
body, err := c.buildBody(requestID, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
headers, err := c.headers(cred, chatPath, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.BaseURL+chatPath+chatQuery, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("remote chat status %d: %s", resp.StatusCode, truncate(string(respBody), 1000))
|
||||
}
|
||||
var builder strings.Builder
|
||||
if err := scanSSE(resp.Body, func(event sseEvent) error {
|
||||
if event.Done {
|
||||
return nil
|
||||
}
|
||||
if event.Content == "" {
|
||||
return nil
|
||||
}
|
||||
builder.WriteString(event.Content)
|
||||
if onDelta != nil {
|
||||
onDelta(event.Content)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
text := builder.String()
|
||||
return &ChatResult{
|
||||
Text: text,
|
||||
InputTokens: estimateTokens(request.Prompt),
|
||||
OutputTokens: estimateTokens(text),
|
||||
RequestID: requestID,
|
||||
CredentialSrc: cred.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) buildBody(requestID string, request ChatRequest) (string, error) {
|
||||
temperature := 0.1
|
||||
if request.Temperature != nil {
|
||||
temperature = *request.Temperature
|
||||
}
|
||||
model := strings.TrimSpace(request.Model)
|
||||
if strings.EqualFold(model, "auto") {
|
||||
model = ""
|
||||
}
|
||||
payload := map[string]any{
|
||||
"request_id": requestID,
|
||||
"request_set_id": "",
|
||||
"chat_record_id": requestID,
|
||||
"stream": true,
|
||||
"image_urls": nil,
|
||||
"is_reply": false,
|
||||
"is_retry": false,
|
||||
"session_id": "",
|
||||
"code_language": "",
|
||||
"source": 0,
|
||||
"version": "3",
|
||||
"chat_prompt": "",
|
||||
"parameters": map[string]float64{"temperature": temperature},
|
||||
"aliyun_user_type": "personal_standard",
|
||||
"agent_id": "agent_common",
|
||||
"task_id": "question_refine",
|
||||
"model_config": map[string]any{
|
||||
"key": model,
|
||||
"display_name": "",
|
||||
"model": model,
|
||||
"format": "",
|
||||
"is_vl": false,
|
||||
"is_reasoning": false,
|
||||
"api_key": "",
|
||||
"url": "",
|
||||
"source": "",
|
||||
"enable": false,
|
||||
},
|
||||
"messages": []map[string]any{{
|
||||
"role": "user",
|
||||
"content": request.Prompt,
|
||||
"response_meta": map[string]any{
|
||||
"id": "",
|
||||
"usage": map[string]int{
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
},
|
||||
"reasoning_content_signature": "",
|
||||
}},
|
||||
"business": map[string]any{
|
||||
"product": "jb_plugin",
|
||||
"version": c.cfg.CosyVersion,
|
||||
"type": "memory",
|
||||
"id": newUUID(),
|
||||
"begin_at": time.Now().UnixMilli(),
|
||||
"stage": "start",
|
||||
"name": "memory_intent_recognition_" + requestID,
|
||||
},
|
||||
}
|
||||
body, err := json.Marshal(payload)
|
||||
return string(body), err
|
||||
}
|
||||
|
||||
func (c *Client) headers(cred Credential, path string, body string) (map[string]string, error) {
|
||||
if err := validateCredential(cred); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
date := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
authPayload := map[string]string{
|
||||
"cosyVersion": c.cfg.CosyVersion,
|
||||
"ideVersion": "",
|
||||
"info": cred.EncryptUserInfo,
|
||||
"requestId": newUUID(),
|
||||
"version": "v1",
|
||||
}
|
||||
authPayloadBytes, err := json.Marshal(authPayload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payloadBase64 := base64.StdEncoding.EncodeToString(authPayloadBytes)
|
||||
preimage := strings.Join([]string{
|
||||
payloadBase64,
|
||||
cred.CosyKey,
|
||||
date,
|
||||
body,
|
||||
normalizePath(path),
|
||||
}, "\n")
|
||||
signature := md5.Sum([]byte(preimage))
|
||||
return map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer COSY.%s.%x", payloadBase64, signature),
|
||||
"Content-Type": "application/json",
|
||||
"Appcode": "cosy",
|
||||
"Cosy-Date": date,
|
||||
"Cosy-Key": cred.CosyKey,
|
||||
"Cosy-Machineid": cred.MachineID,
|
||||
"Cosy-User": cred.UserID,
|
||||
"Cosy-Clientip": "198.18.0.1",
|
||||
"Cosy-Clienttype": "2",
|
||||
"Cosy-Machineos": "x86_64_windows",
|
||||
"Cosy-Machinetoken": "",
|
||||
"Cosy-Machinetype": "",
|
||||
"Cosy-Version": c.cfg.CosyVersion,
|
||||
"Login-Version": "v2",
|
||||
"User-Agent": "lingma-ipc-proxy/remote",
|
||||
"Accept": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizePath(path string) string {
|
||||
return strings.TrimPrefix(path, "/algo")
|
||||
}
|
||||
|
||||
type outerSSE struct {
|
||||
Body string `json:"body"`
|
||||
StatusCode int `json:"statusCodeValue"`
|
||||
}
|
||||
|
||||
type innerSSE struct {
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type sseEvent struct {
|
||||
Content string
|
||||
Done bool
|
||||
}
|
||||
|
||||
func scanSSE(reader io.Reader, onEvent func(sseEvent) error) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "[DONE]" {
|
||||
return onEvent(sseEvent{Done: true})
|
||||
}
|
||||
event, ok, err := parseSSEPayload(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := onEvent(event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func parseSSEPayload(payload string) (sseEvent, bool, error) {
|
||||
var outer outerSSE
|
||||
if err := json.Unmarshal([]byte(payload), &outer); err != nil {
|
||||
return sseEvent{}, false, err
|
||||
}
|
||||
if outer.StatusCode >= 400 {
|
||||
return sseEvent{}, false, fmt.Errorf("remote sse status %d", outer.StatusCode)
|
||||
}
|
||||
if outer.Body == "" {
|
||||
return sseEvent{}, false, nil
|
||||
}
|
||||
if outer.Body == "[DONE]" {
|
||||
return sseEvent{Done: true}, true, nil
|
||||
}
|
||||
var inner innerSSE
|
||||
if err := json.Unmarshal([]byte(outer.Body), &inner); err != nil {
|
||||
return sseEvent{}, false, err
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, choice := range inner.Choices {
|
||||
builder.WriteString(choice.Delta.Content)
|
||||
}
|
||||
return sseEvent{Content: builder.String()}, true, nil
|
||||
}
|
||||
|
||||
func candidateConfigFiles() []string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return []string{
|
||||
filepath.Join(home, ".lingma", "extension", "server", "config.json"),
|
||||
filepath.Join(home, ".lingma", "extension", "local", "config.json"),
|
||||
filepath.Join(home, ".lingma", "bin", "config.json"),
|
||||
filepath.Join(home, ".config", "lingma-ipc-proxy", "config.json"),
|
||||
}
|
||||
}
|
||||
|
||||
func readBaseURLHint(path string) string {
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(body, &value); err != nil {
|
||||
text := string(body)
|
||||
if strings.Contains(text, "lingma.alibabacloud.com") {
|
||||
return DefaultBaseURL
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return findBaseURL(value)
|
||||
}
|
||||
|
||||
func findBaseURL(value any) string {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
for key, item := range typed {
|
||||
lower := strings.ToLower(key)
|
||||
if strings.Contains(lower, "base") || strings.Contains(lower, "domain") || strings.Contains(lower, "url") {
|
||||
if text, ok := item.(string); ok && strings.HasPrefix(strings.TrimSpace(text), "http") && strings.Contains(text, "lingma") {
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
}
|
||||
if nested := findBaseURL(item); nested != "" {
|
||||
return nested
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range typed {
|
||||
if nested := findBaseURL(item); nested != "" {
|
||||
return nested
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func estimateTokens(text string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
return len([]rune(text)) / 4
|
||||
}
|
||||
|
||||
func truncate(value string, max int) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if len(value) <= max {
|
||||
return value
|
||||
}
|
||||
return value[:max] + "... [truncated]"
|
||||
}
|
||||
|
||||
func expandHome(path string) string {
|
||||
if strings.HasPrefix(path, "~/") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
return filepath.Join(home, strings.TrimPrefix(path, "~/"))
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func valueOr(value string, fallback string) string {
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
var hexCounter uint64
|
||||
205
internal/remote/credentials.go
Normal file
205
internal/remote/credentials.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Credential struct {
|
||||
CosyKey string
|
||||
EncryptUserInfo string
|
||||
UserID string
|
||||
MachineID string
|
||||
Source string
|
||||
TokenExpireTime int64
|
||||
}
|
||||
|
||||
type storedCredentialFile struct {
|
||||
Source string `json:"source"`
|
||||
TokenExpireTime string `json:"token_expire_time"`
|
||||
Auth struct {
|
||||
CosyKey string `json:"cosy_key"`
|
||||
EncryptUserInfo string `json:"encrypt_user_info"`
|
||||
UserID string `json:"user_id"`
|
||||
MachineID string `json:"machine_id"`
|
||||
} `json:"auth"`
|
||||
}
|
||||
|
||||
func LoadCredential(authFile string) (Credential, error) {
|
||||
if path := strings.TrimSpace(authFile); path != "" {
|
||||
return loadCredentialFile(expandHome(path))
|
||||
}
|
||||
return importLingmaCacheCredential()
|
||||
}
|
||||
|
||||
func loadCredentialFile(path string) (Credential, error) {
|
||||
body, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return Credential{}, fmt.Errorf("read remote auth file: %w", err)
|
||||
}
|
||||
var stored storedCredentialFile
|
||||
if err := json.Unmarshal(body, &stored); err != nil {
|
||||
return Credential{}, fmt.Errorf("parse remote auth file: %w", err)
|
||||
}
|
||||
cred := Credential{
|
||||
CosyKey: stored.Auth.CosyKey,
|
||||
EncryptUserInfo: stored.Auth.EncryptUserInfo,
|
||||
UserID: stored.Auth.UserID,
|
||||
MachineID: stored.Auth.MachineID,
|
||||
Source: valueOr(stored.Source, path),
|
||||
TokenExpireTime: parseExpire(stored.TokenExpireTime),
|
||||
}
|
||||
return cred, validateCredential(cred)
|
||||
}
|
||||
|
||||
func importLingmaCacheCredential() (Credential, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return Credential{}, err
|
||||
}
|
||||
lingmaDir := filepath.Join(home, ".lingma")
|
||||
machineID, err := loadMachineID(lingmaDir)
|
||||
if err != nil {
|
||||
return Credential{}, err
|
||||
}
|
||||
encrypted, err := os.ReadFile(filepath.Join(lingmaDir, "cache", "user"))
|
||||
if err != nil {
|
||||
return Credential{}, fmt.Errorf("read ~/.lingma/cache/user: %w", err)
|
||||
}
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(encrypted)))
|
||||
if err != nil {
|
||||
return Credential{}, fmt.Errorf("decode ~/.lingma/cache/user: %w", err)
|
||||
}
|
||||
plaintext, err := decryptCacheUser(machineID, ciphertext)
|
||||
if err != nil {
|
||||
return Credential{}, err
|
||||
}
|
||||
var payload struct {
|
||||
Key string `json:"key"`
|
||||
EncryptUserInfo string `json:"encrypt_user_info"`
|
||||
UserID string `json:"uid"`
|
||||
ExpireTime any `json:"expire_time"`
|
||||
}
|
||||
if err := json.Unmarshal(plaintext, &payload); err != nil {
|
||||
return Credential{}, fmt.Errorf("parse ~/.lingma/cache/user: %w", err)
|
||||
}
|
||||
cred := Credential{
|
||||
CosyKey: payload.Key,
|
||||
EncryptUserInfo: payload.EncryptUserInfo,
|
||||
UserID: payload.UserID,
|
||||
MachineID: machineID,
|
||||
Source: "~/.lingma/cache/user",
|
||||
TokenExpireTime: parseExpireAny(payload.ExpireTime),
|
||||
}
|
||||
return cred, validateCredential(cred)
|
||||
}
|
||||
|
||||
func loadMachineID(lingmaDir string) (string, error) {
|
||||
if body, err := os.ReadFile(filepath.Join(lingmaDir, "cache", "id")); err == nil {
|
||||
if value := strings.TrimSpace(string(body)); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
logBody, err := os.ReadFile(filepath.Join(lingmaDir, "logs", "lingma.log"))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("remote credential requires ~/.lingma/cache/id or lingma.log machine id: %w", err)
|
||||
}
|
||||
markers := []string{"using machine id from file:", "machine id:"}
|
||||
text := string(logBody)
|
||||
for _, marker := range markers {
|
||||
index := strings.LastIndex(strings.ToLower(text), marker)
|
||||
if index < 0 {
|
||||
continue
|
||||
}
|
||||
line := text[index+len(marker):]
|
||||
if newline := strings.IndexByte(line, '\n'); newline >= 0 {
|
||||
line = line[:newline]
|
||||
}
|
||||
if value := strings.TrimSpace(line); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("machine id not found in ~/.lingma cache")
|
||||
}
|
||||
|
||||
func decryptCacheUser(machineID string, ciphertext []byte) ([]byte, error) {
|
||||
if len(machineID) < aes.BlockSize {
|
||||
return nil, errors.New("machine id too short for cache decryption")
|
||||
}
|
||||
if len(ciphertext) == 0 || len(ciphertext)%aes.BlockSize != 0 {
|
||||
return nil, errors.New("invalid cache/user ciphertext size")
|
||||
}
|
||||
key := []byte(machineID[:aes.BlockSize])
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
cipher.NewCBCDecrypter(block, key).CryptBlocks(plaintext, ciphertext)
|
||||
return unpadPKCS7(plaintext)
|
||||
}
|
||||
|
||||
func unpadPKCS7(data []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("empty plaintext")
|
||||
}
|
||||
padLen := int(data[len(data)-1])
|
||||
if padLen <= 0 || padLen > aes.BlockSize || padLen > len(data) {
|
||||
return nil, errors.New("invalid cache/user padding")
|
||||
}
|
||||
for _, b := range data[len(data)-padLen:] {
|
||||
if int(b) != padLen {
|
||||
return nil, errors.New("invalid cache/user padding bytes")
|
||||
}
|
||||
}
|
||||
return data[:len(data)-padLen], nil
|
||||
}
|
||||
|
||||
func validateCredential(cred Credential) error {
|
||||
if strings.TrimSpace(cred.CosyKey) == "" {
|
||||
return errors.New("remote credential missing cosy_key")
|
||||
}
|
||||
if strings.TrimSpace(cred.EncryptUserInfo) == "" {
|
||||
return errors.New("remote credential missing encrypt_user_info")
|
||||
}
|
||||
if strings.TrimSpace(cred.UserID) == "" {
|
||||
return errors.New("remote credential missing user_id")
|
||||
}
|
||||
if strings.TrimSpace(cred.MachineID) == "" {
|
||||
return errors.New("remote credential missing machine_id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseExpire(value string) int64 {
|
||||
parsed, _ := strconv.ParseInt(strings.TrimSpace(value), 10, 64)
|
||||
return parsed
|
||||
}
|
||||
|
||||
func parseExpireAny(value any) int64 {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return parseExpire(typed)
|
||||
case float64:
|
||||
return int64(typed)
|
||||
case int64:
|
||||
return typed
|
||||
case int:
|
||||
return int64(typed)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func IsExpired(cred Credential, margin time.Duration) bool {
|
||||
return cred.TokenExpireTime > 0 && time.Now().Add(margin).UnixMilli() > cred.TokenExpireTime
|
||||
}
|
||||
28
internal/remote/id.go
Normal file
28
internal/remote/id.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newUUID() string {
|
||||
var data [16]byte
|
||||
if _, err := rand.Read(data[:]); err != nil {
|
||||
return fmt.Sprintf("fallback-%d", time.Now().UnixNano())
|
||||
}
|
||||
data[6] = (data[6] & 0x0f) | 0x40
|
||||
data[8] = (data[8] & 0x3f) | 0x80
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x", data[0:4], data[4:6], data[6:8], data[8:10], data[10:16])
|
||||
}
|
||||
|
||||
func newHexID() string {
|
||||
var data [16]byte
|
||||
if _, err := rand.Read(data[:]); err != nil {
|
||||
seq := atomic.AddUint64(&hexCounter, 1)
|
||||
return fmt.Sprintf("fallback%x%x", time.Now().UnixNano(), seq)
|
||||
}
|
||||
return hex.EncodeToString(data[:])
|
||||
}
|
||||
@@ -15,9 +15,17 @@ import (
|
||||
"time"
|
||||
|
||||
"lingma-ipc-proxy/internal/lingmaipc"
|
||||
"lingma-ipc-proxy/internal/remote"
|
||||
"lingma-ipc-proxy/internal/toolemulation"
|
||||
)
|
||||
|
||||
type BackendMode string
|
||||
|
||||
const (
|
||||
BackendIPC BackendMode = "ipc"
|
||||
BackendRemote BackendMode = "remote"
|
||||
)
|
||||
|
||||
type SessionMode string
|
||||
|
||||
const (
|
||||
@@ -29,9 +37,13 @@ const (
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int
|
||||
Backend BackendMode
|
||||
Transport lingmaipc.Transport
|
||||
Pipe string
|
||||
WebSocketURL string
|
||||
RemoteBaseURL string
|
||||
RemoteAuthFile string
|
||||
RemoteVersion string
|
||||
Cwd string
|
||||
CurrentFilePath string
|
||||
Mode string
|
||||
@@ -129,6 +141,7 @@ type Service struct {
|
||||
stickySessionID string
|
||||
stickyModelID string
|
||||
modelMap map[string]string // official name -> internal id
|
||||
remoteClient *remote.Client
|
||||
}
|
||||
|
||||
type promptRunResult struct {
|
||||
@@ -158,6 +171,9 @@ func New(cfg Config) *Service {
|
||||
if cfg.Transport == "" {
|
||||
cfg.Transport = lingmaipc.TransportAuto
|
||||
}
|
||||
if cfg.Backend == "" {
|
||||
cfg.Backend = BackendIPC
|
||||
}
|
||||
if cfg.SessionMode == "" {
|
||||
cfg.SessionMode = SessionModeAuto
|
||||
}
|
||||
@@ -177,6 +193,9 @@ func (s *Service) DefaultModel() string {
|
||||
}
|
||||
|
||||
func (s *Service) Warmup(ctx context.Context) error {
|
||||
if s.backend() == BackendRemote {
|
||||
return s.remoteClientLocked().Warmup(ctx)
|
||||
}
|
||||
_, err := s.ensureConnected(ctx)
|
||||
return err
|
||||
}
|
||||
@@ -190,6 +209,14 @@ func (s *Service) Close() error {
|
||||
func (s *Service) State() State {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cfg.Backend == BackendRemote {
|
||||
return State{
|
||||
Endpoint: remote.ResolveBaseURL(s.cfg.RemoteBaseURL),
|
||||
Transport: "remote",
|
||||
Connected: s.remoteClient != nil,
|
||||
SessionMode: s.cfg.SessionMode,
|
||||
}
|
||||
}
|
||||
return State{
|
||||
PipePath: s.pipePath,
|
||||
Endpoint: s.endpoint,
|
||||
@@ -201,6 +228,29 @@ func (s *Service) State() State {
|
||||
}
|
||||
|
||||
func (s *Service) ListModels(ctx context.Context) ([]Model, error) {
|
||||
if s.backend() == BackendRemote {
|
||||
models, err := s.remoteClientLocked().ListModels(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]Model, 0, len(models)+1)
|
||||
seen := map[string]bool{"Auto": true}
|
||||
out = append(out, Model{ID: "Auto", Name: "Auto"})
|
||||
for _, model := range models {
|
||||
id := strings.TrimSpace(model.Key)
|
||||
if id == "" || seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
name := strings.TrimSpace(model.DisplayName)
|
||||
if name == "" {
|
||||
name = id
|
||||
}
|
||||
out = append(out, Model{ID: id, Name: name})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
ipcClient, err := s.ensureConnected(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -229,6 +279,9 @@ func (s *Service) ListModels(ctx context.Context) ([]Model, error) {
|
||||
}
|
||||
|
||||
func (s *Service) Generate(ctx context.Context, req ChatRequest) (*ChatResult, error) {
|
||||
if s.backend() == BackendRemote {
|
||||
return s.generateRemote(ctx, req, nil)
|
||||
}
|
||||
return s.generateWithReconnect(ctx, req, nil)
|
||||
}
|
||||
|
||||
@@ -237,7 +290,11 @@ func (s *Service) GenerateStream(ctx context.Context, req ChatRequest) (<-chan S
|
||||
done := make(chan StreamResult, 1)
|
||||
|
||||
go func() {
|
||||
result, err := s.generateWithReconnect(ctx, req, func(delta string) {
|
||||
generate := s.generateWithReconnect
|
||||
if s.backend() == BackendRemote {
|
||||
generate = s.generateRemote
|
||||
}
|
||||
result, err := generate(ctx, req, func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
@@ -269,6 +326,67 @@ func (s *Service) generateWithReconnect(
|
||||
return s.generateLocked(ctx, req, onDelta)
|
||||
}
|
||||
|
||||
func (s *Service) generateRemote(
|
||||
ctx context.Context,
|
||||
req ChatRequest,
|
||||
onDelta func(string),
|
||||
) (*ChatResult, error) {
|
||||
requestCtx, cancel := context.WithTimeout(ctx, s.cfg.Timeout)
|
||||
defer cancel()
|
||||
|
||||
if strings.TrimSpace(req.Model) == "" {
|
||||
req.Model = s.DefaultModel()
|
||||
}
|
||||
prompt, err := buildLingmaPrompt(req, SessionModeFresh)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.TrimSpace(prompt) == "" {
|
||||
return nil, errors.New("empty user message")
|
||||
}
|
||||
|
||||
client := s.remoteClientLocked()
|
||||
remoteResult, err := client.Chat(requestCtx, remote.ChatRequest{
|
||||
Model: req.Model,
|
||||
Prompt: prompt,
|
||||
Stream: onDelta != nil,
|
||||
Temperature: req.Temperature,
|
||||
}, onDelta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := &ChatResult{
|
||||
Text: remoteResult.Text,
|
||||
Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
|
||||
InputTokens: remoteResult.InputTokens,
|
||||
OutputTokens: remoteResult.OutputTokens,
|
||||
SessionID: "",
|
||||
RequestID: remoteResult.RequestID,
|
||||
FinishReason: "stop",
|
||||
StopReason: "stop",
|
||||
Endpoint: remote.ResolveBaseURL(s.cfg.RemoteBaseURL),
|
||||
Transport: "remote",
|
||||
EffectiveSession: SessionModeFresh,
|
||||
}
|
||||
s.applyToolEmulation(requestCtx, req, prompt, result, onDelta, func(hintPrompt string) (string, int, error) {
|
||||
retryResult, retryErr := client.Chat(requestCtx, remote.ChatRequest{
|
||||
Model: req.Model,
|
||||
Prompt: hintPrompt,
|
||||
Stream: onDelta != nil,
|
||||
Temperature: req.Temperature,
|
||||
}, onDelta)
|
||||
if retryErr != nil {
|
||||
return "", 0, retryErr
|
||||
}
|
||||
if retryResult == nil {
|
||||
return "", 0, nil
|
||||
}
|
||||
return retryResult.Text, retryResult.OutputTokens, nil
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) generateLocked(
|
||||
ctx context.Context,
|
||||
req ChatRequest,
|
||||
@@ -361,6 +479,56 @@ func (s *Service) generateLocked(
|
||||
|
||||
result = s.buildChatResult(req, sessionID, requestID, prompt, runResult, effectiveMode)
|
||||
|
||||
s.applyToolEmulation(requestCtx, req, prompt, result, onDelta, func(hintPrompt string) (string, int, error) {
|
||||
retryRequestID := lingmaipc.CreateRequestID("serve-tool")
|
||||
retryMeta := lingmaipc.CreateMeta(lingmaipc.MetaOptions{
|
||||
RequestID: retryRequestID,
|
||||
Mode: s.cfg.Mode,
|
||||
Model: internalModelID,
|
||||
ShellType: s.cfg.ShellType,
|
||||
CurrentFilePath: s.cfg.CurrentFilePath,
|
||||
EnabledMCP: []any{},
|
||||
})
|
||||
retryRunResult, retryErr := s.runPromptLocked(requestCtx, ipcClient, sessionID, hintPrompt, images, retryRequestID, retryMeta, onDelta)
|
||||
if retryErr != nil {
|
||||
return "", 0, retryErr
|
||||
}
|
||||
return retryRunResult.AssistantText, estimateTokens(retryRunResult.AssistantText), nil
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Service) backend() BackendMode {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cfg.Backend == "" {
|
||||
return BackendIPC
|
||||
}
|
||||
return s.cfg.Backend
|
||||
}
|
||||
|
||||
func (s *Service) remoteClientLocked() *remote.Client {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.remoteClient == nil {
|
||||
s.remoteClient = remote.New(remote.Config{
|
||||
BaseURL: s.cfg.RemoteBaseURL,
|
||||
AuthFile: s.cfg.RemoteAuthFile,
|
||||
CosyVersion: s.cfg.RemoteVersion,
|
||||
Timeout: s.cfg.Timeout,
|
||||
})
|
||||
}
|
||||
return s.remoteClient
|
||||
}
|
||||
|
||||
func (s *Service) applyToolEmulation(
|
||||
ctx context.Context,
|
||||
req ChatRequest,
|
||||
prompt string,
|
||||
result *ChatResult,
|
||||
onDelta func(string),
|
||||
retry func(string) (string, int, error),
|
||||
) {
|
||||
if len(req.Tools) > 0 {
|
||||
calls, remaining, parseErr := toolemulation.ParseActionBlocks(result.Text, req.Tools, toolemulation.Config{})
|
||||
if parseErr == nil && len(calls) > 0 {
|
||||
@@ -368,28 +536,36 @@ func (s *Service) generateLocked(
|
||||
result.ToolCalls = calls
|
||||
} else if shouldRetryTooling(req.ToolChoice, result.Text) {
|
||||
hintPrompt := prompt + "\n\n" + toolemulation.ForceToolingPrompt(req.ToolChoice)
|
||||
retryRequestID := lingmaipc.CreateRequestID("retry")
|
||||
retryMeta := lingmaipc.CreateMeta(lingmaipc.MetaOptions{
|
||||
RequestID: retryRequestID,
|
||||
Mode: s.cfg.Mode,
|
||||
Model: internalModelID,
|
||||
ShellType: s.cfg.ShellType,
|
||||
CurrentFilePath: s.cfg.CurrentFilePath,
|
||||
EnabledMCP: []any{},
|
||||
})
|
||||
retryResult, retryErr := s.runPromptLocked(requestCtx, ipcClient, sessionID, hintPrompt, nil, retryRequestID, retryMeta, onDelta)
|
||||
if retryErr == nil && retryResult != nil {
|
||||
retryCalls, retryRemaining, retryParseErr := toolemulation.ParseActionBlocks(retryResult.AssistantText, req.Tools, toolemulation.Config{})
|
||||
retryText := ""
|
||||
if retry != nil {
|
||||
text, outputTokens, retryErr := retry(hintPrompt)
|
||||
if retryErr == nil {
|
||||
retryText = text
|
||||
if outputTokens > 0 {
|
||||
result.OutputTokens = outputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
if retryText != "" {
|
||||
retryCalls, retryRemaining, retryParseErr := toolemulation.ParseActionBlocks(retryText, req.Tools, toolemulation.Config{})
|
||||
if retryParseErr == nil && len(retryCalls) > 0 {
|
||||
result.Text = retryRemaining
|
||||
result.ToolCalls = retryCalls
|
||||
result.OutputTokens = estimateTokens(retryResult.AssistantText)
|
||||
result.OutputTokens = estimateTokens(retryText)
|
||||
} else if inferred := toolemulation.InferToolCallsFromText(retryText, req.Tools); len(inferred) > 0 {
|
||||
result.Text = ""
|
||||
result.ToolCalls = inferred
|
||||
result.OutputTokens = estimateTokens(retryText)
|
||||
}
|
||||
}
|
||||
if len(result.ToolCalls) == 0 {
|
||||
if inferred := toolemulation.InferToolCallsFromText(result.Text, req.Tools); len(inferred) > 0 {
|
||||
result.Text = ""
|
||||
result.ToolCalls = inferred
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func shouldRetryTooling(choice toolemulation.ToolChoice, text string) bool {
|
||||
|
||||
@@ -402,6 +402,7 @@ func ForceToolingPrompt(choice ToolChoice) string {
|
||||
"You must respond with at least one valid action block now. " +
|
||||
"Select the single most appropriate available tool for the user request. " +
|
||||
"The proxy tools from the previous system message are available even if native Lingma tools are not. " +
|
||||
"If the user asked to inspect the local computer, run a shell command, read files, search files, or check current data, call the matching tool immediately. " +
|
||||
"Do not explain. Do not say tools are unavailable. Output the action block directly."
|
||||
if choice.Mode == "tool" && strings.TrimSpace(choice.Name) != "" {
|
||||
prompt += " You must call \"" + strings.TrimSpace(choice.Name) + "\"."
|
||||
@@ -420,12 +421,28 @@ func LooksLikeRefusal(text string) bool {
|
||||
"tools are unavailable",
|
||||
"cannot call tools",
|
||||
"can't call tools",
|
||||
"cannot execute",
|
||||
"can't execute",
|
||||
"cannot run commands",
|
||||
"can't run commands",
|
||||
"cannot access your computer",
|
||||
"can't access your computer",
|
||||
"cannot access your local machine",
|
||||
"can't access your local machine",
|
||||
"没有可用的工具",
|
||||
"无法调用",
|
||||
"工具不可用",
|
||||
"不能调用工具",
|
||||
"我不具备",
|
||||
"受限于当前环境",
|
||||
"当前环境限制",
|
||||
"无法直接执行",
|
||||
"不能直接执行",
|
||||
"无法执行系统命令",
|
||||
"不能执行系统命令",
|
||||
"无法访问你的电脑",
|
||||
"无法访问本机",
|
||||
"没有权限访问",
|
||||
}
|
||||
for _, needle := range needles {
|
||||
if strings.Contains(t, needle) {
|
||||
@@ -455,9 +472,16 @@ func LooksLikeMissedToolUse(text string) bool {
|
||||
"i will search",
|
||||
"please run",
|
||||
"manually run",
|
||||
"run the following command",
|
||||
"you can run",
|
||||
"you could run",
|
||||
"paste the file",
|
||||
"无法直接访问",
|
||||
"无法直接查询",
|
||||
"无法直接查看",
|
||||
"无法直接执行",
|
||||
"不能直接执行",
|
||||
"无法执行系统命令",
|
||||
"没有可用",
|
||||
"no tools available",
|
||||
"native lingma tools",
|
||||
@@ -470,6 +494,10 @@ func LooksLikeMissedToolUse(text string) bool {
|
||||
"查看文件",
|
||||
"查询天气",
|
||||
"手动运行",
|
||||
"你可以在终端中运行",
|
||||
"你可以运行",
|
||||
"请你运行",
|
||||
"请手动运行",
|
||||
"粘贴给我",
|
||||
"切换到计划模式",
|
||||
}
|
||||
@@ -481,6 +509,60 @@ func LooksLikeMissedToolUse(text string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func InferToolCallsFromText(text string, tools []ToolDef) []ToolCall {
|
||||
if !LooksLikeRefusal(text) && !LooksLikeMissedToolUse(text) {
|
||||
return nil
|
||||
}
|
||||
|
||||
commandTool, ok := selectCommandTool(tools)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if command := inferLocalCommand(text); command != "" {
|
||||
return []ToolCall{{
|
||||
ID: newCallID(),
|
||||
Name: commandTool.Name,
|
||||
Arguments: filterArgsBySchema(map[string]any{
|
||||
"command": command,
|
||||
}, commandTool.InputSchema),
|
||||
}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func selectCommandTool(tools []ToolDef) (ToolDef, bool) {
|
||||
for _, tool := range tools {
|
||||
name := strings.ToLower(strings.TrimSpace(tool.Name))
|
||||
if name == "bash" || name == "terminal" || name == "shell" || strings.Contains(name, "bash") || strings.Contains(name, "terminal") || strings.Contains(name, "shell") {
|
||||
if toolHasCommandArg(tool.InputSchema) {
|
||||
return tool, true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tool := range tools {
|
||||
if toolHasCommandArg(tool.InputSchema) {
|
||||
return tool, true
|
||||
}
|
||||
}
|
||||
return ToolDef{}, false
|
||||
}
|
||||
|
||||
func toolHasCommandArg(schema map[string]any) bool {
|
||||
props, _ := schema["properties"].(map[string]any)
|
||||
_, ok := props["command"]
|
||||
return ok
|
||||
}
|
||||
|
||||
func inferLocalCommand(text string) string {
|
||||
t := strings.ToLower(strings.TrimSpace(text))
|
||||
switch {
|
||||
case strings.Contains(t, "内存") || strings.Contains(t, "memory") || strings.Contains(t, "physmem") || strings.Contains(t, "vm_stat"):
|
||||
return `vm_stat && echo "---" && memory_pressure && echo "---" && top -l 1 -s 0 | head -n 15`
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func ParseActionBlocks(text string, tools []ToolDef, cfg Config) ([]ToolCall, string, error) {
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil, "", nil
|
||||
|
||||
@@ -9,8 +9,11 @@ func TestLooksLikeMissedToolUseDetectsLocalToolAvoidance(t *testing.T) {
|
||||
cases := []string{
|
||||
"我需要使用终端工具来查看内存。",
|
||||
"由于当前环境限制,请手动运行 top。",
|
||||
"当前环境限制,我无法直接执行系统命令查看你的内存占用。",
|
||||
"你可以在终端中运行 top -l 1 | grep PhysMem。",
|
||||
"I need to read the file first.",
|
||||
"Let me use the web search tool.",
|
||||
"You can run the following command in your terminal.",
|
||||
"现在我需要切换到计划模式。",
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -20,6 +23,42 @@ func TestLooksLikeMissedToolUseDetectsLocalToolAvoidance(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeRefusalDetectsLocalAccessRefusals(t *testing.T) {
|
||||
cases := []string{
|
||||
"当前环境限制,我无法直接执行系统命令查看你的内存占用。",
|
||||
"我无法访问你的电脑或本机文件。",
|
||||
"I cannot execute commands in your local machine.",
|
||||
"I can't access your computer directly.",
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if !LooksLikeRefusal(tc) {
|
||||
t.Fatalf("LooksLikeRefusal(%q) = false", tc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferToolCallsFromTextConvertsMemoryRefusalToBash(t *testing.T) {
|
||||
calls := InferToolCallsFromText("当前无法执行系统命令。你可以运行 vm_stat 查看内存占用。", []ToolDef{{
|
||||
Name: "Bash",
|
||||
InputSchema: map[string]any{
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []any{"command"},
|
||||
},
|
||||
}})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("call count = %d", len(calls))
|
||||
}
|
||||
if calls[0].Name != "Bash" {
|
||||
t.Fatalf("tool name = %q", calls[0].Name)
|
||||
}
|
||||
command, _ := calls[0].Arguments["command"].(string)
|
||||
if !strings.Contains(command, "vm_stat") || !strings.Contains(command, "memory_pressure") {
|
||||
t.Fatalf("unexpected command = %q", command)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeMissedToolUseIgnoresFinalAnswers(t *testing.T) {
|
||||
text := "这个文件负责 HTTP API 路由和 OpenAI 兼容响应。"
|
||||
if LooksLikeMissedToolUse(text) {
|
||||
|
||||
Reference in New Issue
Block a user