Files
goproxy/internal/proxy/reverse_proxy.go
2025-03-13 18:11:04 +08:00

278 lines
6.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package proxy
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/darkit/goproxy/internal/rewriter"
"github.com/darkit/goproxy/internal/router"
)
// ReverseProxy 反向代理
type ReverseProxy struct {
// 代理对象
proxy *Proxy
// 路由器
router *router.Router
// URL重写器
rewriter *rewriter.Rewriter
// HTTP传输对象
transport http.RoundTripper
}
// NewReverseProxy 创建反向代理
func (p *Proxy) NewReverseProxy() *ReverseProxy {
rp := &ReverseProxy{
proxy: p,
router: router.NewRouter(),
rewriter: rewriter.NewRewriter(),
}
// 创建自定义的传输对象
transport := &http.Transport{
Proxy: func(req *http.Request) (*url.URL, error) {
// 使用代理委托中的方法获取代理
return p.delegate.ParentProxy(req)
},
DialContext: p.dialContextWithCache(),
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: p.config.ConnectionPoolSize,
IdleConnTimeout: p.config.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: p.config.ConnectionPoolSize,
DisableCompression: !p.config.EnableCompression,
}
rp.transport = transport
// 如果配置了规则文件,加载规则
if p.config.ReverseProxyRulesFile != "" {
// 省略加载规则文件的实现
}
return rp
}
// ServeHTTP 处理反向代理请求
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// 获取请求上下文
ctx := ctxPool.Get().(*Context)
ctx.Reset(req)
defer ctxPool.Put(ctx)
// 调用连接事件
rp.proxy.delegate.Connect(ctx, rw)
// 认证检查
rp.proxy.delegate.Auth(ctx, rw)
if ctx.IsAborted() {
return
}
// 请求前处理
rp.proxy.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 解析后端地址
backend, err := rp.proxy.delegate.ResolveBackend(req)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, err)
return
}
// 创建请求代理对象
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: backend,
})
// 使用自定义传输对象
proxy.Transport = rp.transport
// 设置自定义错误处理函数
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
rp.proxy.delegate.HandleError(rw, req, err)
}
// 设置请求修改函数
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
// 调用原始Director函数
originalDirector(req)
// 处理URL重写
if rp.proxy.config.EnableURLRewrite {
rp.rewriter.Rewrite(req)
}
// 修改请求头
if rp.proxy.config.RewriteHostHeader {
req.Host = backend
}
// 添加X-Forwarded-For头
if rp.proxy.config.AddXForwardedFor {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
// 如果已经有X-Forwarded-For添加到末尾
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}
}
// 添加X-Real-IP头
if rp.proxy.config.AddXRealIP {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
req.Header.Set("X-Real-IP", clientIP)
}
}
// 设置协议头
req.Header.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
}
// 调用委托的ModifyRequest方法
rp.proxy.delegate.ModifyRequest(req)
}
// 设置响应修改函数
proxy.ModifyResponse = func(resp *http.Response) error {
// 处理响应URL重写
if rp.proxy.config.EnableURLRewrite && resp != nil {
rp.rewriter.RewriteResponse(resp, req.Host)
}
// 添加CORS头
if rp.proxy.config.EnableCORS && resp != nil {
resp.Header.Set("Access-Control-Allow-Origin", "*")
resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
// 调用委托的ModifyResponse方法
return rp.proxy.delegate.ModifyResponse(resp)
}
// 更新监控指标
if rp.proxy.metrics != nil {
rp.proxy.metrics.IncActiveConnections()
defer rp.proxy.metrics.DecActiveConnections()
startTime := time.Now()
defer func() {
duration := time.Since(startTime)
rp.proxy.metrics.ObserveRequestDuration(duration.Seconds())
rp.proxy.metrics.IncRequestCount()
}()
}
// 处理WebSocket升级
if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) {
rp.handleWebSocketUpgrade(rw, req, backend)
return
}
// 处理普通请求
proxy.ServeHTTP(rw, req)
// 完成事件
rp.proxy.delegate.Finish(ctx)
}
// 处理WebSocket升级
func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) {
// 创建WebSocket代理
target := &url.URL{
Scheme: "ws",
Host: backend,
}
if req.TLS != nil {
target.Scheme = "wss"
}
// 创建连接到后端的WebSocket连接
backendConn, err := rp.dialBackend(target.String(), req)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err))
return
}
defer backendConn.Close()
// 将请求转发给后端
err = req.Write(backendConn)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err))
return
}
// 升级客户端连接
clientConn, err := hijacker(rw)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err))
return
}
// 双向转发数据
rp.proxy.transfer(clientConn, backendConn)
}
// 连接到后端
func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) {
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
defer cancel()
backend := strings.TrimPrefix(url, "ws://")
backend = strings.TrimPrefix(backend, "wss://")
if strings.Contains(backend, "/") {
backend = backend[:strings.Index(backend, "/")]
}
// 根据协议选择连接方式
if strings.HasPrefix(url, "wss://") {
// 使用 tls.Dialer 替代不存在的 tls.DialWithContext
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, "tcp", backend)
}
var d net.Dialer
return d.DialContext(ctx, "tcp", backend)
}
// 添加路由规则
func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) {
route := &router.Route{
Pattern: pattern,
Type: routeType,
Target: target,
}
rp.router.AddRoute(route)
}
// 添加重写规则
func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error {
return rp.rewriter.AddRule(pattern, replacement, useRegex)
}