Files
demo/unified_proxy.go
2025-03-15 10:17:07 +00:00

1403 lines
37 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package goproxy
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/darkit/goproxy/config"
"github.com/darkit/goproxy/pkg/auth"
"github.com/darkit/goproxy/pkg/cache"
"github.com/darkit/goproxy/pkg/dns"
"github.com/darkit/goproxy/pkg/healthcheck"
"github.com/darkit/goproxy/pkg/loadbalance"
"github.com/darkit/goproxy/pkg/metrics"
"github.com/darkit/goproxy/pkg/rule"
"github.com/ouqiang/websocket"
"github.com/viki-org/dnscache"
)
// UnifiedProxy 统一代理接口
type UnifiedProxy interface {
// ServeHTTP 处理HTTP请求
ServeHTTP(w http.ResponseWriter, r *http.Request)
// Close 关闭代理
Close() error
// ClientConnNum 获取当前客户端连接数
ClientConnNum() int32
// SetDialContext 设置自定义的拨号上下文函数
SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error))
// DoRequest 执行HTTP请求
DoRequest(ctx *Context, responseFunc func(*http.Response, error))
}
// UnifiedProxyImpl 统一代理实现
type UnifiedProxyImpl struct {
// 配置
config *config.UnifiedConfig
// 委托
delegate Delegate
// 证书缓存
certCache CertificateCache
// HTTP缓存
httpCache cache.Cache
// 缓存适配器
cacheAdapter *CacheAdapter
// 负载均衡器
loadBalancer loadbalance.LoadBalancer
// 健康检查器
healthChecker *healthcheck.HealthChecker
// 监控指标
metrics metrics.MetricsCollector
// 客户端跟踪
clientTrace *httptrace.ClientTrace
// 基础传输(用于直接获取*http.Transport类型
transport *http.Transport
// HTTP请求传输可能被中间件包装
httpTransport http.RoundTripper
// DNS缓存
dnsCache *dnscache.Resolver
// 客户端连接数
clientConnNum int32
// 证书管理器
certManager *CertManager
// 日志记录器
logger *slog.Logger
// 规则管理器 (反向代理使用)
ruleManager *rule.Manager
// 反向代理处理器 (反向代理使用)
reverseProxy *httputil.ReverseProxy
// 认证系统
auth *auth.Auth
// 自定义DNS解析器
dnsResolver *dns.CustomResolver
}
// UnifiedOptions 统一代理选项
type UnifiedOptions struct {
// 配置
Config *config.UnifiedConfig
// 委托
Delegate Delegate
// 证书缓存
CertCache CertificateCache
// HTTP缓存
HTTPCache cache.Cache
// 负载均衡器
LoadBalancer loadbalance.LoadBalancer
// 健康检查器
HealthChecker *healthcheck.HealthChecker
// 监控指标
Metrics metrics.MetricsCollector
// 客户端跟踪
ClientTrace *httptrace.ClientTrace
// 证书管理器
CertManager *CertManager
// 认证系统
Auth *auth.Auth
// DNS解析器
DNSResolver *dns.CustomResolver
}
// NewUnifiedProxy 创建统一代理
func NewUnifiedProxy(opts *UnifiedOptions) (UnifiedProxy, error) {
if opts == nil {
opts = &UnifiedOptions{}
}
if opts.Config == nil {
opts.Config = config.DefaultUnifiedConfig()
}
if opts.Delegate == nil {
opts.Delegate = &DefaultDelegate{}
}
proxy := &UnifiedProxyImpl{
config: opts.Config,
delegate: opts.Delegate,
certCache: opts.CertCache,
httpCache: opts.HTTPCache,
loadBalancer: opts.LoadBalancer,
healthChecker: opts.HealthChecker,
metrics: opts.Metrics,
clientTrace: opts.ClientTrace,
certManager: opts.CertManager,
auth: opts.Auth,
logger: opts.Config.Logger,
dnsResolver: opts.DNSResolver,
}
// 初始化日志记录器
if proxy.logger == nil {
proxy.logger = slog.Default()
}
// 创建证书管理器(如果需要)
if proxy.certManager == nil && opts.Config.DecryptHTTPS {
// 如果提供了CA证书和密钥使用它们创建证书管理器
// 否则使用默认的证书管理器
proxy.certManager = NewCertManager(proxy.certCache, WithUseECDSA(opts.Config.UseECDSA))
}
// 根据代理模式初始化不同的组件
switch opts.Config.ProxyMode {
case config.ModeForward:
proxy.logger.Debug("初始化正向代理")
if err := proxy.initializeForwardProxy(); err != nil {
return nil, err
}
case config.ModeReverse:
proxy.logger.Debug("初始化反向代理")
if err := proxy.initializeReverseProxy(); err != nil {
return nil, err
}
case config.ModeTransparent:
proxy.logger.Debug("初始化透明代理")
if err := proxy.initializeTransparentProxy(); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("不支持的代理模式: %s", opts.Config.ProxyMode)
}
return proxy, nil
}
// 初始化正向代理
func (p *UnifiedProxyImpl) initializeForwardProxy() error {
// 创建DNS缓存
if p.config.DNSCacheTTL > 0 && p.dnsResolver == nil {
// 只有在没有自定义DNS解析器的情况下才创建DNS缓存
p.dnsCache = dnscache.New(p.config.DNSCacheTTL)
}
// 创建缓存适配器
if p.httpCache != nil {
p.cacheAdapter = NewCacheAdapter(p.httpCache)
}
// 创建传输层
p.transport = &http.Transport{
// DNS解析
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 首先检查是否有自定义DNS解析器
if p.dnsResolver != nil {
// 解析地址和端口
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, err
}
// 使用自定义DNS解析器解析地址
endpoint, err := p.dnsResolver.ResolveWithPort(host, portNum)
if err != nil {
// 如果解析失败,回退到原始地址
p.logger.Debug("[正向代理]自定义DNS解析失败使用原始地址",
"host", host,
"error", err.Error(),
)
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, addr)
}
// 使用解析后的地址创建连接
p.logger.Debug("使用自定义DNS解析",
"host", host,
"resolved", endpoint.String(),
)
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, endpoint.String())
}
// 使用DNS缓存
if p.dnsCache != nil {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// 从DNS缓存中解析IP
ips, err := p.dnsCache.FetchOneString(host)
if err != nil {
// 如果DNS缓存解析失败尝试使用系统DNS
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, addr)
}
// 使用解析后的IP创建连接
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, net.JoinHostPort(ips, port))
}
// 不使用DNS缓存直接使用系统DNS
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, addr)
},
DisableKeepAlives: p.config.DisableKeepAlive,
MaxIdleConns: p.config.MaxIdleConns,
IdleConnTimeout: p.config.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ForceAttemptHTTP2: true,
TLSClientConfig: nil, // 暂时不设置在HTTPS请求时动态生成
}
// 设置HTTP请求传输
p.httpTransport = p.transport
return nil
}
// 初始化反向代理
func (p *UnifiedProxyImpl) initializeReverseProxy() error {
// 创建规则管理器
p.ruleManager = rule.NewManager(p.logger)
// 处理目标地址中可能包含的协议前缀
targetScheme := "http"
targetHost := p.config.TargetAddr
// 如果目标地址包含协议前缀,分离协议和主机部分
if strings.HasPrefix(p.config.TargetAddr, "http://") || strings.HasPrefix(p.config.TargetAddr, "https://") {
parsedURL, err := url.Parse(p.config.TargetAddr)
if err != nil {
p.logger.Error("解析目标地址URL失败",
"target", p.config.TargetAddr,
"error", err.Error(),
)
} else {
// 提取协议和主机部分
targetScheme = parsedURL.Scheme
targetHost = parsedURL.Host
p.logger.Debug("目标地址包含协议前缀,已分离",
"original", p.config.TargetAddr,
"scheme", targetScheme,
"host", targetHost,
)
}
}
// 创建传输层
p.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 如果存在自定义DNS解析器使用它
if p.dnsResolver != nil {
// 解析地址和端口
host, port, err := net.SplitHostPort(addr)
if err != nil {
// 可能是没有端口的地址,添加默认端口
if strings.Contains(err.Error(), "missing port") {
// 默认使用80端口
host = addr
port = "80"
addr = net.JoinHostPort(host, port)
p.logger.Debug("[反向代理]地址没有端口,添加默认端口",
"addr", addr,
)
} else {
return nil, fmt.Errorf("解析主机和端口失败: %w", err)
}
}
portNum, err := strconv.Atoi(port)
if err != nil {
return nil, fmt.Errorf("端口号格式错误: %w", err)
}
// 使用DNS解析器解析地址
endpoint, err := p.dnsResolver.ResolveWithPort(host, portNum)
if err != nil {
// 如果解析失败,回退到使用原始地址
p.logger.Debug("[反向代理]自定义DNS解析失败使用原始地址",
"host", host,
"error", err.Error(),
)
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, addr)
}
// 使用解析后的地址创建连接
p.logger.Debug("[反向代理]使用自定义DNS解析",
"host", host,
"resolved", endpoint.String(),
)
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, endpoint.String())
}
// 没有自定义DNS解析器直接使用系统DNS
dialer := &net.Dialer{
Timeout: p.config.RequestTimeout,
KeepAlive: p.config.IdleTimeout,
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
MaxIdleConns: p.config.MaxIdleConns,
IdleConnTimeout: p.config.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: nil, // 在需要时设置
}
// 如果需要跳过证书验证设置TLS配置
if p.config.InsecureSkipVerify {
p.transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
// 创建反向代理处理器
p.reverseProxy = &httputil.ReverseProxy{
Transport: p.transport,
Director: func(req *http.Request) {
// 应用规则
rules := p.ruleManager.ListRules(rule.RuleTypeRoute)
for _, r := range rules {
if r.Match(req) {
if err := r.Apply(req); err != nil {
p.logger.Error("应用规则失败",
"rule_id", r.GetID(),
"error", err.Error(),
)
}
}
}
// 设置目标地址
if targetHost != "" {
req.URL.Host = targetHost
// 同时设置 req.Host 字段,确保兼容性
req.Host = targetHost
// 使用 Set 替换 Add确保只有一个 Host 头
req.Header.Set("Host", targetHost)
// 使用解析出的协议,确保一致性
req.URL.Scheme = targetScheme
p.logger.Debug("[反向代理]设置目标地址",
"target", req.URL.Host,
"scheme", req.URL.Scheme,
"path", req.URL.Path,
)
}
// 添加X-Forwarded-For头
if p.config.AddXForwardedFor {
req.Header.Add("X-Forwarded-For", req.RemoteAddr)
}
// 添加X-Real-IP头
if p.config.AddXRealIP {
req.Header.Add("X-Real-IP", req.RemoteAddr)
}
// 调用委托的修改请求方法
p.delegate.ModifyRequest(req)
},
ModifyResponse: func(resp *http.Response) error {
return p.delegate.ModifyResponse(resp)
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
// 检查是否支持连接劫持
if hijacker, ok := w.(http.Hijacker); ok {
// 尝试获取连接
conn, _, hijackErr := hijacker.Hijack()
if hijackErr == nil {
// 连接已经被劫持,使用直接写入方式
defer conn.Close()
// 构建简单的错误响应
errMsg := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\nX-Proxy-Error: %s\r\n\r\n%s",
err.Error(), err.Error())
// 直接写入连接
conn.Write([]byte(errMsg))
return
}
// 如果劫持失败,回退到正常的错误处理
}
// 使用标准的错误处理
p.delegate.HandleError(w, r, err)
},
}
// 如果配置了规则文件,加载规则
if p.config.RulesFile != "" {
loader := rule.NewLoader(p.ruleManager, p.logger)
// 根据文件扩展名决定加载方式
switch {
case strings.HasSuffix(p.config.RulesFile, ".json"):
if err := loader.LoadFromJSON(p.config.RulesFile); err != nil {
p.logger.Error("加载规则文件失败",
"file", p.config.RulesFile,
"error", err.Error(),
)
return fmt.Errorf("加载规则文件失败: %w", err)
}
p.logger.Info("成功加载规则文件",
"file", p.config.RulesFile,
)
case strings.HasSuffix(p.config.RulesFile, ".hosts"):
if err := loader.LoadFromHosts(p.config.RulesFile); err != nil {
p.logger.Error("加载hosts文件失败",
"file", p.config.RulesFile,
"error", err.Error(),
)
return fmt.Errorf("加载hosts文件失败: %w", err)
}
p.logger.Info("成功加载hosts文件",
"file", p.config.RulesFile,
)
default:
return fmt.Errorf("不支持的规则文件格式: %s", p.config.RulesFile)
}
}
return nil
}
// 初始化透明代理
func (p *UnifiedProxyImpl) initializeTransparentProxy() error {
// 透明代理基本上是正向代理的一种特殊形式
return p.initializeForwardProxy()
}
// ServeHTTP 处理HTTP请求
func (p *UnifiedProxyImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// 增加客户端连接数
atomic.AddInt32(&p.clientConnNum, 1)
defer atomic.AddInt32(&p.clientConnNum, -1)
// 更新请求计数指标
if p.metrics != nil {
p.metrics.IncRequestCount()
}
// 创建请求上下文
ctx := ctxPool.Get()
ctx.Reset(r)
defer ctxPool.Put(ctx)
// 调用连接事件
p.delegate.Connect(ctx, w)
// 认证检查
if p.auth != nil {
// 从请求头中获取Authorization信息
username, password, hasAuth := r.BasicAuth()
if hasAuth {
// 进行认证
token, err := p.auth.Authenticate(username, password)
if err != nil {
// 认证失败
p.logger.Debug("认证失败", "username", username, "error", err.Error())
w.Header().Set("WWW-Authenticate", `Basic realm="Proxy Authentication Required"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// 认证成功,设置认证头
ctx.Data["auth_token"] = token
} else {
// 没有提供认证信息
p.delegate.Auth(ctx, w)
if ctx.IsAborted() {
return
}
}
}
// 根据代理模式处理请求
switch p.config.ProxyMode {
case config.ModeForward:
// 正向代理模式下URL必须是完整的
if r.URL.Host == "" {
http.Error(w, "Invalid request: no host in URL", http.StatusBadRequest)
return
}
p.serveForwardProxy(ctx, w)
case config.ModeReverse:
p.serveReverseProxy(ctx, w)
case config.ModeTransparent:
p.serveTransparentProxy(ctx, w)
default:
http.Error(w, "不支持的代理模式", http.StatusInternalServerError)
p.logger.Error("不支持的代理模式", "mode", p.config.ProxyMode)
}
// 调用完成事件
p.delegate.Finish(ctx)
}
// 处理正向代理请求
func (p *UnifiedProxyImpl) serveForwardProxy(ctx *Context, w http.ResponseWriter) {
req := ctx.Req
// 判断是否是隧道代理请求(CONNECT方法)
if req.Method == http.MethodConnect {
// 获取客户端连接
clientConn, err := hijackerImpl(w)
if err != nil {
p.delegate.ErrorLog(err)
http.Error(w, "无法劫持连接", http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// 如果启用了HTTPS解密解密HTTPS流量
if p.config.DecryptHTTPS {
// 发送隧道建立响应
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
p.delegate.ErrorLog(fmt.Errorf("发送隧道建立响应失败: %w", err))
return
}
// 生成证书
certConfig, err := p.generateTLSConfig(req.Host)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("生成证书失败: %w", err))
return
}
// 创建TLS服务器连接
tlsConn := tls.Server(clientConn, certConfig)
defer tlsConn.Close()
// TLS握手
if err := tlsConn.Handshake(); err != nil {
p.delegate.ErrorLog(fmt.Errorf("TLS握手失败: %w", err))
return
}
// 读取HTTPS请求
bufReader := bufio.NewReader(tlsConn)
httpsReq, err := http.ReadRequest(bufReader)
if err != nil {
if err != io.EOF {
p.delegate.ErrorLog(fmt.Errorf("读取HTTPS请求失败: %w", err))
}
return
}
// 更新请求信息
httpsReq.RemoteAddr = req.RemoteAddr
httpsReq.URL.Scheme = "https"
httpsReq.URL.Host = httpsReq.Host
// 检查是否是WebSocket请求
if isWebSocketRequestImpl(httpsReq) {
// 处理WebSocket请求
ctx.Req = httpsReq
p.websocketProxy(ctx, NewConnBuffer(tlsConn, bufReader))
return
}
// 处理普通HTTPS请求
ctx.Req = httpsReq
p.httpsProxy(ctx, NewConnBuffer(tlsConn, bufReader))
return
}
// 没有启用HTTPS解密直接建立隧道
p.tunnelProxy(ctx, w)
return
}
// 处理WebSocket请求
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
// 获取客户端连接
clientConn, err := hijackerImpl(w)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("获取WebSocket连接失败: %w", err))
http.Error(w, "无法处理WebSocket请求", http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// 处理WebSocket请求
p.websocketProxy(ctx, clientConn)
return
}
// 处理普通HTTP请求
p.handleHTTP(ctx, w)
}
// 处理反向代理请求
func (p *UnifiedProxyImpl) serveReverseProxy(ctx *Context, w http.ResponseWriter) {
// 调用BeforeRequest事件
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 使用反向代理处理请求
p.reverseProxy.ServeHTTP(w, ctx.Req)
}
// 处理透明代理请求
func (p *UnifiedProxyImpl) serveTransparentProxy(ctx *Context, w http.ResponseWriter) {
// 透明代理与正向代理的不同之处在于客户端不知道有代理的存在
// 这里简化处理,复用正向代理的逻辑
p.serveForwardProxy(ctx, w)
}
// 处理HTTP请求复用自原始代理的handleHTTP方法
func (p *UnifiedProxyImpl) handleHTTP(ctx *Context, rw http.ResponseWriter) {
req := ctx.Req
// 调用BeforeRequest事件
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 如果是WebSocket请求且支持WebSocket处理WebSocket
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
// 获取目标地址
targetAddr := req.Host
if !strings.Contains(targetAddr, ":") {
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" {
targetAddr += ":443"
} else {
targetAddr += ":80"
}
}
// 劫持连接
conn, err := hijackerImpl(rw)
if err != nil {
http.Error(rw, "无法劫持连接: "+err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()
// 处理WebSocket
p.websocketProxy(ctx, conn)
return
}
// 处理普通HTTP请求
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
// 调用BeforeResponse事件
p.delegate.BeforeResponse(ctx, nil, err)
// 处理错误
http.Error(rw, "请求失败: "+err.Error(), http.StatusBadGateway)
return
}
// 调用BeforeResponse事件
p.delegate.BeforeResponse(ctx, resp, nil)
if ctx.IsAborted() {
return
}
// 如果启用缓存且可以缓存响应
if p.config.EnableCache && p.cacheAdapter != nil && resp.StatusCode < 300 && canCacheMethodImpl(req.Method) && canCacheStatusImpl(resp.StatusCode) {
// 生成缓存键
cacheKey := generateCacheKeyImpl(req)
// 获取缓存TTL
cacheTTL := getCacheTTLImpl(resp)
if cacheTTL > 0 {
// 保存响应到缓存
p.cacheAdapter.Set(cacheKey, resp, cacheTTL)
}
}
// 设置响应头
copyHeaders(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
// 写入响应体
if resp.Body != nil {
defer resp.Body.Close()
// io.Copy(rw, resp.Body)
// 复制响应体
buf := bufPool.Get()
defer bufPool.Put(buf)
_, err = io.CopyBuffer(rw, resp.Body, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err))
}
}
})
}
// DoRequest 执行HTTP请求
func (p *UnifiedProxyImpl) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
req := ctx.Req
// 如果启用缓存且是GET请求
if p.config.EnableCache && p.cacheAdapter != nil && req.Method == http.MethodGet {
// 生成缓存键
cacheKey := generateCacheKeyImpl(req)
// 尝试从缓存中获取响应
if value, ok := p.cacheAdapter.Get(cacheKey); ok {
// 缓存命中
if resp, ok := value.(*http.Response); ok {
// 记录缓存命中指标
if p.metrics != nil && isCacheHitMetricsSupportedImpl(p.metrics) {
incrementCacheHitImpl(p.metrics)
}
// 调用BeforeResponse事件
p.delegate.BeforeResponse(ctx, resp, nil)
if ctx.IsAborted() {
return
}
// 返回缓存的响应
responseFunc(resp, nil)
return
}
}
}
// 准备发送请求
retries := 0
var resp *http.Response
var err error
// 重试循环
for {
// 创建一个新的请求,以避免副作用
newReq := req.Clone(req.Context())
// 如果使用负载均衡,处理后端选择
if p.config.EnableLoadBalancing && p.loadBalancer != nil {
// 获取域名
hostname := newReq.URL.Hostname()
// 使用负载均衡器选择一个后端
backend, err := p.loadBalancer.Next(hostname)
if err != nil {
responseFunc(nil, fmt.Errorf("负载均衡选择后端失败: %w", err))
return
}
if backend != nil {
// 使用选中的后端地址
newReq.URL.Host = backend.Host
if backend.Scheme != "" {
newReq.URL.Scheme = backend.Scheme
}
}
}
// 调用BeforeRequest事件
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 记录请求开始时间
startTime := time.Now()
// 发送请求
resp, err = p.httpTransport.RoundTrip(newReq)
// 记录请求时间
if p.metrics != nil && err == nil {
p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds())
}
// 检查是否需要重试
if err != nil && p.config.EnableRetry && retries < p.config.MaxRetries {
retries++
// 指数退避
backoff := p.config.RetryBackoff
if backoff == 0 {
backoff = time.Second
}
// 计算退避时间
sleepTime := backoff * time.Duration(1<<uint(retries-1))
if maxBackoff := p.config.MaxRetryBackoff; maxBackoff > 0 && sleepTime > maxBackoff {
sleepTime = maxBackoff
}
// 记录重试日志
p.logger.Debug("请求失败,进行重试",
"url", newReq.URL.String(),
"retry", retries,
"max_retries", p.config.MaxRetries,
"sleep", sleepTime,
"error", err.Error(),
)
// 等待一段时间再重试
time.Sleep(sleepTime)
continue
}
// 无需重试或达到最大重试次数,返回结果
break
}
// 调用BeforeResponse事件
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.IsAborted() {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return
}
// 如果有错误,返回错误
if err != nil {
responseFunc(nil, err)
return
}
// 如果启用缓存且可以缓存响应
if p.config.EnableCache && p.cacheAdapter != nil && resp.StatusCode < 300 && canCacheMethodImpl(req.Method) && canCacheStatusImpl(resp.StatusCode) {
// 生成缓存键
cacheKey := generateCacheKeyImpl(req)
// 获取缓存TTL
cacheTTL := getCacheTTLImpl(resp)
if cacheTTL > 0 {
// 保存响应到缓存
p.cacheAdapter.Set(cacheKey, resp, cacheTTL)
}
}
// 返回响应
responseFunc(resp, nil)
}
// Close 关闭代理
func (p *UnifiedProxyImpl) Close() error {
// 关闭健康检查器
if p.healthChecker != nil {
p.healthChecker.Stop()
}
// 关闭Transport连接
if p.transport != nil {
p.transport.CloseIdleConnections()
}
// 等待活跃连接完成
// 设置一个超时,防止永远等待
timeout := time.NewTimer(30 * time.Second)
defer timeout.Stop()
// 创建一个心跳定时器,定期检查连接数
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-timeout.C:
// 超时,强制关闭
p.logger.Warn("关闭超时,强制关闭代理",
"active_connections", p.ClientConnNum(),
)
return nil
case <-ticker.C:
// 检查当前连接数
if p.ClientConnNum() <= 0 {
p.logger.Info("所有连接已关闭,代理成功关闭")
return nil
}
p.logger.Debug("等待连接关闭",
"active_connections", p.ClientConnNum(),
)
}
}
}
// ClientConnNum 获取当前客户端连接数
func (p *UnifiedProxyImpl) ClientConnNum() int32 {
return atomic.LoadInt32(&p.clientConnNum)
}
// 复制HTTP头部
func copyHeaders(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// 发送隧道已连接响应
func (p *UnifiedProxyImpl) tunnelConnected(ctx *Context, err error) {
// 实现调用tunnelConnectedImpl
tunnelConnectedImpl(ctx, err, nil)
}
// 处理HTTPS请求
func (p *UnifiedProxyImpl) httpsProxy(ctx *Context, srcConn *ConnBuffer) {
req := ctx.Req
// 检查是否是WebSocket请求
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
p.websocketProxy(ctx, srcConn)
return
}
// 准备发送请求的上下文
if req.Body == nil {
req.Body = http.NoBody
}
// 创建请求上下文,支持超时
reqContext := req.Context()
if p.config.RequestTimeout > 0 {
var cancel context.CancelFunc
reqContext, cancel = context.WithTimeout(reqContext, p.config.RequestTimeout)
defer cancel()
}
// 创建HTTP客户端
client := &http.Client{
Transport: p.httpTransport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
// 允许最多10次重定向
if len(via) >= 10 {
return http.ErrUseLastResponse
}
return nil
},
}
// 设置上下文
req = req.WithContext(reqContext)
// 调用BeforeRequest事件
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS请求失败: %w", req.URL.String(), err))
errorResp := &http.Response{
StatusCode: http.StatusBadGateway,
Status: "502 Bad Gateway",
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Body: http.NoBody,
Request: req,
}
errorResp.Header.Set("Content-Type", "text/plain; charset=utf-8")
errorResp.Header.Set("X-Proxy-Error", err.Error())
// 写入错误响应
if writeErr := resp.Write(srcConn); writeErr != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 写入错误响应失败: %w", req.URL.String(), writeErr))
}
return
}
defer resp.Body.Close()
// 调用BeforeResponse事件
p.delegate.BeforeResponse(ctx, resp, nil)
if ctx.IsAborted() {
return
}
// 写入响应
if err := resp.Write(srcConn); err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 写入响应失败: %w", req.URL.String(), err))
}
}
// 建立隧道代理
func (p *UnifiedProxyImpl) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
// 获取客户端连接
clientConn, err := hijackerImpl(rw)
if err != nil {
p.delegate.ErrorLog(err)
//http.Error(rw, "无法劫持连接", http.StatusServiceUnavailable)
return
}
defer clientConn.Close()
// 获取目标主机
host := ctx.Req.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
// 尝试使用负载均衡器选择目标
var targetHost string
if p.config.EnableLoadBalancing && p.loadBalancer != nil {
hostname := strings.Split(host, ":")[0]
backend, err := p.loadBalancer.Next(hostname)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 负载均衡选择目标失败: %w", host, err))
// 继续使用原始主机
targetHost = host
} else if backend != nil {
targetHost = backend.Host
} else {
targetHost = host
}
} else {
targetHost = host
}
// 连接目标服务器
targetConn, err := net.DialTimeout("tcp", targetHost, 10*time.Second)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标服务器失败: %w", targetHost, err))
// 直接写入连接而不是使用 http.Error
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
return
}
defer targetConn.Close()
// 发送隧道建立成功响应
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 发送隧道建立响应失败: %w", targetHost, err))
return
}
// 双向转发数据
p.transfer(clientConn, targetConn)
}
// 处理WebSocket请求
func (p *UnifiedProxyImpl) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
req := ctx.Req
// 检查是否启用WebSocket拦截
if p.config.WebSocketIntercept {
// 使用WebSocket拦截模式
// 创建WebSocket升级器
upgrader := &websocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// 创建伪响应写入器用于升级WebSocket连接
rw := &customResponseWriter{
conn: srcConn,
header: make(http.Header),
statusCode: http.StatusOK,
}
// 升级源连接为WebSocket连接
srcWSConn, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 升级WebSocket连接失败: %w", req.URL.String(), err))
return
}
defer srcWSConn.Close()
// 构建目标URL
targetURL := url.URL{
Scheme: "ws",
Host: req.URL.Host,
Path: req.URL.Path,
RawQuery: req.URL.RawQuery,
}
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" {
targetURL.Scheme = "wss"
}
// 创建目标WebSocket连接
dialer := &websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
TLSClientConfig: &tls.Config{InsecureSkipVerify: p.config.InsecureSkipVerify},
}
// 连接到目标WebSocket服务器
targetWSConn, _, err := dialer.Dial(targetURL.String(), nil)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标WebSocket服务器失败: %w", targetURL.String(), err))
return
}
defer targetWSConn.Close()
// 转发WebSocket消息
p.transferWebSocket(ctx, srcWSConn, targetWSConn)
return
}
// 不使用WebSocket拦截直接转发TCP流量
// 确定目标地址
targetAddr := req.URL.Host
if !strings.Contains(targetAddr, ":") {
if req.URL.Scheme == "wss" || req.URL.Scheme == "https" {
targetAddr += ":443"
} else {
targetAddr += ":80"
}
}
// 创建到目标服务器的连接
var targetConn net.Conn
var err error
if req.URL.Scheme == "wss" || req.URL.Scheme == "https" {
// 使用TLS连接
targetConn, err = tls.Dial("tcp", targetAddr, &tls.Config{
InsecureSkipVerify: p.config.InsecureSkipVerify,
})
} else {
// 使用普通TCP连接
targetConn, err = net.Dial("tcp", targetAddr)
}
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 连接WebSocket目标服务器失败: %w", targetAddr, err))
return
}
defer targetConn.Close()
// 将原始请求转发给目标服务器
err = req.Write(targetConn)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 写入WebSocket握手请求失败: %w", targetAddr, err))
return
}
// 启动双向数据传输
p.transfer(srcConn, targetConn)
}
// 生成TLS配置
func (p *UnifiedProxyImpl) generateTLSConfig(host string) (*tls.Config, error) {
// 如果没有证书管理器,则创建一个
if p.certManager == nil {
// 创建证书管理器,使用已有的证书缓存
options := []CertManagerOption{
WithDefaultPrivateKey(true), // 使用默认私钥提高性能
WithValidityYears(1), // 证书有效期1年
}
p.certManager = NewCertManager(p.certCache, options...)
}
// 1. 首先检查是否配置了自定义证书
if p.config.TLSCert != "" && p.config.TLSKey != "" {
cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey)
if err != nil {
return nil, fmt.Errorf("加载TLS证书失败: %s", err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
// 2. 检查是否配置了CA证书和密钥用于动态生成证书
if p.config.CACert != "" && p.config.CAKey != "" {
// 加载CA证书和私钥
caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err))
// 如果加载失败使用默认CA
return p.certManager.GenerateTLSConfig(host)
}
// 使用自定义CA生成证书
cert, err := p.certManager.GenerateCertificate(host, caCert, caKey)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err))
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{*cert},
}, nil
}
// 3. 使用默认CA生成证书
tlsConfig, err := p.certManager.GenerateTLSConfig(host)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err))
return nil, err
}
return tlsConfig, nil
}
// transfer 在两个连接之间双向转发数据
func (p *UnifiedProxyImpl) transfer(src net.Conn, dst net.Conn) {
// 创建完成通道
done := make(chan struct{}, 2)
// src -> dst
go func() {
buf := bufPool.Get()
written, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
}
// 记录传输字节数
if p.metrics != nil {
p.metrics.AddBytesTransferred("request", written)
}
bufPool.Put(buf)
dst.Close()
done <- struct{}{}
}()
// dst -> src
go func() {
buf := bufPool.Get()
written, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
}
// 记录传输字节数
if p.metrics != nil {
p.metrics.AddBytesTransferred("response", written)
}
bufPool.Put(buf)
src.Close()
done <- struct{}{}
}()
// 等待两个方向都结束
<-done
<-done
}
// transferWebSocket 在两个WebSocket连接之间双向转发数据
func (p *UnifiedProxyImpl) transferWebSocket(ctx *Context, srcWSConn, targetWSConn *websocket.Conn) {
doneCtx, cancel := context.WithCancel(context.Background())
defer cancel()
// 源到目标
go func() {
for {
if doneCtx.Err() != nil {
return
}
// 读取源消息,正确处理消息类型
msgType, msg, err := srcWSConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
srcWSConn.RemoteAddr().String(), targetWSConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
// 调用消息拦截接口
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
// 写入目标,保留原始消息类型(文本/二进制)
err = targetWSConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcWSConn.RemoteAddr().String(), targetWSConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
}
}()
// 目标到源
for {
if doneCtx.Err() != nil {
return
}
// 读取目标消息,正确处理消息类型
msgType, msg, err := targetWSConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetWSConn.RemoteAddr().String(), srcWSConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
// 调用消息拦截接口
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
// 写入源,保留原始消息类型(文本/二进制)
err = srcWSConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetWSConn.RemoteAddr().String(), srcWSConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
}
}
// SetDialContext 设置自定义的拨号上下文函数
func (p *UnifiedProxyImpl) SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) {
p.transport.DialContext = dialContext
}
// 自定义响应写入器用于WebSocket连接升级
type customResponseWriter struct {
conn *ConnBuffer
header http.Header
statusCode int
}
// Header 实现 http.ResponseWriter 接口
func (rw *customResponseWriter) Header() http.Header {
return rw.header
}
// Write 实现 http.ResponseWriter 接口
func (rw *customResponseWriter) Write(b []byte) (int, error) {
return rw.conn.Write(b)
}
// WriteHeader 实现 http.ResponseWriter 接口
func (rw *customResponseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
}