package proxy import ( "bufio" "context" "crypto/elliptic" "crypto/tls" "fmt" "io" "log/slog" "net" "net/http" "net/http/httptrace" "net/url" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/darkit/goproxy/internal/auth" "github.com/darkit/goproxy/internal/cache" "github.com/darkit/goproxy/internal/config" "github.com/darkit/goproxy/internal/healthcheck" "github.com/darkit/goproxy/internal/loadbalance" "github.com/darkit/goproxy/internal/metrics" "github.com/darkit/goproxy/internal/middleware" "github.com/darkit/goproxy/internal/reverse" "github.com/ouqiang/websocket" "github.com/viki-org/dnscache" ) const ( // 连接目标服务器超时时间 defaultTargetConnectTimeout = 5 * time.Second // 目标服务器读写超时时间 defaultTargetReadWriteTimeout = 10 * time.Second ) // 隧道连接成功响应行 var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n") // 错误网关响应 var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))) // 对象池 var ( bufPool = sync.Pool{ New: func() interface{} { return make([]byte, 32*1024) }, } ctxPool = sync.Pool{ New: func() interface{} { return new(Context) }, } ) // CertificateCache 证书缓存接口 type CertificateCache interface { // Get 获取证书 Get(host string) *tls.Certificate // Set 设置证书 Set(host string, cert *tls.Certificate) } // MemCertCache 内存证书缓存 type MemCertCache struct { certs sync.Map } // Get 获取证书 func (c *MemCertCache) Get(host string) *tls.Certificate { v, ok := c.certs.Load(host) if !ok { return nil } return v.(*tls.Certificate) } // Set 设置证书 func (c *MemCertCache) Set(host string, cert *tls.Certificate) { c.certs.Store(host, cert) } // CacheAdapter 缓存适配器,统一不同缓存实现的接口 type CacheAdapter struct { cache interface{} // 缓存方法类型标志 getMethodType int setMethodType int // 方法类型常量 getResponseBool int getInterfaceBool int setResponseOnly int setResponseTTL int setInterfaceOnly int } // NewCacheAdapter 创建缓存适配器 func NewCacheAdapter(cache interface{}) *CacheAdapter { adapter := &CacheAdapter{ cache: cache, // 方法类型常量初始化 getResponseBool: 1, getInterfaceBool: 2, setResponseOnly: 1, setResponseTTL: 2, setInterfaceOnly: 3, } // 判断支持的方法类型 if _, ok := cache.(interface { Get(string) (*http.Response, bool) }); ok { adapter.getMethodType = adapter.getResponseBool } else if _, ok := cache.(interface { Get(string) (interface{}, bool) }); ok { adapter.getMethodType = adapter.getInterfaceBool } if _, ok := cache.(interface { Set(string, *http.Response, time.Duration) }); ok { adapter.setMethodType = adapter.setResponseTTL } else if _, ok := cache.(interface { Set(string, *http.Response) }); ok { adapter.setMethodType = adapter.setResponseOnly } else if _, ok := cache.(interface { Set(string, interface{}) }); ok { adapter.setMethodType = adapter.setInterfaceOnly } return adapter } // Get 统一的获取方法 func (a *CacheAdapter) Get(key string) (interface{}, bool) { switch a.getMethodType { case a.getResponseBool: if getter, ok := a.cache.(interface { Get(string) (*http.Response, bool) }); ok { return getter.Get(key) } case a.getInterfaceBool: if getter, ok := a.cache.(interface { Get(string) (interface{}, bool) }); ok { return getter.Get(key) } } return nil, false } // Set 统一的设置方法 func (a *CacheAdapter) Set(key string, value interface{}, ttl time.Duration) { resp, isResponse := value.(*http.Response) switch a.setMethodType { case a.setResponseTTL: if setter, ok := a.cache.(interface { Set(string, *http.Response, time.Duration) }); ok && isResponse { setter.Set(key, resp, ttl) } case a.setResponseOnly: if setter, ok := a.cache.(interface { Set(string, *http.Response) }); ok && isResponse { setter.Set(key, resp) } case a.setInterfaceOnly: if setter, ok := a.cache.(interface { Set(string, interface{}) }); ok { setter.Set(key, value) } } } // Options 代理选项 type Options struct { // 配置 Config *config.Config // 委托 Delegate Delegate // 证书缓存 CertCache CertificateCache // HTTP缓存 HTTPCache cache.Cache // 负载均衡器 LoadBalancer loadbalance.LoadBalancer // 健康检查器 HealthChecker *healthcheck.HealthChecker // 监控指标 Metrics metrics.MetricsCollector // 客户端跟踪 ClientTrace *httptrace.ClientTrace // 认证系统 Auth *auth.Auth // 证书管理器 CertManager *CertManager } // WithAuth 设置认证系统 func WithAuth(auth *auth.Auth) Option { return func(o *Options) { o.Auth = auth } } // Proxy HTTP代理 type Proxy struct { // 配置 config *config.Config // 委托 delegate Delegate // 证书缓存 certCache CertificateCache // HTTP缓存 httpCache cache.Cache // 缓存适配器 cacheAdapter *CacheAdapter // 负载均衡器 loadBalancer loadbalance.LoadBalancer // 健康检查器 healthChecker *healthcheck.HealthChecker // 监控指标 metrics metrics.MetricsCollector // 客户端跟踪 clientTrace *httptrace.ClientTrace // 基础传输(用于直接获取*http.Transport类型) transport *http.Transport // HTTP请求传输(可能被中间件包装) httpTransport http.RoundTripper // DNS缓存 dnsCache *dnscache.Resolver // 客户端连接数 clientConnNum int32 // 证书管理器 certManager *CertManager // 日志记录器 logger *slog.Logger } // New 创建代理 func New(opts *Options) *Proxy { if opts == nil { opts = &Options{} } if opts.Config == nil { opts.Config = config.DefaultConfig() } if opts.Delegate == nil { opts.Delegate = &DefaultDelegate{} } p := &Proxy{ config: opts.Config, delegate: opts.Delegate, certCache: opts.CertCache, httpCache: opts.HTTPCache, loadBalancer: opts.LoadBalancer, healthChecker: opts.HealthChecker, metrics: opts.Metrics, clientTrace: opts.ClientTrace, clientConnNum: 0, logger: opts.Config.Logger, } // 如果存在HTTP缓存,创建缓存适配器 if p.httpCache != nil { p.cacheAdapter = NewCacheAdapter(p.httpCache) } // 创建DNS缓存 p.dnsCache = dnscache.New(opts.Config.DNSCacheTTL) // 设置证书管理器 if opts.CertManager != nil { // 如果选项中已提供证书管理器,直接使用 p.certManager = opts.CertManager } else if opts.Config.DecryptHTTPS { // 如果启用了HTTPS解密,且未提供证书管理器,则创建一个新的证书管理器 certManagerOpts := []CertManagerOption{ WithDefaultPrivateKey(true), // 使用默认私钥提高性能 WithValidityYears(1), // 证书有效期1年 WithUseECDSA(opts.Config.UseECDSA), // 根据配置决定是否使用ECDSA } // 如果配置指定了使用ECDSA,设置曲线为P-256 if opts.Config.UseECDSA { certManagerOpts = append(certManagerOpts, WithCurve(elliptic.P256())) } p.certManager = NewCertManager(p.certCache, certManagerOpts...) } // 创建基础传输 httpTransport := &http.Transport{ Proxy: p.proxyFromDelegate, DialContext: p.dialContextWithCache(), MaxIdleConns: opts.Config.ConnectionPoolSize, IdleConnTimeout: opts.Config.IdleTimeout, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, DisableKeepAlives: false, DisableCompression: false, ForceAttemptHTTP2: true, } // 保存原始Transport供后续使用 p.transport = httpTransport // 包装传输对象,应用中间件 var roundTripper http.RoundTripper = httpTransport // 应用重试中间件 if opts.Config.EnableRetry { policy := &middleware.RetryPolicy{ MaxRetries: opts.Config.MaxRetries, BaseBackoff: opts.Config.RetryBackoff, MaxBackoff: opts.Config.MaxRetryBackoff, } retryMiddleware := middleware.NewRetryMiddleware(policy) roundTripper = retryMiddleware.Transport(roundTripper) } // 最终的RoundTripper赋值给p.httpTransport,用于HTTP请求 p.httpTransport = roundTripper // 将健康检查器与负载均衡器集成 if p.healthChecker != nil && p.loadBalancer != nil { p.healthChecker.SetStatusChangeCallback(func(target string, healthy bool) { if healthy { p.loadBalancer.MarkUp(target) } else { p.loadBalancer.MarkDown(target) } }) } return p } // NewProxy 使用functional options模式创建代理 func NewProxy(options ...Option) *Proxy { // 创建默认选项 opts := &Options{ Config: config.DefaultConfig(), } // 应用所有选项 for _, option := range options { option(opts) } // 使用传统方法创建代理 return New(opts) } // ServeHTTP 处理HTTP请求 func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // 判断是反向代理还是正向代理 if p.config.ReverseProxy { // 如果是反向代理模式,使用反向代理处理请求 reverseProxy, err := reverse.New(convertToReverseConfig(p.config)) if err != nil { p.logger.Error("创建反向代理失败", "error", err.Error(), ) http.Error(rw, "Internal Server Error", http.StatusInternalServerError) return } reverseProxy.ServeHTTP(rw, req) return } // 更新请求计数指标 if p.metrics != nil { p.metrics.IncRequestCount() } // 处理请求 ctx := ctxPool.Get().(*Context) ctx.Reset(req) defer ctxPool.Put(ctx) // 调用连接事件 p.delegate.Connect(ctx, rw) // 认证检查 p.delegate.Auth(ctx, rw) if ctx.IsAborted() { return } // HTTP隧道连接(CONNECT方法) if req.Method == http.MethodConnect { p.tunnelProxy(ctx, rw) return } // 如果是WebSocket请求,使用WebSocket代理 if isWebSocketRequest(req) { clientConn, err := hijacker(rw) if err != nil { p.delegate.ErrorLog(err) http.Error(rw, "无法处理WebSocket请求", http.StatusInternalServerError) return } p.websocketProxy(ctx, clientConn) return } // 处理普通HTTP请求 p.handleHTTP(ctx, rw) } // handleHTTP 处理HTTP请求 func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) { // 调用请求前事件 p.delegate.BeforeRequest(ctx) if ctx.IsAborted() { return } // 开始时间 startTime := time.Now() // 获取上级代理 parentProxy, err := p.proxyFromDelegate(ctx.Req) if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - 获取上级代理错误: %s", ctx.Req.URL.Host, err)) ctx.ParentProxyURL = nil } else { ctx.ParentProxyURL = parentProxy } var ( resp *http.Response req = ctx.Req ) // 从缓存获取响应 if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) { cacheKey := generateCacheKey(req) var cachedResp interface{} var ok bool // 使用缓存适配器获取数据 if p.cacheAdapter != nil { cachedResp, ok = p.cacheAdapter.Get(cacheKey) if ok && cachedResp != nil { // 从缓存中找到响应 resp = cachedResp.(*http.Response) // 更新指标 if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { incrementCacheHit(p.metrics) } } } } // 如果缓存中没有,则发送请求 if resp == nil { // 创建传输上下文 reqCtx := req.Context() if p.clientTrace != nil { reqCtx = httptrace.WithClientTrace(reqCtx, p.clientTrace) } // 设置请求超时 if p.config.RequestTimeout > 0 { var cancel context.CancelFunc reqCtx, cancel = context.WithTimeout(reqCtx, p.config.RequestTimeout) defer cancel() } req = req.WithContext(reqCtx) // 发送请求 var err error resp, err = p.httpTransport.RoundTrip(req) // 处理错误 if err != nil { p.delegate.BeforeResponse(ctx, nil, err) p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", req.URL.Host, err)) http.Error(rw, err.Error(), http.StatusBadGateway) return } // 更新指标 if p.metrics != nil { p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds()) } // 缓存响应 if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) && canCacheStatus(resp.StatusCode) { cacheKey := generateCacheKey(req) // 使用缓存适配器设置数据 if p.cacheAdapter != nil { p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) } } } // 调用响应前事件 p.delegate.BeforeResponse(ctx, resp, nil) if ctx.IsAborted() { return } // 复制头部信息 for key, values := range resp.Header { for _, value := range values { rw.Header().Add(key, value) } } // 写入状态码 rw.WriteHeader(resp.StatusCode) // 复制响应体 _, err = io.Copy(rw, resp.Body) if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err)) } // 关闭响应体 resp.Body.Close() // 调用完成事件 p.delegate.Finish(ctx) } // canCacheMethod 检查请求方法是否可缓存 func canCacheMethod(method string) bool { return method == http.MethodGet || method == http.MethodHead } // canCacheStatus 检查响应状态码是否可缓存 func canCacheStatus(statusCode int) bool { return statusCode >= 200 && statusCode < 400 } // generateCacheKey 生成缓存键 func generateCacheKey(req *http.Request) string { return req.Method + " " + req.URL.String() } // getCacheTTL 获取缓存TTL func getCacheTTL(resp *http.Response) time.Duration { // 默认5分钟 ttl := 5 * time.Minute // 从Cache-Control获取max-age cacheControl := resp.Header.Get("Cache-Control") if cacheControl != "" { for _, directive := range strings.Split(cacheControl, ",") { directive = strings.TrimSpace(directive) if strings.HasPrefix(directive, "max-age=") { maxAge := strings.TrimPrefix(directive, "max-age=") if seconds, err := strconv.Atoi(maxAge); err == nil { ttl = time.Duration(seconds) * time.Second } break } } } return ttl } // ClientConnNum 获取客户端连接数 func (p *Proxy) ClientConnNum() int32 { return atomic.LoadInt32(&p.clientConnNum) } // DoRequest 执行HTTP请求 func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) { if ctx.Data == nil { ctx.Data = make(map[interface{}]interface{}) } // 请求前处理 p.delegate.BeforeRequest(ctx) if ctx.IsAborted() { return } // 检查缓存 if p.httpCache != nil && ctx.Req.Method == http.MethodGet && p.config.EnableCache { cacheKey := cache.GenerateCacheKey(ctx.Req) if p.cacheAdapter != nil { cachedResp, ok := p.cacheAdapter.Get(cacheKey) if ok && cachedResp != nil { // 使用缓存的响应 cached := cachedResp.(*http.Response) p.delegate.BeforeResponse(ctx, cached, nil) if !ctx.IsAborted() { responseFunc(cached, nil) } // 更新指标 if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { incrementCacheHit(p.metrics) } return } } } // 准备请求 newReq := ctx.Req.Clone(ctx.Req.Context()) // 移除hop-by-hop头部 for _, h := range hopHeaders { newReq.Header.Del(h) } // 添加客户端跟踪 if p.clientTrace != nil { newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace)) } // 执行请求 resp, err := p.httpTransport.RoundTrip(newReq) // 响应前处理 p.delegate.BeforeResponse(ctx, resp, err) if ctx.IsAborted() { if resp != nil { resp.Body.Close() } return } // 错误处理 if err != nil { responseFunc(nil, err) return } // 移除hop-by-hop头部 for _, h := range hopHeaders { resp.Header.Del(h) } // 缓存响应 if p.httpCache != nil && p.config.EnableCache && cache.ShouldCache(ctx.Req, resp) { cacheKey := cache.GenerateCacheKey(ctx.Req) if p.cacheAdapter != nil { p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) } } // 返回响应 responseFunc(resp, nil) } // HTTPS代理 func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) { if isWebSocketRequest(ctx.Req) { p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil)) return } p.DoRequest(ctx, func(resp *http.Response, err error) { if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,请求错误: %s", ctx.Req.URL, err)) tlsClientConn.Write(badGatewayResponse) return } // 直接写入TLS连接 err = resp.Write(tlsClientConn) if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,响应写入客户端失败: %s", ctx.Req.URL, err)) } resp.Body.Close() }) } // 隧道代理 func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) { // 获取客户端连接 clientConn, err := hijacker(rw) if err != nil { p.delegate.ErrorLog(err) rw.WriteHeader(http.StatusBadGateway) return } defer clientConn.Close() // 处理WebSocket请求 if isWebSocketRequest(ctx.Req) { p.websocketProxy(ctx, clientConn) return } // 获取上级代理 parentProxyURL, err := p.delegate.ParentProxy(ctx.Req) if ctx.ParentProxyURL != nil { parentProxyURL = ctx.ParentProxyURL } if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err)) rw.WriteHeader(http.StatusBadGateway) return } // 如果不使用上级代理,通知客户端隧道已建立 if parentProxyURL == nil { _, err = clientConn.Write(tunnelEstablishedResponseLine) if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err)) return } } // 检测WebSocket isWebsocket := false methodBytes, err := clientConn.Peek(3) if err == nil && string(methodBytes) == http.MethodGet { isWebsocket = true } // 处理WebSocket if isWebsocket { req, err := http.ReadRequest(clientConn.BufferReader()) if err != nil { if err != io.EOF { p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err)) } return } req.RemoteAddr = ctx.Req.RemoteAddr req.URL.Scheme = "http" req.URL.Host = req.Host ctx.Req = req p.websocketProxy(ctx, clientConn) return } // HTTPS解密 var tlsClientConn *tls.Conn if p.config.DecryptHTTPS { // 生成证书 certConfig, err := p.generateTLSConfig(ctx.Req.URL.Host) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,生成证书失败: %s", ctx.Req.URL.Host, err)) return } // 创建TLS服务器连接 tlsClientConn = tls.Server(clientConn, certConfig) defer tlsClientConn.Close() // TLS握手 if err := tlsClientConn.Handshake(); err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,握手失败: %s", ctx.Req.URL.Host, err)) return } // 读取HTTPS请求 buf := bufio.NewReader(tlsClientConn) tlsReq, err := http.ReadRequest(buf) if err != nil { if err != io.EOF { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,读取客户端请求失败: %s", ctx.Req.URL.Host, err)) } return } // 更新请求信息 tlsReq.RemoteAddr = ctx.Req.RemoteAddr tlsReq.URL.Scheme = "https" tlsReq.URL.Host = tlsReq.Host ctx.Req = tlsReq } // 确定目标地址 targetAddr := ctx.Req.URL.Host if ctx.TargetAddr != "" { targetAddr = ctx.TargetAddr } else if parentProxyURL != nil { targetAddr = parentProxyURL.Host } // 确保地址包含端口 if !strings.Contains(targetAddr, ":") { targetAddr += ":443" } // 连接目标服务器 targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err)) return } defer targetConn.Close() // 向上级代理发送CONNECT请求 if parentProxyURL != nil { tunnelRequestLine := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", ctx.Req.URL.Host, ctx.Req.URL.Host) _, err = targetConn.Write([]byte(tunnelRequestLine)) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - 向上级代理发送CONNECT请求失败: %s", ctx.Req.URL.Host, err)) return } // 读取上级代理响应 bufReader := bufio.NewReader(targetConn) resp, err := http.ReadResponse(bufReader, ctx.Req) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - 读取上级代理响应失败: %s", ctx.Req.URL.Host, err)) return } defer resp.Body.Close() // 检查上级代理响应 if resp.StatusCode != http.StatusOK { p.tunnelConnected(ctx, fmt.Errorf("上级代理返回错误状态码: %d", resp.StatusCode)) p.delegate.ErrorLog(fmt.Errorf("%s - 上级代理返回错误状态码: %d", ctx.Req.URL.Host, resp.StatusCode)) return } } // 处理HTTPS解密或直接隧道转发 if p.config.DecryptHTTPS { p.httpsProxy(ctx, tlsClientConn) } else { p.tunnelConnected(ctx, nil) p.transfer(clientConn, targetConn) } } // WebSocket代理 func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) { if !p.config.WebSocketIntercept { // 不拦截WebSocket,直接转发 remoteAddr := ctx.Addr() var err error var targetConn net.Conn // 根据协议建立连接 if ctx.IsHTTPS() { targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) } else { targetConn, err = net.Dial("tcp", remoteAddr) } if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) return } // 将请求转发给目标 err = ctx.Req.Write(targetConn) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) return } // 开始转发数据 p.tunnelConnected(ctx, nil) p.transfer(srcConn, targetConn) return } // 创建WebSocket升级器 upgrader := websocket.Upgrader{ HandshakeTimeout: defaultTargetConnectTimeout, ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, } // 创建伪响应写入器,用于升级WebSocket连接 responseWriter := newResponseWriter(srcConn) // 升级源连接为WebSocket连接 srcWSConn, err := upgrader.Upgrade(responseWriter, ctx.Req, nil) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - 升级WebSocket连接失败: %s", ctx.Req.URL.Host, err)) return } defer srcWSConn.Close() // 构建目标URL u := url.URL{ Scheme: func() string { if ctx.IsHTTPS() { return "wss" } return "ws" }(), Host: ctx.Req.URL.Host, Path: ctx.Req.URL.Path, RawQuery: ctx.Req.URL.RawQuery, } // 创建目标WebSocket连接 dialer := websocket.Dialer{ HandshakeTimeout: defaultTargetConnectTimeout, ReadBufferSize: 4096, WriteBufferSize: 4096, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } // 连接到目标WebSocket服务器 targetWSConn, _, err := dialer.Dial(u.String(), ctx.Req.Header) if err != nil { p.tunnelConnected(ctx, err) p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标WebSocket服务器失败: %s", ctx.Req.URL.Host, err)) return } defer targetWSConn.Close() // 连接成功,通知 p.tunnelConnected(ctx, nil) // 开始WebSocket消息转发 p.transferWebSocket(ctx, srcWSConn, targetWSConn) } // transferWebSocket 使用WebSocket协议进行双向消息转发 func (p *Proxy) transferWebSocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) { doneCtx, cancel := context.WithCancel(context.Background()) defer cancel() // 源到目标 go func() { for { if doneCtx.Err() != nil { return } // 读取源消息,正确处理消息类型 msgType, msg, err := srcConn.ReadMessage() if err != nil { p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err)) cancel() // 取消另一个goroutine return } // 调用消息拦截接口 p.delegate.WebSocketSendMessage(ctx, &msgType, &msg) // 写入目标,保留原始消息类型(文本/二进制) err = targetConn.WriteMessage(msgType, msg) if err != nil { p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err)) cancel() // 取消另一个goroutine return } } }() // 目标到源 for { if doneCtx.Err() != nil { return } // 读取目标消息,正确处理消息类型 msgType, msg, err := targetConn.ReadMessage() if err != nil { p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err)) cancel() // 取消另一个goroutine return } // 调用消息拦截接口 p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg) // 写入源,保留原始消息类型(文本/二进制) err = srcConn.WriteMessage(msgType, msg) if err != nil { p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err)) cancel() // 取消另一个goroutine return } } } // 用于WebSocket升级的响应写入器 type responseWriter struct { conn *ConnBuffer header http.Header statusCode int } func newResponseWriter(conn *ConnBuffer) *responseWriter { return &responseWriter{ conn: conn, header: make(http.Header), statusCode: http.StatusOK, } } func (rw *responseWriter) Header() http.Header { return rw.header } func (rw *responseWriter) Write(b []byte) (int, error) { return rw.conn.Write(b) } func (rw *responseWriter) WriteHeader(statusCode int) { rw.statusCode = statusCode } // 双向转发 func (p *Proxy) transfer(src net.Conn, dst net.Conn) { // 创建完成通道 done := make(chan struct{}, 2) // src -> dst go func() { buf := bufPool.Get().([]byte) written, err := io.CopyBuffer(dst, src, buf) if err != nil { p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) } // 记录传输字节数 if p.metrics != nil { p.metrics.AddBytesTransferred("request", written) } bufPool.Put(buf) dst.Close() done <- struct{}{} }() // dst -> src go func() { buf := bufPool.Get().([]byte) written, err := io.CopyBuffer(src, dst, buf) if err != nil { p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err)) } // 记录传输字节数 if p.metrics != nil { p.metrics.AddBytesTransferred("response", written) } bufPool.Put(buf) src.Close() done <- struct{}{} }() // 等待两个方向都结束 <-done <-done } // 隧道连接处理 func (p *Proxy) tunnelConnected(ctx *Context, err error) { ctx.TunnelProxy = true p.delegate.BeforeRequest(ctx) if err != nil { p.delegate.BeforeResponse(ctx, nil, err) return } resp := &http.Response{ Status: "200 Connection Established", StatusCode: http.StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: http.Header{}, Body: http.NoBody, } p.delegate.BeforeResponse(ctx, resp, nil) } // 使用DNS缓存的DialContext func (p *Proxy) dialContextWithCache() func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) { // 创建拨号器 dialer := &net.Dialer{ Timeout: defaultTargetConnectTimeout, KeepAlive: 30 * time.Second, } // 如果没有启用DNS缓存,直接拨号 if p.dnsCache == nil { return dialer.DialContext(ctx, network, addr) } // 解析主机和端口 separator := strings.LastIndex(addr, ":") if separator < 0 { return nil, fmt.Errorf("invalid address: %s", addr) } host := addr[:separator] port := addr[separator:] // 查询DNS缓存 ips, err := p.dnsCache.Fetch(host) if err != nil { return nil, err } // 使用第一个IPv4地址 var ip string for _, item := range ips { ip = item.String() if !strings.Contains(ip, ":") { break } } if ip == "" { return nil, fmt.Errorf("no valid IP address found for: %s", host) } // 连接到解析后的IP return dialer.DialContext(ctx, network, ip+port) } } // 从委托获取代理 func (p *Proxy) proxyFromDelegate(req *http.Request) (*url.URL, error) { if p.loadBalancer != nil && p.config.EnableLoadBalancing { // 使用负载均衡 host := req.URL.Hostname() return p.loadBalancer.Next(host) } // 使用委托 return p.delegate.ParentProxy(req) } // 生成TLS配置 func (p *Proxy) generateTLSConfig(host string) (*tls.Config, error) { // 如果没有证书管理器,则创建一个 if p.certManager == nil { // 创建证书管理器,使用已有的证书缓存 options := []CertManagerOption{ WithDefaultPrivateKey(true), // 使用默认私钥提高性能 WithValidityYears(1), // 证书有效期1年 } p.certManager = NewCertManager(p.certCache, options...) } // 1. 首先检查是否配置了自定义证书 if p.config.TLSCert != "" && p.config.TLSKey != "" { cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey) if err != nil { return nil, fmt.Errorf("加载TLS证书失败: %s", err) } return &tls.Config{ Certificates: []tls.Certificate{cert}, }, nil } // 2. 检查是否配置了CA证书和密钥(用于动态生成证书) if p.config.CACert != "" && p.config.CAKey != "" { // 加载CA证书和私钥 caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey) if err != nil { p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err)) // 如果加载失败,使用默认CA return p.certManager.GenerateTLSConfig(host) } // 使用自定义CA生成证书 cert, err := p.certManager.GenerateCertificate(host, caCert, caKey) if err != nil { p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err)) return nil, err } return &tls.Config{ Certificates: []tls.Certificate{*cert}, }, nil } // 3. 使用默认CA生成证书 tlsConfig, err := p.certManager.GenerateTLSConfig(host) if err != nil { p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err)) return nil, err } return tlsConfig, nil } // 获取客户端连接 func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) { hijacker, ok := rw.(http.Hijacker) if !ok { return nil, fmt.Errorf("http server不支持Hijacker") } conn, bufrw, err := hijacker.Hijack() if err != nil { return nil, fmt.Errorf("hijacker错误: %s", err) } return NewConnBuffer(conn, bufrw.Reader), nil } // 检查是否是WebSocket请求 func isWebSocketRequest(req *http.Request) bool { if req == nil { return false } // 检查Connection头 connection := strings.ToLower(req.Header.Get("Connection")) if !strings.Contains(connection, "upgrade") { return false } // 检查Upgrade头 upgrade := strings.ToLower(req.Header.Get("Upgrade")) if upgrade != "websocket" { return false } return true } // hop-by-hop 头部 var hopHeaders = []string{ "Connection", "Proxy-Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Te", "Trailer", "Transfer-Encoding", "Upgrade", } // isCacheHitMetricsSupported 检查指标是否支持缓存命中计数 func isCacheHitMetricsSupported(m metrics.MetricsCollector) bool { _, ok := m.(interface{ IncCacheHit() }) return ok } // incrementCacheHit 增加缓存命中计数 func incrementCacheHit(m metrics.MetricsCollector) { if hitter, ok := m.(interface{ IncCacheHit() }); ok { hitter.IncCacheHit() } } // SetDialContext 设置自定义的拨号上下文函数 func (p *Proxy) SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) { p.transport.DialContext = dialContext } // convertToReverseConfig 将 config.Config 转换为 reverse.Config func convertToReverseConfig(cfg *config.Config) *reverse.Config { return &reverse.Config{ BaseConfig: reverse.BaseConfig{ ListenAddr: cfg.ListenAddr, EnableHTTPS: cfg.DecryptHTTPS, TLSConfig: &reverse.TLSConfig{ CertFile: cfg.TLSCert, KeyFile: cfg.TLSKey, InsecureSkipVerify: cfg.InsecureSkipVerify, UseECDSA: cfg.UseECDSA, }, EnableWebSocket: cfg.SupportWebSocketUpgrade, EnableCompression: cfg.EnableCompression, EnableCORS: cfg.EnableCORS, PreserveClientIP: cfg.PreserveClientIP, AddXForwardedFor: cfg.AddXForwardedFor, AddXRealIP: cfg.AddXRealIP, }, RulesFile: cfg.ReverseProxyRulesFile, InsecureSkipVerify: cfg.InsecureSkipVerify, EnableHealthCheck: cfg.EnableHealthCheck, HealthCheckInterval: cfg.HealthCheckInterval, HealthCheckTimeout: cfg.HealthCheckTimeout, EnableRetry: cfg.EnableRetry, MaxRetries: cfg.MaxRetries, RetryBackoff: cfg.RetryBackoff, MaxRetryBackoff: cfg.MaxRetryBackoff, EnableMetrics: cfg.EnableMetrics, EnableTracing: cfg.EnableTracing, WebSocketIntercept: cfg.WebSocketIntercept, DNSCacheTTL: cfg.DNSCacheTTL, EnableCache: cfg.EnableCache, CacheTTL: cfg.CacheTTL, EnableConnectionPool: cfg.EnableConnectionPool, ConnectionPoolSize: cfg.ConnectionPoolSize, IdleTimeout: cfg.IdleTimeout, RequestTimeout: cfg.RequestTimeout, } }