148 lines
3.2 KiB
Go
148 lines
3.2 KiB
Go
package middleware
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"fmt"
|
||
"net"
|
||
"net/http"
|
||
"time"
|
||
|
||
"github.com/ouqiang/websocket"
|
||
)
|
||
|
||
// WebSocketConn WebSocket连接接口
|
||
type WebSocketConn struct {
|
||
conn *websocket.Conn
|
||
}
|
||
|
||
// NewWebSocketConn 创建WebSocket连接
|
||
func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn {
|
||
return &WebSocketConn{
|
||
conn: conn,
|
||
}
|
||
}
|
||
|
||
// Close 关闭连接
|
||
func (c *WebSocketConn) Close() error {
|
||
return c.conn.Close()
|
||
}
|
||
|
||
// ReadMessage 读取消息
|
||
func (c *WebSocketConn) ReadMessage() (int, []byte, error) {
|
||
return c.conn.ReadMessage()
|
||
}
|
||
|
||
// WriteMessage 写入消息
|
||
func (c *WebSocketConn) WriteMessage(messageType int, data []byte) error {
|
||
return c.conn.WriteMessage(messageType, data)
|
||
}
|
||
|
||
// RemoteAddr 获取远程地址
|
||
func (c *WebSocketConn) RemoteAddr() net.Addr {
|
||
return c.conn.RemoteAddr()
|
||
}
|
||
|
||
// WebSocketUpgrader WebSocket升级器
|
||
type WebSocketUpgrader struct {
|
||
upgrader websocket.Upgrader
|
||
}
|
||
|
||
// NewWebSocketUpgrader 创建WebSocket升级器
|
||
func NewWebSocketUpgrader(timeout time.Duration) *WebSocketUpgrader {
|
||
return &WebSocketUpgrader{
|
||
upgrader: websocket.Upgrader{
|
||
HandshakeTimeout: timeout,
|
||
ReadBufferSize: 4096,
|
||
WriteBufferSize: 4096,
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
},
|
||
}
|
||
}
|
||
|
||
// Upgrade 升级HTTP连接为WebSocket连接
|
||
func (u *WebSocketUpgrader) Upgrade(conn net.Conn, req *http.Request) (*WebSocketConn, error) {
|
||
bufConn, ok := conn.(interface {
|
||
Hijack() (net.Conn, *bufio.ReadWriter, error)
|
||
})
|
||
if !ok {
|
||
return nil, fmt.Errorf("连接不支持Hijack")
|
||
}
|
||
|
||
netConn, bufrw, err := bufConn.Hijack()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("hijack错误: %s", err)
|
||
}
|
||
|
||
wsConn, err := u.upgrader.Upgrade(newResponseWriter(netConn, bufrw), req, http.Header{})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return NewWebSocketConn(wsConn), nil
|
||
}
|
||
|
||
// WebSocketDialer WebSocket拨号器
|
||
type WebSocketDialer struct {
|
||
dialer websocket.Dialer
|
||
}
|
||
|
||
// NewWebSocketDialer 创建WebSocket拨号器
|
||
func NewWebSocketDialer(timeout time.Duration) *WebSocketDialer {
|
||
return &WebSocketDialer{
|
||
dialer: websocket.Dialer{
|
||
HandshakeTimeout: timeout,
|
||
ReadBufferSize: 4096,
|
||
WriteBufferSize: 4096,
|
||
},
|
||
}
|
||
}
|
||
|
||
// Dial 连接到WebSocket服务器
|
||
func (d *WebSocketDialer) Dial(urlStr string, header http.Header) (*WebSocketConn, error) {
|
||
ctx, cancel := context.WithTimeout(context.Background(), d.dialer.HandshakeTimeout)
|
||
defer cancel()
|
||
|
||
wsConn, _, err := d.dialer.DialContext(ctx, urlStr, header)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return NewWebSocketConn(wsConn), nil
|
||
}
|
||
|
||
// 实现http.ResponseWriter接口,用于WebSocket升级
|
||
type responseWriter struct {
|
||
conn net.Conn
|
||
bufrw *bufio.ReadWriter
|
||
header http.Header
|
||
status int
|
||
}
|
||
|
||
func newResponseWriter(conn net.Conn, bufrw *bufio.ReadWriter) *responseWriter {
|
||
return &responseWriter{
|
||
conn: conn,
|
||
bufrw: bufrw,
|
||
header: make(http.Header),
|
||
status: http.StatusOK,
|
||
}
|
||
}
|
||
|
||
func (rw *responseWriter) Header() http.Header {
|
||
return rw.header
|
||
}
|
||
|
||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||
return rw.bufrw.Write(b)
|
||
}
|
||
|
||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||
rw.status = statusCode
|
||
}
|
||
|
||
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||
return rw.conn, rw.bufrw, nil
|
||
}
|