Fix tool loop handling and count tokens endpoint

This commit is contained in:
lutc5
2026-05-06 15:58:37 +08:00
parent 1c349227a3
commit fe1d5b5348
6 changed files with 182 additions and 7 deletions

View File

@@ -117,6 +117,7 @@ func NewServer(addr string, svc *service.Service) *Server {
mux.HandleFunc("/v1/props", s.handleModelProps)
mux.HandleFunc("/props", s.handleModelProps)
mux.HandleFunc("/version", s.handleVersion)
mux.HandleFunc("/v1/messages/count_tokens", s.handleAnthropicCountTokens)
mux.HandleFunc("/v1/messages", s.handleAnthropicMessages)
mux.HandleFunc("/v1/chat/completions", s.handleOpenAIChatCompletions)
mux.HandleFunc("/api/v1/chat/completions", s.handleOpenAIChatCompletions)
@@ -446,6 +447,27 @@ func (s *Server) handleVersion(w http.ResponseWriter, r *http.Request) {
})
}
func (s *Server) handleAnthropicCountTokens(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
if r.Method != http.MethodPost {
writeAnthropicError(w, http.StatusMethodNotAllowed, "invalid_request_error", "method not allowed")
return
}
var req anthropicRequest
if err := decodeJSON(r, &req); err != nil {
writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", err.Error())
return
}
writeJSON(w, http.StatusOK, map[string]any{
"input_tokens": estimateAnthropicInputTokens(req),
})
}
func (s *Server) handleAnthropicMessages(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
@@ -1292,6 +1314,9 @@ func anthropicHostedWebSearchCall(req anthropicRequest) (toolemulation.ToolCall,
if !hasAnthropicHostedWebSearchTool(req.Tools) {
return toolemulation.ToolCall{}, false
}
if hasAnthropicToolResult(req.Messages) {
return toolemulation.ToolCall{}, false
}
if !anthropicHostedWebSearchRequested(req.Tools, req.ToolChoice) {
return toolemulation.ToolCall{}, false
}
@@ -1307,6 +1332,46 @@ func anthropicHostedWebSearchCall(req anthropicRequest) (toolemulation.ToolCall,
}, true
}
func hasAnthropicToolResult(messages []rawMessage) bool {
for _, message := range messages {
items, ok := message.Content.([]any)
if !ok {
continue
}
for _, item := range items {
m, ok := item.(map[string]any)
if ok && stringFromAny(m["type"]) == "tool_result" {
return true
}
}
}
return false
}
func estimateAnthropicInputTokens(req anthropicRequest) int {
payload := map[string]any{
"model": req.Model,
"system": req.System,
"messages": req.Messages,
"tools": req.Tools,
"tool_choice": req.ToolChoice,
"thinking": req.Thinking,
}
raw, err := json.Marshal(payload)
if err != nil {
return 1
}
runes := len([]rune(string(raw)))
if runes == 0 {
return 1
}
tokens := (runes + 2) / 3
if tokens < 1 {
return 1
}
return tokens
}
func hasAnthropicHostedWebSearchTool(raw any) bool {
items, ok := raw.([]any)
if !ok {

View File

@@ -218,6 +218,64 @@ func TestAnthropicHostedWebSearchCallIgnoresRegularClientWebSearch(t *testing.T)
}
}
func TestAnthropicHostedWebSearchCallIgnoresToolResultFollowup(t *testing.T) {
req := anthropicRequest{
Tools: []any{
map[string]any{
"name": "web_search",
"type": "web_search_20250305",
},
},
ToolChoice: map[string]any{
"type": "tool",
"name": "web_search",
},
Messages: []rawMessage{{
Role: "user",
Content: []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": "result",
},
},
}},
}
if _, ok := anthropicHostedWebSearchCall(req); ok {
t.Fatal("hosted web_search should not short-circuit after a tool_result")
}
}
func TestAnthropicCountTokensEndpoint(t *testing.T) {
server := NewServer("", service.New(service.Config{
Model: "Qwen3-Coder",
Timeout: time.Second,
}))
req := httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", strings.NewReader(`{
"model":"kmodel",
"max_tokens":128,
"system":"You are concise.",
"messages":[{"role":"user","content":"hello"}],
"tools":[{"name":"read_file","input_schema":{"type":"object","properties":{"file_path":{"type":"string"}},"required":["file_path"]}}]
}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
server.http.Handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d body = %s", rec.Code, rec.Body.String())
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatal(err)
}
if body["input_tokens"].(float64) <= 0 {
t.Fatalf("input_tokens = %#v", body["input_tokens"])
}
}
func TestDiscoveryCompatibilityEndpoints(t *testing.T) {
server := NewServer("", service.New(service.Config{
Model: "Qwen3-Coder",