package middleware import ( "bufio" "context" "fmt" "net" "net/http" "time" "github.com/ouqiang/websocket" ) // WebSocketConn WebSocket连接接口 type WebSocketConn struct { conn *websocket.Conn } // NewWebSocketConn 创建WebSocket连接 func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { return &WebSocketConn{ conn: conn, } } // Close 关闭连接 func (c *WebSocketConn) Close() error { return c.conn.Close() } // ReadMessage 读取消息 func (c *WebSocketConn) ReadMessage() (int, []byte, error) { return c.conn.ReadMessage() } // WriteMessage 写入消息 func (c *WebSocketConn) WriteMessage(messageType int, data []byte) error { return c.conn.WriteMessage(messageType, data) } // RemoteAddr 获取远程地址 func (c *WebSocketConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // WebSocketUpgrader WebSocket升级器 type WebSocketUpgrader struct { upgrader websocket.Upgrader } // NewWebSocketUpgrader 创建WebSocket升级器 func NewWebSocketUpgrader(timeout time.Duration) *WebSocketUpgrader { return &WebSocketUpgrader{ upgrader: websocket.Upgrader{ HandshakeTimeout: timeout, ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, }, } } // Upgrade 升级HTTP连接为WebSocket连接 func (u *WebSocketUpgrader) Upgrade(conn net.Conn, req *http.Request) (*WebSocketConn, error) { bufConn, ok := conn.(interface { Hijack() (net.Conn, *bufio.ReadWriter, error) }) if !ok { return nil, fmt.Errorf("连接不支持Hijack") } netConn, bufrw, err := bufConn.Hijack() if err != nil { return nil, fmt.Errorf("hijack错误: %s", err) } wsConn, err := u.upgrader.Upgrade(newResponseWriter(netConn, bufrw), req, http.Header{}) if err != nil { return nil, err } return NewWebSocketConn(wsConn), nil } // WebSocketDialer WebSocket拨号器 type WebSocketDialer struct { dialer websocket.Dialer } // NewWebSocketDialer 创建WebSocket拨号器 func NewWebSocketDialer(timeout time.Duration) *WebSocketDialer { return &WebSocketDialer{ dialer: websocket.Dialer{ HandshakeTimeout: timeout, ReadBufferSize: 4096, WriteBufferSize: 4096, }, } } // Dial 连接到WebSocket服务器 func (d *WebSocketDialer) Dial(urlStr string, header http.Header) (*WebSocketConn, error) { ctx, cancel := context.WithTimeout(context.Background(), d.dialer.HandshakeTimeout) defer cancel() wsConn, _, err := d.dialer.DialContext(ctx, urlStr, header) if err != nil { return nil, err } return NewWebSocketConn(wsConn), nil } // 实现http.ResponseWriter接口,用于WebSocket升级 type responseWriter struct { conn net.Conn bufrw *bufio.ReadWriter header http.Header status int } func newResponseWriter(conn net.Conn, bufrw *bufio.ReadWriter) *responseWriter { return &responseWriter{ conn: conn, bufrw: bufrw, header: make(http.Header), status: http.StatusOK, } } func (rw *responseWriter) Header() http.Header { return rw.header } func (rw *responseWriter) Write(b []byte) (int, error) { return rw.bufrw.Write(b) } func (rw *responseWriter) WriteHeader(statusCode int) { rw.status = statusCode } func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return rw.conn, rw.bufrw, nil }