Files
goproxy/internal/middleware/websocket.go

148 lines
3.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bufio"
"context"
"fmt"
"net"
"net/http"
"time"
"github.com/ouqiang/websocket"
)
// WebSocketConn WebSocket连接接口
type WebSocketConn struct {
conn *websocket.Conn
}
// NewWebSocketConn 创建WebSocket连接
func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn {
return &WebSocketConn{
conn: conn,
}
}
// Close 关闭连接
func (c *WebSocketConn) Close() error {
return c.conn.Close()
}
// ReadMessage 读取消息
func (c *WebSocketConn) ReadMessage() (int, []byte, error) {
return c.conn.ReadMessage()
}
// WriteMessage 写入消息
func (c *WebSocketConn) WriteMessage(messageType int, data []byte) error {
return c.conn.WriteMessage(messageType, data)
}
// RemoteAddr 获取远程地址
func (c *WebSocketConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// WebSocketUpgrader WebSocket升级器
type WebSocketUpgrader struct {
upgrader websocket.Upgrader
}
// NewWebSocketUpgrader 创建WebSocket升级器
func NewWebSocketUpgrader(timeout time.Duration) *WebSocketUpgrader {
return &WebSocketUpgrader{
upgrader: websocket.Upgrader{
HandshakeTimeout: timeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
}
}
// Upgrade 升级HTTP连接为WebSocket连接
func (u *WebSocketUpgrader) Upgrade(conn net.Conn, req *http.Request) (*WebSocketConn, error) {
bufConn, ok := conn.(interface {
Hijack() (net.Conn, *bufio.ReadWriter, error)
})
if !ok {
return nil, fmt.Errorf("连接不支持Hijack")
}
netConn, bufrw, err := bufConn.Hijack()
if err != nil {
return nil, fmt.Errorf("hijack错误: %s", err)
}
wsConn, err := u.upgrader.Upgrade(newResponseWriter(netConn, bufrw), req, http.Header{})
if err != nil {
return nil, err
}
return NewWebSocketConn(wsConn), nil
}
// WebSocketDialer WebSocket拨号器
type WebSocketDialer struct {
dialer websocket.Dialer
}
// NewWebSocketDialer 创建WebSocket拨号器
func NewWebSocketDialer(timeout time.Duration) *WebSocketDialer {
return &WebSocketDialer{
dialer: websocket.Dialer{
HandshakeTimeout: timeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
},
}
}
// Dial 连接到WebSocket服务器
func (d *WebSocketDialer) Dial(urlStr string, header http.Header) (*WebSocketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), d.dialer.HandshakeTimeout)
defer cancel()
wsConn, _, err := d.dialer.DialContext(ctx, urlStr, header)
if err != nil {
return nil, err
}
return NewWebSocketConn(wsConn), nil
}
// 实现http.ResponseWriter接口用于WebSocket升级
type responseWriter struct {
conn net.Conn
bufrw *bufio.ReadWriter
header http.Header
status int
}
func newResponseWriter(conn net.Conn, bufrw *bufio.ReadWriter) *responseWriter {
return &responseWriter{
conn: conn,
bufrw: bufrw,
header: make(http.Header),
status: http.StatusOK,
}
}
func (rw *responseWriter) Header() http.Header {
return rw.header
}
func (rw *responseWriter) Write(b []byte) (int, error) {
return rw.bufrw.Write(b)
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.status = statusCode
}
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rw.conn, rw.bufrw, nil
}