package wsgrpc import ( "bytes" "context" "crypto/tls" "errors" "fmt" "net" "net/http" "sync" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) var ( WSGrpcError = errors.New("wsgrpc error") ) // --------------------------------------- // 通用 websocketConn 实现 net.Conn 接口 // --------------------------------------- type websocketConn struct { ws *websocket.Conn readMutex sync.Mutex writeMutex sync.Mutex // 缓存由于一次读取没有全部消耗完的数据 readBuffer bytes.Buffer } // Read 实现对 websocket 消息的分段读取 func (c *websocketConn) Read(p []byte) (int, error) { c.readMutex.Lock() defer c.readMutex.Unlock() // 若缓冲区为空,则阻塞读取下一条消息 if c.readBuffer.Len() == 0 { messageType, data, err := c.ws.ReadMessage() if err != nil { return 0, errors.Join(err, errors.New("wsgrpc read message error"), WSGrpcError) } // 只接受二进制数据 if messageType != websocket.BinaryMessage { return 0, errors.Join(fmt.Errorf("unexpected message type: %d", messageType), WSGrpcError) } c.readBuffer.Write(data) } if n, err := c.readBuffer.Read(p); err != nil { return n, errors.Join(err, WSGrpcError) } else { return n, nil } } // Write 将数据作为单条二进制消息发送 func (c *websocketConn) Write(p []byte) (int, error) { c.writeMutex.Lock() defer c.writeMutex.Unlock() err := c.ws.WriteMessage(websocket.BinaryMessage, p) if err != nil { return 0, errors.Join(err, errors.New("wsgrpc write message error"), WSGrpcError) } return len(p), nil } // Close 关闭 websocket 连接 func (c *websocketConn) Close() error { err := c.ws.Close() if err != nil { return errors.Join(err, errors.New("wsgrpc close error"), WSGrpcError) } return nil } // LocalAddr 返回本地地址,通过 websocket 底层连接获取 func (c *websocketConn) LocalAddr() net.Addr { if conn := c.ws.UnderlyingConn(); conn != nil { return conn.LocalAddr() } return nil } // RemoteAddr 返回远端地址 func (c *websocketConn) RemoteAddr() net.Addr { if conn := c.ws.UnderlyingConn(); conn != nil { return conn.RemoteAddr() } return nil } // SetDeadline 同时设置读写超时 func (c *websocketConn) SetDeadline(t time.Time) error { if err := c.ws.SetReadDeadline(t); err != nil { return errors.Join(err, errors.New("wsgrpc set read deadline error"), WSGrpcError) } if err := c.ws.SetWriteDeadline(t); err != nil { return errors.Join(err, errors.New("wsgrpc set write deadline error"), WSGrpcError) } return nil } // SetReadDeadline 设置读超时 func (c *websocketConn) SetReadDeadline(t time.Time) error { return c.ws.SetReadDeadline(t) } // SetWriteDeadline 设置写超时 func (c *websocketConn) SetWriteDeadline(t time.Time) error { return c.ws.SetWriteDeadline(t) } // --------------------------------------- // 客户端 WebSocket Dialer // --------------------------------------- type LogInterface interface { Infof(format string, args ...interface{}) Errorf(format string, args ...interface{}) Tracef(format string, args ...interface{}) } // WebsocketDialer 返回一个可以用于 grpc.WithContextDialer 的拨号函数;该函数通过 websocket 建立连接。 // 参数 url 表示 websocket 服务器地址;header 可用于传递额外的 header 参数。 func WebsocketDialer(url string, header http.Header, insecure bool, log LogInterface) func(ctx context.Context, addr string) (net.Conn, error) { return func(ctx context.Context, addr string) (net.Conn, error) { dialer := websocket.Dialer{ TLSClientConfig: &tls.Config{InsecureSkipVerify: insecure}, } log.Tracef("dialing websocket server [%s]", url) ws, _, err := dialer.DialContext(ctx, url, header) if err != nil { log.Errorf("wsgrpc dialer error: %v", err) return nil, errors.Join(err, errors.New("wsgrpc dialer error"), WSGrpcError) } log.Tracef("websocket connection connect done") return &websocketConn{ws: ws}, nil } } // --------------------------------------- // 服务端 WebSocket Listener 及 Gin Handler // --------------------------------------- // WSListener 实现了 net.Listener 接口,用于接收 websocket 升级后的连接。 // gRPC server 可直接传入 WSListener 实例作为监听器调用 Serve 方法。 type WSListener struct { connCh chan net.Conn mu sync.Mutex closed bool addr net.Addr done chan struct{} } // dummyAddr 用于 WSListener 的 Addr 实现 type dummyAddr struct { network string address string } func (d dummyAddr) Network() string { return d.network } func (d dummyAddr) String() string { return d.address } // NewWSListener 创建一个 WSListener 实例。 // 参数 addr 表示监听地址,network 建议为固定字符串(例如:"ws"),bufSize 为连接队列大小。 func NewWSListener(addr, network string, bufSize int) *WSListener { return &WSListener{ connCh: make(chan net.Conn, bufSize), addr: dummyAddr{network: network, address: addr}, done: make(chan struct{}), } } // Accept 等待并返回下一个连接 func (l *WSListener) Accept() (net.Conn, error) { select { case conn, ok := <-l.connCh: if !ok { return nil, errors.Join(fmt.Errorf("listener closed"), WSGrpcError) } return conn, nil case <-l.done: return nil, errors.Join(fmt.Errorf("listener closed"), WSGrpcError) } } // Close 关闭 WSListener func (l *WSListener) Close() error { l.mu.Lock() defer l.mu.Unlock() if l.closed { return nil } l.closed = true close(l.done) close(l.connCh) return nil } // Addr 返回本监听器的地址 func (l *WSListener) Addr() net.Addr { return l.addr } // GinWSHandler 返回一个 Gin 的 HandlerFunc,用于处理 HTTP 请求,将其升级为 WebSocket 连接 // 并包装为 websocketConn 后推送到 WSListener 中,以供 gRPC server 使用。 // 参数 upgrader 可对 websocket 升级过程进行自定义配置。 func GinWSHandler(listener *WSListener, upgrader *websocket.Upgrader) gin.HandlerFunc { return func(c *gin.Context) { ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { c.String(http.StatusInternalServerError, "ws upgrade error: %v", err) return } conn := &websocketConn{ws: ws} // 非阻塞方式将连接推送到 listener select { case listener.connCh <- conn: // 推送成功后,可选进行应答 default: // 队列满则关闭连接 ws.Close() c.String(http.StatusServiceUnavailable, "connection queue is full") return } } } // ------------------------------ // 使用示例 // ------------------------------ // 假设我们有这样一个 main 文件使用上述库: /* package main import ( "context" "log" "net/http" "github.com/gin-gonic/gin" "google.golang.org/grpc" "vaalacat/frp-panel/utils/wsgrpc" "github.com/gorilla/websocket" ) // 服务端实例 func main() { // 创建 WebSocket Listener,缓冲队列大小为 100,地址和网络标识可自定义 listener := wsgrpc.NewWSListener("ws-listener", "ws", 100) // 在单独的 goroutine 中启动 gRPC Server go func() { grpcServer := grpc.NewServer() // 在此注册你的 gRPC 服务… if err := grpcServer.Serve(listener); err != nil { log.Fatalf("gRPC server error: %v", err) } }() // 使用 Gin 创建 HTTP 服务器,并在某个路径下提供 WebSocket 功能 router := gin.Default() // 创建一个简单的 upgrader 实例;可根据需要自定义 CheckOrigin 等选项 upgrader := &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } // 注册 WebSocket 处理 handler,路径可自定义,例如 /ws router.GET("/ws", wsgrpc.GinWSHandler(listener, upgrader)) // 启动 HTTP 服务 if err := router.Run(":8080"); err != nil { log.Fatalf("HTTP server error: %v", err) } // 示例中,当 HTTP 请求升级为 WebSocket 后,会将连接推入 listener, // gRPC Server 的 Accept 就会获取到该 net.Conn 连接,实现 gRPC 请求的代理。 } 客户端示例: func main() { // 定义 websocket 服务器地址和 header(如果有需要) wsURL := "ws://127.0.0.1:8080/ws" // 示例地址 header := http.Header{} // 创建 websocket dialer dialer := wsgrpc.WebsocketDialer(wsURL, header) // 使用 grpc.WithContextDialer 配置 GRPC Dial conn, err := grpc.DialContext(context.Background(), "ignored", grpc.WithContextDialer(dialer), grpc.WithInsecure(), // 示例中禁用 TLS,生产环境建议使用安全连接 ) if err != nil { log.Fatalf("failed to dial: %v", err) } defer conn.Close() // 接下来可使用 conn 创建 GRPC 客户端进行调用 } */