package proxy import ( "context" "crypto/tls" "fmt" "net" "net/http" "net/http/httputil" "net/url" "strings" "time" "github.com/darkit/goproxy/internal/rewriter" "github.com/darkit/goproxy/internal/router" ) // ReverseProxy 反向代理 type ReverseProxy struct { // 代理对象 proxy *Proxy // 路由器 router *router.Router // URL重写器 rewriter *rewriter.Rewriter // HTTP传输对象 transport http.RoundTripper } // NewReverseProxy 创建反向代理 func (p *Proxy) NewReverseProxy() *ReverseProxy { rp := &ReverseProxy{ proxy: p, router: router.NewRouter(), rewriter: rewriter.NewRewriter(), } // 创建自定义的传输对象 transport := &http.Transport{ Proxy: func(req *http.Request) (*url.URL, error) { // 使用代理委托中的方法获取代理 return p.delegate.ParentProxy(req) }, DialContext: p.dialContextWithCache(), TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, MaxIdleConns: p.config.ConnectionPoolSize, IdleConnTimeout: p.config.IdleTimeout, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConnsPerHost: p.config.ConnectionPoolSize, DisableCompression: !p.config.EnableCompression, } rp.transport = transport // 如果配置了规则文件,加载规则 if p.config.ReverseProxyRulesFile != "" { // 省略加载规则文件的实现 } return rp } // ServeHTTP 处理反向代理请求 func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // 获取请求上下文 ctx := ctxPool.Get().(*Context) ctx.Reset(req) defer ctxPool.Put(ctx) // 调用连接事件 rp.proxy.delegate.Connect(ctx, rw) // 认证检查 rp.proxy.delegate.Auth(ctx, rw) if ctx.IsAborted() { return } // 请求前处理 rp.proxy.delegate.BeforeRequest(ctx) if ctx.IsAborted() { return } // 解析后端地址 backend, err := rp.proxy.delegate.ResolveBackend(req) if err != nil { rp.proxy.delegate.HandleError(rw, req, err) return } // 创建请求代理对象 proxy := httputil.NewSingleHostReverseProxy(&url.URL{ Scheme: "http", Host: backend, }) // 使用自定义传输对象 proxy.Transport = rp.transport // 设置自定义错误处理函数 proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { rp.proxy.delegate.HandleError(rw, req, err) } // 设置请求修改函数 originalDirector := proxy.Director proxy.Director = func(req *http.Request) { // 调用原始Director函数 originalDirector(req) // 处理URL重写 if rp.proxy.config.EnableURLRewrite { rp.rewriter.Rewrite(req) } // 修改请求头 if rp.proxy.config.RewriteHostHeader { req.Host = backend } // 添加X-Forwarded-For头 if rp.proxy.config.AddXForwardedFor { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err == nil { // 如果已经有X-Forwarded-For,添加到末尾 if prior, ok := req.Header["X-Forwarded-For"]; ok { clientIP = strings.Join(prior, ", ") + ", " + clientIP } req.Header.Set("X-Forwarded-For", clientIP) } } // 添加X-Real-IP头 if rp.proxy.config.AddXRealIP { clientIP, _, err := net.SplitHostPort(req.RemoteAddr) if err == nil { req.Header.Set("X-Real-IP", clientIP) } } // 设置协议头 req.Header.Set("X-Forwarded-Proto", "http") if req.TLS != nil { req.Header.Set("X-Forwarded-Proto", "https") } // 调用委托的ModifyRequest方法 rp.proxy.delegate.ModifyRequest(req) } // 设置响应修改函数 proxy.ModifyResponse = func(resp *http.Response) error { // 处理响应URL重写 if rp.proxy.config.EnableURLRewrite && resp != nil { rp.rewriter.RewriteResponse(resp, req.Host) } // 添加CORS头 if rp.proxy.config.EnableCORS && resp != nil { resp.Header.Set("Access-Control-Allow-Origin", "*") resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") } // 调用委托的ModifyResponse方法 return rp.proxy.delegate.ModifyResponse(resp) } // 更新监控指标 if rp.proxy.metrics != nil { rp.proxy.metrics.IncActiveConnections() defer rp.proxy.metrics.DecActiveConnections() startTime := time.Now() defer func() { duration := time.Since(startTime) rp.proxy.metrics.ObserveRequestDuration(duration.Seconds()) rp.proxy.metrics.IncRequestCount() }() } // 处理WebSocket升级 if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) { rp.handleWebSocketUpgrade(rw, req, backend) return } // 处理普通请求 proxy.ServeHTTP(rw, req) // 完成事件 rp.proxy.delegate.Finish(ctx) } // 处理WebSocket升级 func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) { // 创建WebSocket代理 target := &url.URL{ Scheme: "ws", Host: backend, } if req.TLS != nil { target.Scheme = "wss" } // 创建连接到后端的WebSocket连接 backendConn, err := rp.dialBackend(target.String(), req) if err != nil { rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err)) return } defer backendConn.Close() // 将请求转发给后端 err = req.Write(backendConn) if err != nil { rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err)) return } // 升级客户端连接 clientConn, err := hijacker(rw) if err != nil { rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err)) return } // 双向转发数据 rp.proxy.transfer(clientConn, backendConn) } // 连接到后端 func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) { ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second) defer cancel() backend := strings.TrimPrefix(url, "ws://") backend = strings.TrimPrefix(backend, "wss://") if strings.Contains(backend, "/") { backend = backend[:strings.Index(backend, "/")] } // 根据协议选择连接方式 if strings.HasPrefix(url, "wss://") { // 使用 tls.Dialer 替代不存在的 tls.DialWithContext dialer := &tls.Dialer{ Config: &tls.Config{ InsecureSkipVerify: true, }, } return dialer.DialContext(ctx, "tcp", backend) } var d net.Dialer return d.DialContext(ctx, "tcp", backend) } // 添加路由规则 func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) { route := &router.Route{ Pattern: pattern, Type: routeType, Target: target, } rp.router.AddRoute(route) } // 添加重写规则 func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error { return rp.rewriter.AddRule(pattern, replacement, useRegex) }