feat: add websocket transport support
This commit is contained in:
344
internal/lingmaipc/transport.go
Normal file
344
internal/lingmaipc/transport.go
Normal file
@@ -0,0 +1,344 @@
|
||||
package lingmaipc
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
winio "github.com/Microsoft/go-winio"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type Transport string
|
||||
|
||||
const (
|
||||
TransportAuto Transport = "auto"
|
||||
TransportPipe Transport = "pipe"
|
||||
TransportWebSocket Transport = "websocket"
|
||||
)
|
||||
|
||||
type DialOptions struct {
|
||||
Transport Transport
|
||||
PipePath string
|
||||
WebSocketURL string
|
||||
}
|
||||
|
||||
type framedTransport interface {
|
||||
ReadFrame() ([]byte, error)
|
||||
WriteFrame([]byte) error
|
||||
Close() error
|
||||
Address() string
|
||||
}
|
||||
|
||||
func ParseTransport(value string) (Transport, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "", string(TransportAuto):
|
||||
return TransportAuto, nil
|
||||
case string(TransportPipe):
|
||||
return TransportPipe, nil
|
||||
case "ws", string(TransportWebSocket):
|
||||
return TransportWebSocket, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Lingma transport %q; expected auto, pipe, or websocket", value)
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveDialOptions(transport Transport, explicitPipe string, explicitWebSocketURL string) (DialOptions, error) {
|
||||
switch transport {
|
||||
case "", TransportAuto:
|
||||
if hasConfiguredWebSocketURL(explicitWebSocketURL) {
|
||||
wsURL, err := ResolveWebSocketURL(explicitWebSocketURL)
|
||||
if err != nil {
|
||||
return DialOptions{}, err
|
||||
}
|
||||
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
|
||||
}
|
||||
|
||||
pipePath, pipeErr := ResolvePipePath(explicitPipe)
|
||||
if pipeErr == nil {
|
||||
return DialOptions{Transport: TransportPipe, PipePath: pipePath}, nil
|
||||
}
|
||||
|
||||
wsURL, wsErr := ResolveWebSocketURL(explicitWebSocketURL)
|
||||
if wsErr == nil {
|
||||
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
|
||||
}
|
||||
return DialOptions{}, fmt.Errorf("resolve Lingma transport automatically: pipe: %w; websocket: %v", pipeErr, wsErr)
|
||||
case TransportPipe:
|
||||
pipePath, err := ResolvePipePath(explicitPipe)
|
||||
if err != nil {
|
||||
return DialOptions{}, err
|
||||
}
|
||||
return DialOptions{Transport: TransportPipe, PipePath: pipePath}, nil
|
||||
case TransportWebSocket:
|
||||
wsURL, err := ResolveWebSocketURL(explicitWebSocketURL)
|
||||
if err != nil {
|
||||
return DialOptions{}, err
|
||||
}
|
||||
return DialOptions{Transport: TransportWebSocket, WebSocketURL: wsURL}, nil
|
||||
default:
|
||||
return DialOptions{}, fmt.Errorf("unsupported Lingma transport %q", transport)
|
||||
}
|
||||
}
|
||||
|
||||
func ResolvePipePath(explicit string) (string, error) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return "", errors.New("Lingma pipe transport currently requires Windows")
|
||||
}
|
||||
|
||||
if pipe := strings.TrimSpace(explicit); pipe != "" {
|
||||
return normalizePipePath(pipe), nil
|
||||
}
|
||||
if pipe := strings.TrimSpace(os.Getenv("LINGMA_IPC_PIPE")); pipe != "" {
|
||||
return normalizePipePath(pipe), nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(PipeDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("enumerate Lingma named pipes: %w", err)
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if strings.HasPrefix(name, PipePrefix) {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
if len(names) == 0 {
|
||||
return "", errors.New("no active Lingma named pipe was found")
|
||||
}
|
||||
return PipeDir + names[len(names)-1], nil
|
||||
}
|
||||
|
||||
func ResolveWebSocketURL(explicit string) (string, error) {
|
||||
value := strings.TrimSpace(explicit)
|
||||
if value == "" {
|
||||
value = strings.TrimSpace(os.Getenv("LINGMA_PROXY_WS_URL"))
|
||||
}
|
||||
if value == "" {
|
||||
return "", errors.New("no Lingma websocket URL configured")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(value)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse Lingma websocket URL %q: %w", value, err)
|
||||
}
|
||||
if parsed.Scheme != "ws" && parsed.Scheme != "wss" {
|
||||
return "", fmt.Errorf("Lingma websocket URL must start with ws:// or wss://: %q", value)
|
||||
}
|
||||
if parsed.Host == "" {
|
||||
return "", fmt.Errorf("Lingma websocket URL is missing a host: %q", value)
|
||||
}
|
||||
if parsed.Path == "" {
|
||||
parsed.Path = "/"
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func hasConfiguredWebSocketURL(explicit string) bool {
|
||||
return strings.TrimSpace(explicit) != "" || strings.TrimSpace(os.Getenv("LINGMA_PROXY_WS_URL")) != ""
|
||||
}
|
||||
|
||||
func normalizePipePath(pipe string) string {
|
||||
if strings.HasPrefix(pipe, PipeDir) {
|
||||
return pipe
|
||||
}
|
||||
return PipeDir + pipe
|
||||
}
|
||||
|
||||
func connectTransport(ctx context.Context, opts DialOptions) (framedTransport, error) {
|
||||
switch opts.Transport {
|
||||
case TransportPipe:
|
||||
return connectPipeTransport(ctx, opts.PipePath)
|
||||
case TransportWebSocket:
|
||||
return connectWebSocketTransport(ctx, opts.WebSocketURL)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported Lingma transport %q", opts.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
type pipeTransport struct {
|
||||
path string
|
||||
conn net.Conn
|
||||
reader *framedReader
|
||||
write sync.Mutex
|
||||
}
|
||||
|
||||
func connectPipeTransport(ctx context.Context, pipePath string) (*pipeTransport, error) {
|
||||
conn, err := winio.DialPipeContext(ctx, pipePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect Lingma IPC pipe %s: %w", pipePath, err)
|
||||
}
|
||||
return &pipeTransport{
|
||||
path: pipePath,
|
||||
conn: conn,
|
||||
reader: newFramedReader(conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *pipeTransport) ReadFrame() ([]byte, error) {
|
||||
return t.reader.ReadFrame()
|
||||
}
|
||||
|
||||
func (t *pipeTransport) WriteFrame(body []byte) error {
|
||||
t.write.Lock()
|
||||
defer t.write.Unlock()
|
||||
|
||||
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
|
||||
if _, err := t.conn.Write(frame); err != nil {
|
||||
return fmt.Errorf("write frame header: %w", err)
|
||||
}
|
||||
if _, err := t.conn.Write(body); err != nil {
|
||||
return fmt.Errorf("write frame body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *pipeTransport) Close() error {
|
||||
return t.conn.Close()
|
||||
}
|
||||
|
||||
func (t *pipeTransport) Address() string {
|
||||
return t.path
|
||||
}
|
||||
|
||||
type websocketTransport struct {
|
||||
url string
|
||||
conn *websocket.Conn
|
||||
buffer bytes.Buffer
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func connectWebSocketTransport(ctx context.Context, wsURL string) (*websocketTransport, error) {
|
||||
dialer := websocket.Dialer{HandshakeTimeout: 5 * time.Second}
|
||||
conn, _, err := dialer.DialContext(ctx, wsURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect Lingma websocket %s: %w", wsURL, err)
|
||||
}
|
||||
return &websocketTransport{url: wsURL, conn: conn}, nil
|
||||
}
|
||||
|
||||
func (t *websocketTransport) ReadFrame() ([]byte, error) {
|
||||
for {
|
||||
if body, ok, err := tryReadBufferedFrame(&t.buffer); ok || err != nil {
|
||||
return body, err
|
||||
}
|
||||
|
||||
messageType, payload, err := t.conn.ReadMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
t.buffer.Write(payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *websocketTransport) WriteFrame(body []byte) error {
|
||||
t.writeMu.Lock()
|
||||
defer t.writeMu.Unlock()
|
||||
|
||||
frame := []byte(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body)))
|
||||
frame = append(frame, body...)
|
||||
if err := t.conn.WriteMessage(websocket.TextMessage, frame); err != nil {
|
||||
return fmt.Errorf("write websocket frame: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *websocketTransport) Close() error {
|
||||
return t.conn.Close()
|
||||
}
|
||||
|
||||
func (t *websocketTransport) Address() string {
|
||||
return t.url
|
||||
}
|
||||
|
||||
type framedReader struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func newFramedReader(r io.Reader) *framedReader {
|
||||
return &framedReader{reader: bufio.NewReader(r)}
|
||||
}
|
||||
|
||||
func (r *framedReader) ReadFrame() ([]byte, error) {
|
||||
contentLength := -1
|
||||
for {
|
||||
line, err := r.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if line == "\r\n" {
|
||||
break
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
|
||||
raw := strings.TrimSpace(line[len("content-length:"):])
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse content length %q: %w", raw, err)
|
||||
}
|
||||
contentLength = n
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, errors.New("missing Content-Length header")
|
||||
}
|
||||
|
||||
body := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(r.reader, body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func tryReadBufferedFrame(buffer *bytes.Buffer) ([]byte, bool, error) {
|
||||
data := buffer.Bytes()
|
||||
headerEnd := bytes.Index(data, []byte("\r\n\r\n"))
|
||||
if headerEnd < 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
contentLength := -1
|
||||
for _, rawLine := range bytes.Split(data[:headerEnd], []byte("\r\n")) {
|
||||
line := strings.TrimSpace(string(rawLine))
|
||||
if strings.HasPrefix(strings.ToLower(line), "content-length:") {
|
||||
raw := strings.TrimSpace(line[len("content-length:"):])
|
||||
n, err := strconv.Atoi(raw)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parse content length %q: %w", raw, err)
|
||||
}
|
||||
contentLength = n
|
||||
break
|
||||
}
|
||||
}
|
||||
if contentLength < 0 {
|
||||
return nil, false, errors.New("missing Content-Length header")
|
||||
}
|
||||
|
||||
bodyStart := headerEnd + len("\r\n\r\n")
|
||||
if len(data[bodyStart:]) < contentLength {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
frame := make([]byte, contentLength)
|
||||
copy(frame, data[bodyStart:bodyStart+contentLength])
|
||||
buffer.Next(bodyStart + contentLength)
|
||||
return frame, true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user