feat: add websocket transport support
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package lingmaipc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
@@ -9,17 +8,12 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
winio "github.com/Microsoft/go-winio"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -65,56 +59,18 @@ type responseEnvelope struct {
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
writeMu sync.Mutex
|
||||
pendingMu sync.Mutex
|
||||
pending map[int]chan responseEnvelope
|
||||
subsMu sync.RWMutex
|
||||
subs map[int]chan Notification
|
||||
nextID atomic.Int64
|
||||
nextSubID atomic.Int64
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
closeErr atomic.Value
|
||||
}
|
||||
|
||||
func ResolvePipePath(explicit string) (string, error) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return "", errors.New("Lingma IPC proxy currently requires Windows")
|
||||
}
|
||||
|
||||
if pipe := strings.TrimSpace(explicit); pipe != "" {
|
||||
return normalizePipePath(pipe), nil
|
||||
}
|
||||
if pipe := strings.TrimSpace(os.Getenv("LINGMA_IPC_PIPE")); pipe != "" {
|
||||
return normalizePipePath(pipe), nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(PipeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("enumerate Lingma named pipes: %w", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, PipePrefix) {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
if len(names) == 0 {
|
||||
return "", errors.New("no active Lingma named pipe was found")
|
||||
}
|
||||
return PipeDir + names[len(names)-1], nil
|
||||
}
|
||||
|
||||
func normalizePipePath(pipe string) string {
|
||||
if strings.HasPrefix(pipe, PipeDir) {
|
||||
return pipe
|
||||
}
|
||||
return PipeDir + pipe
|
||||
transport framedTransport
|
||||
kind Transport
|
||||
pendingMu sync.Mutex
|
||||
pending map[int]chan responseEnvelope
|
||||
subsMu sync.RWMutex
|
||||
subs map[int]chan Notification
|
||||
nextID atomic.Int64
|
||||
nextSubID atomic.Int64
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
closeErr atomic.Value
|
||||
responseMu sync.Mutex
|
||||
}
|
||||
|
||||
func DefaultShellType() string {
|
||||
@@ -162,43 +118,27 @@ func CreateMeta(opts MetaOptions) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
func Connect(ctx context.Context, pipePath string) (*Client, error) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil, errors.New("Lingma IPC proxy currently requires Windows")
|
||||
}
|
||||
|
||||
conn, err := winio.DialPipeContext(ctx, pipePath)
|
||||
func Connect(ctx context.Context, opts DialOptions) (*Client, error) {
|
||||
transport, err := connectTransport(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect Lingma IPC pipe %s: %w", pipePath, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
conn: conn,
|
||||
reader: bufio.NewReader(conn),
|
||||
pending: make(map[int]chan responseEnvelope),
|
||||
subs: make(map[int]chan Notification),
|
||||
closed: make(chan struct{}),
|
||||
transport: transport,
|
||||
kind: opts.Transport,
|
||||
pending: make(map[int]chan responseEnvelope),
|
||||
subs: make(map[int]chan Notification),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go client.readLoop()
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *Client) Request(ctx context.Context, method string, params any, out any) error {
|
||||
if params == nil {
|
||||
params = map[string]any{}
|
||||
}
|
||||
|
||||
id := int(c.nextID.Add(1))
|
||||
payload := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
payload, id, err := c.buildRequest(method, params)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal request %s: %w", method, err)
|
||||
return err
|
||||
}
|
||||
|
||||
responseCh := make(chan responseEnvelope, 1)
|
||||
@@ -206,7 +146,7 @@ func (c *Client) Request(ctx context.Context, method string, params any, out any
|
||||
c.pending[id] = responseCh
|
||||
c.pendingMu.Unlock()
|
||||
|
||||
if err := c.writeFrame(body); err != nil {
|
||||
if err := c.transport.WriteFrame(payload); err != nil {
|
||||
c.pendingMu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.pendingMu.Unlock()
|
||||
@@ -235,6 +175,14 @@ func (c *Client) Request(ctx context.Context, method string, params any, out any
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Send(method string, params any) error {
|
||||
payload, _, err := c.buildRequest(method, params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.transport.WriteFrame(payload)
|
||||
}
|
||||
|
||||
func (c *Client) Subscribe() (<-chan Notification, func()) {
|
||||
id := int(c.nextSubID.Add(1))
|
||||
ch := make(chan Notification, 2048)
|
||||
@@ -253,10 +201,21 @@ func (c *Client) Subscribe() (<-chan Notification, func()) {
|
||||
return ch, cancel
|
||||
}
|
||||
|
||||
func (c *Client) Address() string {
|
||||
if c.transport == nil {
|
||||
return ""
|
||||
}
|
||||
return c.transport.Address()
|
||||
}
|
||||
|
||||
func (c *Client) Transport() Transport {
|
||||
return c.kind
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closed)
|
||||
if err := c.conn.Close(); err != nil {
|
||||
if err := c.transport.Close(); err != nil {
|
||||
c.closeErr.Store(err)
|
||||
}
|
||||
c.failPending(io.EOF)
|
||||
@@ -268,26 +227,32 @@ func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) writeFrame(body []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
func (c *Client) buildRequest(method string, params any) ([]byte, int, error) {
|
||||
if params == nil {
|
||||
params = map[string]any{}
|
||||
}
|
||||
|
||||
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
|
||||
if _, err := c.conn.Write(frame); err != nil {
|
||||
return fmt.Errorf("write frame header: %w", err)
|
||||
id := int(c.nextID.Add(1))
|
||||
payload := map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
if _, err := c.conn.Write(body); err != nil {
|
||||
return fmt.Errorf("write frame body: %w", err)
|
||||
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("marshal request %s: %w", method, err)
|
||||
}
|
||||
return nil
|
||||
return body, id, nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop() {
|
||||
defer c.Close()
|
||||
for {
|
||||
body, err := c.readFrame()
|
||||
body, err := c.transport.ReadFrame()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
c.closeErr.Store(err)
|
||||
}
|
||||
return
|
||||
@@ -299,8 +264,11 @@ func (c *Client) readLoop() {
|
||||
return
|
||||
}
|
||||
|
||||
if envelope.Method != "" && envelope.ID == nil {
|
||||
if envelope.Method != "" {
|
||||
c.broadcast(Notification{JSONRPC: envelope.JSONRPC, Method: envelope.Method, Params: envelope.Params})
|
||||
if envelope.ID != nil {
|
||||
_ = c.sendEmptyResponse(*envelope.ID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -321,35 +289,19 @@ func (c *Client) readLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readFrame() ([]byte, error) {
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := c.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if line == "\r\n" {
|
||||
break
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
|
||||
raw := strings.TrimSpace(line[len("content-length:"):])
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse content length %q: %w", raw, err)
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, errors.New("missing Content-Length header")
|
||||
}
|
||||
func (c *Client) sendEmptyResponse(id int) error {
|
||||
c.responseMu.Lock()
|
||||
defer c.responseMu.Unlock()
|
||||
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(c.reader, body); err != nil {
|
||||
return nil, err
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": nil,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return body, nil
|
||||
return c.transport.WriteFrame(body)
|
||||
}
|
||||
|
||||
func (c *Client) broadcast(notification Notification) {
|
||||
|
||||
Reference in New Issue
Block a user