feat: add websocket transport support

This commit is contained in:
coolxll
2026-03-26 09:37:00 +08:00
parent e5d1134502
commit c184c2a5e6
9 changed files with 518 additions and 143 deletions

View File

@@ -1,7 +1,6 @@
package lingmaipc
import (
"bufio"
"context"
"crypto/rand"
"encoding/hex"
@@ -9,17 +8,12 @@ import (
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
winio "github.com/Microsoft/go-winio"
)
const (
@@ -65,56 +59,18 @@ type responseEnvelope struct {
}
type Client struct {
conn net.Conn
reader *bufio.Reader
writeMu sync.Mutex
pendingMu sync.Mutex
pending map[int]chan responseEnvelope
subsMu sync.RWMutex
subs map[int]chan Notification
nextID atomic.Int64
nextSubID atomic.Int64
closeOnce sync.Once
closed chan struct{}
closeErr atomic.Value
}
func ResolvePipePath(explicit string) (string, error) {
if runtime.GOOS != "windows" {
return "", errors.New("Lingma IPC proxy currently requires Windows")
}
if pipe := strings.TrimSpace(explicit); pipe != "" {
return normalizePipePath(pipe), nil
}
if pipe := strings.TrimSpace(os.Getenv("LINGMA_IPC_PIPE")); pipe != "" {
return normalizePipePath(pipe), nil
}
entries, err := os.ReadDir(PipeDir)
if err != nil {
return "", fmt.Errorf("enumerate Lingma named pipes: %w", err)
}
names := make([]string, 0, len(entries))
for _, entry := range entries {
name := entry.Name()
if strings.HasPrefix(name, PipePrefix) {
names = append(names, name)
}
}
sort.Strings(names)
if len(names) == 0 {
return "", errors.New("no active Lingma named pipe was found")
}
return PipeDir + names[len(names)-1], nil
}
func normalizePipePath(pipe string) string {
if strings.HasPrefix(pipe, PipeDir) {
return pipe
}
return PipeDir + pipe
transport framedTransport
kind Transport
pendingMu sync.Mutex
pending map[int]chan responseEnvelope
subsMu sync.RWMutex
subs map[int]chan Notification
nextID atomic.Int64
nextSubID atomic.Int64
closeOnce sync.Once
closed chan struct{}
closeErr atomic.Value
responseMu sync.Mutex
}
func DefaultShellType() string {
@@ -162,43 +118,27 @@ func CreateMeta(opts MetaOptions) map[string]any {
return meta
}
func Connect(ctx context.Context, pipePath string) (*Client, error) {
if runtime.GOOS != "windows" {
return nil, errors.New("Lingma IPC proxy currently requires Windows")
}
conn, err := winio.DialPipeContext(ctx, pipePath)
func Connect(ctx context.Context, opts DialOptions) (*Client, error) {
transport, err := connectTransport(ctx, opts)
if err != nil {
return nil, fmt.Errorf("connect Lingma IPC pipe %s: %w", pipePath, err)
return nil, err
}
client := &Client{
conn: conn,
reader: bufio.NewReader(conn),
pending: make(map[int]chan responseEnvelope),
subs: make(map[int]chan Notification),
closed: make(chan struct{}),
transport: transport,
kind: opts.Transport,
pending: make(map[int]chan responseEnvelope),
subs: make(map[int]chan Notification),
closed: make(chan struct{}),
}
go client.readLoop()
return client, nil
}
func (c *Client) Request(ctx context.Context, method string, params any, out any) error {
if params == nil {
params = map[string]any{}
}
id := int(c.nextID.Add(1))
payload := map[string]any{
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}
body, err := json.Marshal(payload)
payload, id, err := c.buildRequest(method, params)
if err != nil {
return fmt.Errorf("marshal request %s: %w", method, err)
return err
}
responseCh := make(chan responseEnvelope, 1)
@@ -206,7 +146,7 @@ func (c *Client) Request(ctx context.Context, method string, params any, out any
c.pending[id] = responseCh
c.pendingMu.Unlock()
if err := c.writeFrame(body); err != nil {
if err := c.transport.WriteFrame(payload); err != nil {
c.pendingMu.Lock()
delete(c.pending, id)
c.pendingMu.Unlock()
@@ -235,6 +175,14 @@ func (c *Client) Request(ctx context.Context, method string, params any, out any
}
}
func (c *Client) Send(method string, params any) error {
payload, _, err := c.buildRequest(method, params)
if err != nil {
return err
}
return c.transport.WriteFrame(payload)
}
func (c *Client) Subscribe() (<-chan Notification, func()) {
id := int(c.nextSubID.Add(1))
ch := make(chan Notification, 2048)
@@ -253,10 +201,21 @@ func (c *Client) Subscribe() (<-chan Notification, func()) {
return ch, cancel
}
func (c *Client) Address() string {
if c.transport == nil {
return ""
}
return c.transport.Address()
}
func (c *Client) Transport() Transport {
return c.kind
}
func (c *Client) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
if err := c.conn.Close(); err != nil {
if err := c.transport.Close(); err != nil {
c.closeErr.Store(err)
}
c.failPending(io.EOF)
@@ -268,26 +227,32 @@ func (c *Client) Close() error {
return nil
}
func (c *Client) writeFrame(body []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
func (c *Client) buildRequest(method string, params any) ([]byte, int, error) {
if params == nil {
params = map[string]any{}
}
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
if _, err := c.conn.Write(frame); err != nil {
return fmt.Errorf("write frame header: %w", err)
id := int(c.nextID.Add(1))
payload := map[string]any{
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
}
if _, err := c.conn.Write(body); err != nil {
return fmt.Errorf("write frame body: %w", err)
body, err := json.Marshal(payload)
if err != nil {
return nil, 0, fmt.Errorf("marshal request %s: %w", method, err)
}
return nil
return body, id, nil
}
func (c *Client) readLoop() {
defer c.Close()
for {
body, err := c.readFrame()
body, err := c.transport.ReadFrame()
if err != nil {
if !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
if !errors.Is(err, io.EOF) {
c.closeErr.Store(err)
}
return
@@ -299,8 +264,11 @@ func (c *Client) readLoop() {
return
}
if envelope.Method != "" && envelope.ID == nil {
if envelope.Method != "" {
c.broadcast(Notification{JSONRPC: envelope.JSONRPC, Method: envelope.Method, Params: envelope.Params})
if envelope.ID != nil {
_ = c.sendEmptyResponse(*envelope.ID)
}
continue
}
@@ -321,35 +289,19 @@ func (c *Client) readLoop() {
}
}
func (c *Client) readFrame() ([]byte, error) {
contentLength := -1
for {
line, err := c.reader.ReadString('\n')
if err != nil {
return nil, err
}
if line == "\r\n" {
break
}
line = strings.TrimSpace(line)
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
raw := strings.TrimSpace(line[len("content-length:"):])
n, err := strconv.Atoi(raw)
if err != nil {
return nil, fmt.Errorf("parse content length %q: %w", raw, err)
}
contentLength = n
}
}
if contentLength < 0 {
return nil, errors.New("missing Content-Length header")
}
func (c *Client) sendEmptyResponse(id int) error {
c.responseMu.Lock()
defer c.responseMu.Unlock()
body := make([]byte, contentLength)
if _, err := io.ReadFull(c.reader, body); err != nil {
return nil, err
body, err := json.Marshal(map[string]any{
"jsonrpc": "2.0",
"id": id,
"result": nil,
})
if err != nil {
return err
}
return body, nil
return c.transport.WriteFrame(body)
}
func (c *Client) broadcast(notification Notification) {

View File

@@ -0,0 +1,344 @@
package lingmaipc
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
winio "github.com/Microsoft/go-winio"
"github.com/gorilla/websocket"
)
type Transport string
const (
TransportAuto Transport = "auto"
TransportPipe Transport = "pipe"
TransportWebSocket Transport = "websocket"
)
type DialOptions struct {
Transport Transport
PipePath string
WebSocketURL string
}
type framedTransport interface {
ReadFrame() ([]byte, error)
WriteFrame([]byte) error
Close() error
Address() string
}
func ParseTransport(value string) (Transport, error) {
switch strings.ToLower(strings.TrimSpace(value)) {
case "", string(TransportAuto):
return TransportAuto, nil
case string(TransportPipe):
return TransportPipe, nil
case "ws", string(TransportWebSocket):
return TransportWebSocket, nil
default:
return "", fmt.Errorf("invalid Lingma transport %q; expected auto, pipe, or websocket", value)
}
}
func ResolveDialOptions(transport Transport, explicitPipe string, explicitWebSocketURL string) (DialOptions, error) {
switch transport {
case "", TransportAuto:
if hasConfiguredWebSocketURL(explicitWebSocketURL) {
wsURL, err := ResolveWebSocketURL(explicitWebSocketURL)
if err != nil {
return DialOptions{}, err
}
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
}
pipePath, pipeErr := ResolvePipePath(explicitPipe)
if pipeErr == nil {
return DialOptions{Transport: TransportPipe, PipePath: pipePath}, nil
}
wsURL, wsErr := ResolveWebSocketURL(explicitWebSocketURL)
if wsErr == nil {
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
}
return DialOptions{}, fmt.Errorf("resolve Lingma transport automatically: pipe: %w; websocket: %v", pipeErr, wsErr)
case TransportPipe:
pipePath, err := ResolvePipePath(explicitPipe)
if err != nil {
return DialOptions{}, err
}
return DialOptions{Transport: TransportPipe, PipePath: pipePath}, nil
case TransportWebSocket:
wsURL, err := ResolveWebSocketURL(explicitWebSocketURL)
if err != nil {
return DialOptions{}, err
}
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
default:
return DialOptions{}, fmt.Errorf("unsupported Lingma transport %q", transport)
}
}
func ResolvePipePath(explicit string) (string, error) {
if runtime.GOOS != "windows" {
return "", errors.New("Lingma pipe transport currently requires Windows")
}
if pipe := strings.TrimSpace(explicit); pipe != "" {
return normalizePipePath(pipe), nil
}
if pipe := strings.TrimSpace(os.Getenv("LINGMA_IPC_PIPE")); pipe != "" {
return normalizePipePath(pipe), nil
}
entries, err := os.ReadDir(PipeDir)
if err != nil {
return "", fmt.Errorf("enumerate Lingma named pipes: %w", err)
}
names := make([]string, 0, len(entries))
for _, entry := range entries {
name := entry.Name()
if strings.HasPrefix(name, PipePrefix) {
names = append(names, name)
}
}
sort.Strings(names)
if len(names) == 0 {
return "", errors.New("no active Lingma named pipe was found")
}
return PipeDir + names[len(names)-1], nil
}
func ResolveWebSocketURL(explicit string) (string, error) {
value := strings.TrimSpace(explicit)
if value == "" {
value = strings.TrimSpace(os.Getenv("LINGMA_PROXY_WS_URL"))
}
if value == "" {
return "", errors.New("no Lingma websocket URL configured")
}
parsed, err := url.Parse(value)
if err != nil {
return "", fmt.Errorf("parse Lingma websocket URL %q: %w", value, err)
}
if parsed.Scheme != "ws" && parsed.Scheme != "wss" {
return "", fmt.Errorf("Lingma websocket URL must start with ws:// or wss://: %q", value)
}
if parsed.Host == "" {
return "", fmt.Errorf("Lingma websocket URL is missing a host: %q", value)
}
if parsed.Path == "" {
parsed.Path = "/"
}
return parsed.String(), nil
}
func hasConfiguredWebSocketURL(explicit string) bool {
return strings.TrimSpace(explicit) != "" || strings.TrimSpace(os.Getenv("LINGMA_PROXY_WS_URL")) != ""
}
func normalizePipePath(pipe string) string {
if strings.HasPrefix(pipe, PipeDir) {
return pipe
}
return PipeDir + pipe
}
func connectTransport(ctx context.Context, opts DialOptions) (framedTransport, error) {
switch opts.Transport {
case TransportPipe:
return connectPipeTransport(ctx, opts.PipePath)
case TransportWebSocket:
return connectWebSocketTransport(ctx, opts.WebSocketURL)
default:
return nil, fmt.Errorf("unsupported Lingma transport %q", opts.Transport)
}
}
type pipeTransport struct {
path string
conn net.Conn
reader *framedReader
write sync.Mutex
}
func connectPipeTransport(ctx context.Context, pipePath string) (*pipeTransport, error) {
conn, err := winio.DialPipeContext(ctx, pipePath)
if err != nil {
return nil, fmt.Errorf("connect Lingma IPC pipe %s: %w", pipePath, err)
}
return &pipeTransport{
path: pipePath,
conn: conn,
reader: newFramedReader(conn),
}, nil
}
func (t *pipeTransport) ReadFrame() ([]byte, error) {
return t.reader.ReadFrame()
}
func (t *pipeTransport) WriteFrame(body []byte) error {
t.write.Lock()
defer t.write.Unlock()
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
if _, err := t.conn.Write(frame); err != nil {
return fmt.Errorf("write frame header: %w", err)
}
if _, err := t.conn.Write(body); err != nil {
return fmt.Errorf("write frame body: %w", err)
}
return nil
}
func (t *pipeTransport) Close() error {
return t.conn.Close()
}
func (t *pipeTransport) Address() string {
return t.path
}
type websocketTransport struct {
url string
conn *websocket.Conn
buffer bytes.Buffer
writeMu sync.Mutex
}
func connectWebSocketTransport(ctx context.Context, wsURL string) (*websocketTransport, error) {
dialer := websocket.Dialer{HandshakeTimeout: 5 * time.Second}
conn, _, err := dialer.DialContext(ctx, wsURL, nil)
if err != nil {
return nil, fmt.Errorf("connect Lingma websocket %s: %w", wsURL, err)
}
return &websocketTransport{url: wsURL, conn: conn}, nil
}
func (t *websocketTransport) ReadFrame() ([]byte, error) {
for {
if body, ok, err := tryReadBufferedFrame(&t.buffer); ok || err != nil {
return body, err
}
messageType, payload, err := t.conn.ReadMessage()
if err != nil {
return nil, err
}
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
continue
}
t.buffer.Write(payload)
}
}
func (t *websocketTransport) WriteFrame(body []byte) error {
t.writeMu.Lock()
defer t.writeMu.Unlock()
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
frame = append(frame, body...)
if err := t.conn.WriteMessage(websocket.TextMessage, frame); err != nil {
return fmt.Errorf("write websocket frame: %w", err)
}
return nil
}
func (t *websocketTransport) Close() error {
return t.conn.Close()
}
func (t *websocketTransport) Address() string {
return t.url
}
type framedReader struct {
reader *bufio.Reader
}
func newFramedReader(r io.Reader) *framedReader {
return &framedReader{reader: bufio.NewReader(r)}
}
func (r *framedReader) ReadFrame() ([]byte, error) {
contentLength := -1
for {
line, err := r.reader.ReadString('\n')
if err != nil {
return nil, err
}
if line == "\r\n" {
break
}
line = strings.TrimSpace(line)
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
raw := strings.TrimSpace(line[len("content-length:"):])
n, err := strconv.Atoi(raw)
if err != nil {
return nil, fmt.Errorf("parse content length %q: %w", raw, err)
}
contentLength = n
}
}
if contentLength < 0 {
return nil, errors.New("missing Content-Length header")
}
body := make([]byte, contentLength)
if _, err := io.ReadFull(r.reader, body); err != nil {
return nil, err
}
return body, nil
}
func tryReadBufferedFrame(buffer *bytes.Buffer) ([]byte, bool, error) {
data := buffer.Bytes()
headerEnd := bytes.Index(data, []byte("\r\n\r\n"))
if headerEnd < 0 {
return nil, false, nil
}
contentLength := -1
for _, rawLine := range bytes.Split(data[:headerEnd], []byte("\r\n")) {
line := strings.TrimSpace(string(rawLine))
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
raw := strings.TrimSpace(line[len("content-length:"):])
n, err := strconv.Atoi(raw)
if err != nil {
return nil, false, fmt.Errorf("parse content length %q: %w", raw, err)
}
contentLength = n
break
}
}
if contentLength < 0 {
return nil, false, errors.New("missing Content-Length header")
}
bodyStart := headerEnd + len("\r\n\r\n")
if len(data[bodyStart:]) < contentLength {
return nil, false, nil
}
frame := make([]byte, contentLength)
copy(frame, data[bodyStart:bodyStart+contentLength])
buffer.Next(bodyStart + contentLength)
return frame, true, nil
}

View File

@@ -25,7 +25,9 @@ const (
type Config struct {
Host string
Port int
Transport lingmaipc.Transport
Pipe string
WebSocketURL string
Cwd string
CurrentFilePath string
Mode string
@@ -57,6 +59,8 @@ type ChatResult struct {
UsedTokens int
LimitTokens int
PipePath string
Endpoint string
Transport string
EffectiveSession SessionMode
}
@@ -77,6 +81,8 @@ type Model struct {
type State struct {
PipePath string `json:"pipe_path,omitempty"`
Endpoint string `json:"endpoint,omitempty"`
Transport string `json:"transport,omitempty"`
Connected bool `json:"connected"`
StickySessionID string `json:"sticky_session_id,omitempty"`
SessionMode SessionMode `json:"session_mode"`
@@ -87,6 +93,8 @@ type Service struct {
mu sync.Mutex
client *lingmaipc.Client
pipePath string
endpoint string
transport lingmaipc.Transport
stickySessionID string
stickyModelID string
}
@@ -114,6 +122,9 @@ func New(cfg Config) *Service {
if cfg.Timeout <= 0 {
cfg.Timeout = 120 * time.Second
}
if cfg.Transport == "" {
cfg.Transport = lingmaipc.TransportAuto
}
if cfg.SessionMode == "" {
cfg.SessionMode = SessionModeAuto
}
@@ -136,6 +147,8 @@ func (s *Service) State() State {
defer s.mu.Unlock()
return State{
PipePath: s.pipePath,
Endpoint: s.endpoint,
Transport: string(s.transport),
Connected: s.client != nil,
StickySessionID: s.stickySessionID,
SessionMode: s.cfg.SessionMode,
@@ -282,6 +295,7 @@ func (s *Service) buildChatResult(
runResult *promptRunResult,
effectiveMode SessionMode,
) *ChatResult {
endpoint := s.currentPipePath()
return &ChatResult{
Text: runResult.AssistantText,
Model: valueOr(strings.TrimSpace(req.Model), "lingma"),
@@ -293,7 +307,9 @@ func (s *Service) buildChatResult(
StopReason: nestedString(runResult.PromptResult, "stopReason"),
UsedTokens: int(nestedInt64(runResult.ContextUsage, "usedTokens")),
LimitTokens: int(nestedInt64(runResult.ContextUsage, "limitTokens")),
PipePath: s.currentPipePath(),
PipePath: endpoint,
Endpoint: endpoint,
Transport: string(s.currentTransport()),
EffectiveSession: effectiveMode,
}
}
@@ -309,11 +325,11 @@ func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client,
return s.client, nil
}
pipePath, err := lingmaipc.ResolvePipePath(s.cfg.Pipe)
dialOptions, err := lingmaipc.ResolveDialOptions(s.cfg.Transport, s.cfg.Pipe, s.cfg.WebSocketURL)
if err != nil {
return nil, err
}
client, err := lingmaipc.Connect(ctx, pipePath)
client, err := lingmaipc.Connect(ctx, dialOptions)
if err != nil {
return nil, err
}
@@ -327,19 +343,25 @@ func (s *Service) ensureConnectedLocked(ctx context.Context) (*lingmaipc.Client,
}
s.client = client
s.pipePath = pipePath
s.pipePath = dialOptions.PipePath
s.endpoint = client.Address()
s.transport = client.Transport()
return client, nil
}
func (s *Service) closeClientLocked() error {
if s.client == nil {
s.pipePath = ""
s.endpoint = ""
s.transport = ""
s.clearStickyLocked()
return nil
}
client := s.client
s.client = nil
s.pipePath = ""
s.endpoint = ""
s.transport = ""
s.clearStickyLocked()
return client.Close()
}
@@ -388,9 +410,18 @@ func (s *Service) clearStickyLocked() {
func (s *Service) currentPipePath() string {
s.mu.Lock()
defer s.mu.Unlock()
if strings.TrimSpace(s.endpoint) != "" {
return s.endpoint
}
return s.pipePath
}
func (s *Service) currentTransport() lingmaipc.Transport {
s.mu.Lock()
defer s.mu.Unlock()
return s.transport
}
func (s *Service) resolveSessionLocked(ctx context.Context, client *lingmaipc.Client, mode SessionMode) (string, error) {
if mode == SessionModeReuse && strings.TrimSpace(s.stickySessionID) != "" {
return s.stickySessionID, nil
@@ -436,18 +467,17 @@ func (s *Service) runPromptLocked(
notifications, cancel := client.Subscribe()
defer cancel()
promptResult := map[string]any{}
if err := client.Request(ctx, "session/prompt", map[string]any{
if err := client.Send("session/prompt", map[string]any{
"sessionId": sessionID,
"prompt": []map[string]any{
{"type": "text", "text": text},
},
"_meta": meta,
}, &promptResult); err != nil {
}); err != nil {
return nil, err
}
result := &promptRunResult{PromptResult: promptResult}
result := &promptRunResult{PromptResult: map[string]any{}}
var builder strings.Builder
for {