1403 lines
37 KiB
Go
1403 lines
37 KiB
Go
package goproxy
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"crypto/tls"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httptrace"
|
||
"net/http/httputil"
|
||
"net/url"
|
||
"strconv"
|
||
"strings"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/darkit/goproxy/config"
|
||
"github.com/darkit/goproxy/pkg/auth"
|
||
"github.com/darkit/goproxy/pkg/cache"
|
||
"github.com/darkit/goproxy/pkg/dns"
|
||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||
"github.com/darkit/goproxy/pkg/metrics"
|
||
"github.com/darkit/goproxy/pkg/rule"
|
||
"github.com/ouqiang/websocket"
|
||
"github.com/viki-org/dnscache"
|
||
)
|
||
|
||
// UnifiedProxy 统一代理接口
|
||
type UnifiedProxy interface {
|
||
// ServeHTTP 处理HTTP请求
|
||
ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||
// Close 关闭代理
|
||
Close() error
|
||
// ClientConnNum 获取当前客户端连接数
|
||
ClientConnNum() int32
|
||
// SetDialContext 设置自定义的拨号上下文函数
|
||
SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error))
|
||
// DoRequest 执行HTTP请求
|
||
DoRequest(ctx *Context, responseFunc func(*http.Response, error))
|
||
}
|
||
|
||
// UnifiedProxyImpl 统一代理实现
|
||
type UnifiedProxyImpl struct {
|
||
// 配置
|
||
config *config.UnifiedConfig
|
||
// 委托
|
||
delegate Delegate
|
||
// 证书缓存
|
||
certCache CertificateCache
|
||
// HTTP缓存
|
||
httpCache cache.Cache
|
||
// 缓存适配器
|
||
cacheAdapter *CacheAdapter
|
||
// 负载均衡器
|
||
loadBalancer loadbalance.LoadBalancer
|
||
// 健康检查器
|
||
healthChecker *healthcheck.HealthChecker
|
||
// 监控指标
|
||
metrics metrics.MetricsCollector
|
||
// 客户端跟踪
|
||
clientTrace *httptrace.ClientTrace
|
||
// 基础传输(用于直接获取*http.Transport类型)
|
||
transport *http.Transport
|
||
// HTTP请求传输(可能被中间件包装)
|
||
httpTransport http.RoundTripper
|
||
// DNS缓存
|
||
dnsCache *dnscache.Resolver
|
||
// 客户端连接数
|
||
clientConnNum int32
|
||
// 证书管理器
|
||
certManager *CertManager
|
||
// 日志记录器
|
||
logger *slog.Logger
|
||
// 规则管理器 (反向代理使用)
|
||
ruleManager *rule.Manager
|
||
// 反向代理处理器 (反向代理使用)
|
||
reverseProxy *httputil.ReverseProxy
|
||
// 认证系统
|
||
auth *auth.Auth
|
||
// 自定义DNS解析器
|
||
dnsResolver *dns.CustomResolver
|
||
}
|
||
|
||
// UnifiedOptions 统一代理选项
|
||
type UnifiedOptions struct {
|
||
// 配置
|
||
Config *config.UnifiedConfig
|
||
// 委托
|
||
Delegate Delegate
|
||
// 证书缓存
|
||
CertCache CertificateCache
|
||
// HTTP缓存
|
||
HTTPCache cache.Cache
|
||
// 负载均衡器
|
||
LoadBalancer loadbalance.LoadBalancer
|
||
// 健康检查器
|
||
HealthChecker *healthcheck.HealthChecker
|
||
// 监控指标
|
||
Metrics metrics.MetricsCollector
|
||
// 客户端跟踪
|
||
ClientTrace *httptrace.ClientTrace
|
||
// 证书管理器
|
||
CertManager *CertManager
|
||
// 认证系统
|
||
Auth *auth.Auth
|
||
// DNS解析器
|
||
DNSResolver *dns.CustomResolver
|
||
}
|
||
|
||
// NewUnifiedProxy 创建统一代理
|
||
func NewUnifiedProxy(opts *UnifiedOptions) (UnifiedProxy, error) {
|
||
if opts == nil {
|
||
opts = &UnifiedOptions{}
|
||
}
|
||
|
||
if opts.Config == nil {
|
||
opts.Config = config.DefaultUnifiedConfig()
|
||
}
|
||
|
||
if opts.Delegate == nil {
|
||
opts.Delegate = &DefaultDelegate{}
|
||
}
|
||
|
||
proxy := &UnifiedProxyImpl{
|
||
config: opts.Config,
|
||
delegate: opts.Delegate,
|
||
certCache: opts.CertCache,
|
||
httpCache: opts.HTTPCache,
|
||
loadBalancer: opts.LoadBalancer,
|
||
healthChecker: opts.HealthChecker,
|
||
metrics: opts.Metrics,
|
||
clientTrace: opts.ClientTrace,
|
||
certManager: opts.CertManager,
|
||
auth: opts.Auth,
|
||
logger: opts.Config.Logger,
|
||
dnsResolver: opts.DNSResolver,
|
||
}
|
||
|
||
// 初始化日志记录器
|
||
if proxy.logger == nil {
|
||
proxy.logger = slog.Default()
|
||
}
|
||
|
||
// 创建证书管理器(如果需要)
|
||
if proxy.certManager == nil && opts.Config.DecryptHTTPS {
|
||
// 如果提供了CA证书和密钥,使用它们创建证书管理器
|
||
// 否则使用默认的证书管理器
|
||
proxy.certManager = NewCertManager(proxy.certCache, WithUseECDSA(opts.Config.UseECDSA))
|
||
}
|
||
|
||
// 根据代理模式初始化不同的组件
|
||
switch opts.Config.ProxyMode {
|
||
case config.ModeForward:
|
||
proxy.logger.Debug("初始化正向代理")
|
||
if err := proxy.initializeForwardProxy(); err != nil {
|
||
return nil, err
|
||
}
|
||
case config.ModeReverse:
|
||
proxy.logger.Debug("初始化反向代理")
|
||
if err := proxy.initializeReverseProxy(); err != nil {
|
||
return nil, err
|
||
}
|
||
case config.ModeTransparent:
|
||
proxy.logger.Debug("初始化透明代理")
|
||
if err := proxy.initializeTransparentProxy(); err != nil {
|
||
return nil, err
|
||
}
|
||
default:
|
||
return nil, fmt.Errorf("不支持的代理模式: %s", opts.Config.ProxyMode)
|
||
}
|
||
|
||
return proxy, nil
|
||
}
|
||
|
||
// 初始化正向代理
|
||
func (p *UnifiedProxyImpl) initializeForwardProxy() error {
|
||
// 创建DNS缓存
|
||
if p.config.DNSCacheTTL > 0 && p.dnsResolver == nil {
|
||
// 只有在没有自定义DNS解析器的情况下才创建DNS缓存
|
||
p.dnsCache = dnscache.New(p.config.DNSCacheTTL)
|
||
}
|
||
|
||
// 创建缓存适配器
|
||
if p.httpCache != nil {
|
||
p.cacheAdapter = NewCacheAdapter(p.httpCache)
|
||
}
|
||
|
||
// 创建传输层
|
||
p.transport = &http.Transport{
|
||
// DNS解析
|
||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 首先检查是否有自定义DNS解析器
|
||
if p.dnsResolver != nil {
|
||
// 解析地址和端口
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
portNum, err := strconv.Atoi(port)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 使用自定义DNS解析器解析地址
|
||
endpoint, err := p.dnsResolver.ResolveWithPort(host, portNum)
|
||
if err != nil {
|
||
// 如果解析失败,回退到原始地址
|
||
p.logger.Debug("[正向代理]自定义DNS解析失败,使用原始地址",
|
||
"host", host,
|
||
"error", err.Error(),
|
||
)
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, addr)
|
||
}
|
||
|
||
// 使用解析后的地址创建连接
|
||
p.logger.Debug("使用自定义DNS解析",
|
||
"host", host,
|
||
"resolved", endpoint.String(),
|
||
)
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, endpoint.String())
|
||
}
|
||
|
||
// 使用DNS缓存
|
||
if p.dnsCache != nil {
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 从DNS缓存中解析IP
|
||
ips, err := p.dnsCache.FetchOneString(host)
|
||
if err != nil {
|
||
// 如果DNS缓存解析失败,尝试使用系统DNS
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, addr)
|
||
}
|
||
|
||
// 使用解析后的IP创建连接
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, net.JoinHostPort(ips, port))
|
||
}
|
||
|
||
// 不使用DNS缓存,直接使用系统DNS
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, addr)
|
||
},
|
||
DisableKeepAlives: p.config.DisableKeepAlive,
|
||
MaxIdleConns: p.config.MaxIdleConns,
|
||
IdleConnTimeout: p.config.IdleTimeout,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
ForceAttemptHTTP2: true,
|
||
TLSClientConfig: nil, // 暂时不设置,在HTTPS请求时动态生成
|
||
}
|
||
|
||
// 设置HTTP请求传输
|
||
p.httpTransport = p.transport
|
||
|
||
return nil
|
||
}
|
||
|
||
// 初始化反向代理
|
||
func (p *UnifiedProxyImpl) initializeReverseProxy() error {
|
||
// 创建规则管理器
|
||
p.ruleManager = rule.NewManager(p.logger)
|
||
|
||
// 处理目标地址中可能包含的协议前缀
|
||
targetScheme := "http"
|
||
targetHost := p.config.TargetAddr
|
||
|
||
// 如果目标地址包含协议前缀,分离协议和主机部分
|
||
if strings.HasPrefix(p.config.TargetAddr, "http://") || strings.HasPrefix(p.config.TargetAddr, "https://") {
|
||
parsedURL, err := url.Parse(p.config.TargetAddr)
|
||
if err != nil {
|
||
p.logger.Error("解析目标地址URL失败",
|
||
"target", p.config.TargetAddr,
|
||
"error", err.Error(),
|
||
)
|
||
} else {
|
||
// 提取协议和主机部分
|
||
targetScheme = parsedURL.Scheme
|
||
targetHost = parsedURL.Host
|
||
p.logger.Debug("目标地址包含协议前缀,已分离",
|
||
"original", p.config.TargetAddr,
|
||
"scheme", targetScheme,
|
||
"host", targetHost,
|
||
)
|
||
}
|
||
}
|
||
|
||
// 创建传输层
|
||
p.transport = &http.Transport{
|
||
Proxy: http.ProxyFromEnvironment,
|
||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
// 如果存在自定义DNS解析器,使用它
|
||
if p.dnsResolver != nil {
|
||
// 解析地址和端口
|
||
host, port, err := net.SplitHostPort(addr)
|
||
if err != nil {
|
||
// 可能是没有端口的地址,添加默认端口
|
||
if strings.Contains(err.Error(), "missing port") {
|
||
// 默认使用80端口
|
||
host = addr
|
||
port = "80"
|
||
addr = net.JoinHostPort(host, port)
|
||
p.logger.Debug("[反向代理]地址没有端口,添加默认端口",
|
||
"addr", addr,
|
||
)
|
||
} else {
|
||
return nil, fmt.Errorf("解析主机和端口失败: %w", err)
|
||
}
|
||
}
|
||
|
||
portNum, err := strconv.Atoi(port)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("端口号格式错误: %w", err)
|
||
}
|
||
|
||
// 使用DNS解析器解析地址
|
||
endpoint, err := p.dnsResolver.ResolveWithPort(host, portNum)
|
||
if err != nil {
|
||
// 如果解析失败,回退到使用原始地址
|
||
p.logger.Debug("[反向代理]自定义DNS解析失败,使用原始地址",
|
||
"host", host,
|
||
"error", err.Error(),
|
||
)
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, addr)
|
||
}
|
||
|
||
// 使用解析后的地址创建连接
|
||
p.logger.Debug("[反向代理]使用自定义DNS解析",
|
||
"host", host,
|
||
"resolved", endpoint.String(),
|
||
)
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, endpoint.String())
|
||
}
|
||
|
||
// 没有自定义DNS解析器,直接使用系统DNS
|
||
dialer := &net.Dialer{
|
||
Timeout: p.config.RequestTimeout,
|
||
KeepAlive: p.config.IdleTimeout,
|
||
}
|
||
return dialer.DialContext(ctx, network, addr)
|
||
},
|
||
ForceAttemptHTTP2: true,
|
||
MaxIdleConns: p.config.MaxIdleConns,
|
||
IdleConnTimeout: p.config.IdleTimeout,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
TLSClientConfig: nil, // 在需要时设置
|
||
}
|
||
|
||
// 如果需要跳过证书验证,设置TLS配置
|
||
if p.config.InsecureSkipVerify {
|
||
p.transport.TLSClientConfig = &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
}
|
||
}
|
||
|
||
// 创建反向代理处理器
|
||
p.reverseProxy = &httputil.ReverseProxy{
|
||
Transport: p.transport,
|
||
Director: func(req *http.Request) {
|
||
// 应用规则
|
||
rules := p.ruleManager.ListRules(rule.RuleTypeRoute)
|
||
for _, r := range rules {
|
||
if r.Match(req) {
|
||
if err := r.Apply(req); err != nil {
|
||
p.logger.Error("应用规则失败",
|
||
"rule_id", r.GetID(),
|
||
"error", err.Error(),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 设置目标地址
|
||
if targetHost != "" {
|
||
req.URL.Host = targetHost
|
||
// 同时设置 req.Host 字段,确保兼容性
|
||
req.Host = targetHost
|
||
// 使用 Set 替换 Add,确保只有一个 Host 头
|
||
req.Header.Set("Host", targetHost)
|
||
// 使用解析出的协议,确保一致性
|
||
req.URL.Scheme = targetScheme
|
||
|
||
p.logger.Debug("[反向代理]设置目标地址",
|
||
"target", req.URL.Host,
|
||
"scheme", req.URL.Scheme,
|
||
"path", req.URL.Path,
|
||
)
|
||
}
|
||
|
||
// 添加X-Forwarded-For头
|
||
if p.config.AddXForwardedFor {
|
||
req.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||
}
|
||
|
||
// 添加X-Real-IP头
|
||
if p.config.AddXRealIP {
|
||
req.Header.Add("X-Real-IP", req.RemoteAddr)
|
||
}
|
||
|
||
// 调用委托的修改请求方法
|
||
p.delegate.ModifyRequest(req)
|
||
},
|
||
ModifyResponse: func(resp *http.Response) error {
|
||
return p.delegate.ModifyResponse(resp)
|
||
},
|
||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||
// 检查是否支持连接劫持
|
||
if hijacker, ok := w.(http.Hijacker); ok {
|
||
// 尝试获取连接
|
||
conn, _, hijackErr := hijacker.Hijack()
|
||
if hijackErr == nil {
|
||
// 连接已经被劫持,使用直接写入方式
|
||
defer conn.Close()
|
||
|
||
// 构建简单的错误响应
|
||
errMsg := fmt.Sprintf("HTTP/1.1 502 Bad Gateway\r\nContent-Type: text/plain; charset=utf-8\r\nX-Proxy-Error: %s\r\n\r\n%s",
|
||
err.Error(), err.Error())
|
||
|
||
// 直接写入连接
|
||
conn.Write([]byte(errMsg))
|
||
return
|
||
}
|
||
// 如果劫持失败,回退到正常的错误处理
|
||
}
|
||
|
||
// 使用标准的错误处理
|
||
p.delegate.HandleError(w, r, err)
|
||
},
|
||
}
|
||
|
||
// 如果配置了规则文件,加载规则
|
||
if p.config.RulesFile != "" {
|
||
loader := rule.NewLoader(p.ruleManager, p.logger)
|
||
|
||
// 根据文件扩展名决定加载方式
|
||
switch {
|
||
case strings.HasSuffix(p.config.RulesFile, ".json"):
|
||
if err := loader.LoadFromJSON(p.config.RulesFile); err != nil {
|
||
p.logger.Error("加载规则文件失败",
|
||
"file", p.config.RulesFile,
|
||
"error", err.Error(),
|
||
)
|
||
return fmt.Errorf("加载规则文件失败: %w", err)
|
||
}
|
||
p.logger.Info("成功加载规则文件",
|
||
"file", p.config.RulesFile,
|
||
)
|
||
|
||
case strings.HasSuffix(p.config.RulesFile, ".hosts"):
|
||
if err := loader.LoadFromHosts(p.config.RulesFile); err != nil {
|
||
p.logger.Error("加载hosts文件失败",
|
||
"file", p.config.RulesFile,
|
||
"error", err.Error(),
|
||
)
|
||
return fmt.Errorf("加载hosts文件失败: %w", err)
|
||
}
|
||
p.logger.Info("成功加载hosts文件",
|
||
"file", p.config.RulesFile,
|
||
)
|
||
|
||
default:
|
||
return fmt.Errorf("不支持的规则文件格式: %s", p.config.RulesFile)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 初始化透明代理
|
||
func (p *UnifiedProxyImpl) initializeTransparentProxy() error {
|
||
// 透明代理基本上是正向代理的一种特殊形式
|
||
return p.initializeForwardProxy()
|
||
}
|
||
|
||
// ServeHTTP 处理HTTP请求
|
||
func (p *UnifiedProxyImpl) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||
// 增加客户端连接数
|
||
atomic.AddInt32(&p.clientConnNum, 1)
|
||
defer atomic.AddInt32(&p.clientConnNum, -1)
|
||
|
||
// 更新请求计数指标
|
||
if p.metrics != nil {
|
||
p.metrics.IncRequestCount()
|
||
}
|
||
|
||
// 创建请求上下文
|
||
ctx := ctxPool.Get()
|
||
ctx.Reset(r)
|
||
defer ctxPool.Put(ctx)
|
||
|
||
// 调用连接事件
|
||
p.delegate.Connect(ctx, w)
|
||
|
||
// 认证检查
|
||
if p.auth != nil {
|
||
// 从请求头中获取Authorization信息
|
||
username, password, hasAuth := r.BasicAuth()
|
||
if hasAuth {
|
||
// 进行认证
|
||
token, err := p.auth.Authenticate(username, password)
|
||
if err != nil {
|
||
// 认证失败
|
||
p.logger.Debug("认证失败", "username", username, "error", err.Error())
|
||
w.Header().Set("WWW-Authenticate", `Basic realm="Proxy Authentication Required"`)
|
||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||
return
|
||
}
|
||
// 认证成功,设置认证头
|
||
ctx.Data["auth_token"] = token
|
||
} else {
|
||
// 没有提供认证信息
|
||
p.delegate.Auth(ctx, w)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 根据代理模式处理请求
|
||
switch p.config.ProxyMode {
|
||
case config.ModeForward:
|
||
// 正向代理模式下URL必须是完整的
|
||
if r.URL.Host == "" {
|
||
http.Error(w, "Invalid request: no host in URL", http.StatusBadRequest)
|
||
return
|
||
}
|
||
p.serveForwardProxy(ctx, w)
|
||
case config.ModeReverse:
|
||
p.serveReverseProxy(ctx, w)
|
||
case config.ModeTransparent:
|
||
p.serveTransparentProxy(ctx, w)
|
||
default:
|
||
http.Error(w, "不支持的代理模式", http.StatusInternalServerError)
|
||
p.logger.Error("不支持的代理模式", "mode", p.config.ProxyMode)
|
||
}
|
||
|
||
// 调用完成事件
|
||
p.delegate.Finish(ctx)
|
||
}
|
||
|
||
// 处理正向代理请求
|
||
func (p *UnifiedProxyImpl) serveForwardProxy(ctx *Context, w http.ResponseWriter) {
|
||
req := ctx.Req
|
||
|
||
// 判断是否是隧道代理请求(CONNECT方法)
|
||
if req.Method == http.MethodConnect {
|
||
// 获取客户端连接
|
||
clientConn, err := hijackerImpl(w)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(err)
|
||
http.Error(w, "无法劫持连接", http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
// 如果启用了HTTPS解密,解密HTTPS流量
|
||
if p.config.DecryptHTTPS {
|
||
// 发送隧道建立响应
|
||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("发送隧道建立响应失败: %w", err))
|
||
return
|
||
}
|
||
|
||
// 生成证书
|
||
certConfig, err := p.generateTLSConfig(req.Host)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("生成证书失败: %w", err))
|
||
return
|
||
}
|
||
|
||
// 创建TLS服务器连接
|
||
tlsConn := tls.Server(clientConn, certConfig)
|
||
defer tlsConn.Close()
|
||
|
||
// TLS握手
|
||
if err := tlsConn.Handshake(); err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("TLS握手失败: %w", err))
|
||
return
|
||
}
|
||
|
||
// 读取HTTPS请求
|
||
bufReader := bufio.NewReader(tlsConn)
|
||
httpsReq, err := http.ReadRequest(bufReader)
|
||
if err != nil {
|
||
if err != io.EOF {
|
||
p.delegate.ErrorLog(fmt.Errorf("读取HTTPS请求失败: %w", err))
|
||
}
|
||
return
|
||
}
|
||
|
||
// 更新请求信息
|
||
httpsReq.RemoteAddr = req.RemoteAddr
|
||
httpsReq.URL.Scheme = "https"
|
||
httpsReq.URL.Host = httpsReq.Host
|
||
|
||
// 检查是否是WebSocket请求
|
||
if isWebSocketRequestImpl(httpsReq) {
|
||
// 处理WebSocket请求
|
||
ctx.Req = httpsReq
|
||
p.websocketProxy(ctx, NewConnBuffer(tlsConn, bufReader))
|
||
return
|
||
}
|
||
|
||
// 处理普通HTTPS请求
|
||
ctx.Req = httpsReq
|
||
p.httpsProxy(ctx, NewConnBuffer(tlsConn, bufReader))
|
||
return
|
||
}
|
||
|
||
// 没有启用HTTPS解密,直接建立隧道
|
||
p.tunnelProxy(ctx, w)
|
||
return
|
||
}
|
||
|
||
// 处理WebSocket请求
|
||
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
|
||
// 获取客户端连接
|
||
clientConn, err := hijackerImpl(w)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("获取WebSocket连接失败: %w", err))
|
||
http.Error(w, "无法处理WebSocket请求", http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
// 处理WebSocket请求
|
||
p.websocketProxy(ctx, clientConn)
|
||
return
|
||
}
|
||
|
||
// 处理普通HTTP请求
|
||
p.handleHTTP(ctx, w)
|
||
}
|
||
|
||
// 处理反向代理请求
|
||
func (p *UnifiedProxyImpl) serveReverseProxy(ctx *Context, w http.ResponseWriter) {
|
||
// 调用BeforeRequest事件
|
||
p.delegate.BeforeRequest(ctx)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 使用反向代理处理请求
|
||
p.reverseProxy.ServeHTTP(w, ctx.Req)
|
||
}
|
||
|
||
// 处理透明代理请求
|
||
func (p *UnifiedProxyImpl) serveTransparentProxy(ctx *Context, w http.ResponseWriter) {
|
||
// 透明代理与正向代理的不同之处在于客户端不知道有代理的存在
|
||
// 这里简化处理,复用正向代理的逻辑
|
||
p.serveForwardProxy(ctx, w)
|
||
}
|
||
|
||
// 处理HTTP请求,复用自原始代理的handleHTTP方法
|
||
func (p *UnifiedProxyImpl) handleHTTP(ctx *Context, rw http.ResponseWriter) {
|
||
req := ctx.Req
|
||
|
||
// 调用BeforeRequest事件
|
||
p.delegate.BeforeRequest(ctx)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 如果是WebSocket请求且支持WebSocket,处理WebSocket
|
||
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
|
||
// 获取目标地址
|
||
targetAddr := req.Host
|
||
if !strings.Contains(targetAddr, ":") {
|
||
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" {
|
||
targetAddr += ":443"
|
||
} else {
|
||
targetAddr += ":80"
|
||
}
|
||
}
|
||
|
||
// 劫持连接
|
||
conn, err := hijackerImpl(rw)
|
||
if err != nil {
|
||
http.Error(rw, "无法劫持连接: "+err.Error(), http.StatusInternalServerError)
|
||
return
|
||
}
|
||
defer conn.Close()
|
||
|
||
// 处理WebSocket
|
||
p.websocketProxy(ctx, conn)
|
||
return
|
||
}
|
||
|
||
// 处理普通HTTP请求
|
||
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||
if err != nil {
|
||
// 调用BeforeResponse事件
|
||
p.delegate.BeforeResponse(ctx, nil, err)
|
||
|
||
// 处理错误
|
||
http.Error(rw, "请求失败: "+err.Error(), http.StatusBadGateway)
|
||
return
|
||
}
|
||
|
||
// 调用BeforeResponse事件
|
||
p.delegate.BeforeResponse(ctx, resp, nil)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 如果启用缓存且可以缓存响应
|
||
if p.config.EnableCache && p.cacheAdapter != nil && resp.StatusCode < 300 && canCacheMethodImpl(req.Method) && canCacheStatusImpl(resp.StatusCode) {
|
||
// 生成缓存键
|
||
cacheKey := generateCacheKeyImpl(req)
|
||
// 获取缓存TTL
|
||
cacheTTL := getCacheTTLImpl(resp)
|
||
if cacheTTL > 0 {
|
||
// 保存响应到缓存
|
||
p.cacheAdapter.Set(cacheKey, resp, cacheTTL)
|
||
}
|
||
}
|
||
|
||
// 设置响应头
|
||
copyHeaders(rw.Header(), resp.Header)
|
||
rw.WriteHeader(resp.StatusCode)
|
||
|
||
// 写入响应体
|
||
if resp.Body != nil {
|
||
defer resp.Body.Close()
|
||
// io.Copy(rw, resp.Body)
|
||
|
||
// 复制响应体
|
||
buf := bufPool.Get()
|
||
defer bufPool.Put(buf)
|
||
|
||
_, err = io.CopyBuffer(rw, resp.Body, buf)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err))
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// DoRequest 执行HTTP请求
|
||
func (p *UnifiedProxyImpl) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
|
||
req := ctx.Req
|
||
|
||
// 如果启用缓存且是GET请求
|
||
if p.config.EnableCache && p.cacheAdapter != nil && req.Method == http.MethodGet {
|
||
// 生成缓存键
|
||
cacheKey := generateCacheKeyImpl(req)
|
||
|
||
// 尝试从缓存中获取响应
|
||
if value, ok := p.cacheAdapter.Get(cacheKey); ok {
|
||
// 缓存命中
|
||
if resp, ok := value.(*http.Response); ok {
|
||
// 记录缓存命中指标
|
||
if p.metrics != nil && isCacheHitMetricsSupportedImpl(p.metrics) {
|
||
incrementCacheHitImpl(p.metrics)
|
||
}
|
||
|
||
// 调用BeforeResponse事件
|
||
p.delegate.BeforeResponse(ctx, resp, nil)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 返回缓存的响应
|
||
responseFunc(resp, nil)
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// 准备发送请求
|
||
retries := 0
|
||
var resp *http.Response
|
||
var err error
|
||
|
||
// 重试循环
|
||
for {
|
||
// 创建一个新的请求,以避免副作用
|
||
newReq := req.Clone(req.Context())
|
||
|
||
// 如果使用负载均衡,处理后端选择
|
||
if p.config.EnableLoadBalancing && p.loadBalancer != nil {
|
||
// 获取域名
|
||
hostname := newReq.URL.Hostname()
|
||
// 使用负载均衡器选择一个后端
|
||
backend, err := p.loadBalancer.Next(hostname)
|
||
if err != nil {
|
||
responseFunc(nil, fmt.Errorf("负载均衡选择后端失败: %w", err))
|
||
return
|
||
}
|
||
|
||
if backend != nil {
|
||
// 使用选中的后端地址
|
||
newReq.URL.Host = backend.Host
|
||
if backend.Scheme != "" {
|
||
newReq.URL.Scheme = backend.Scheme
|
||
}
|
||
}
|
||
}
|
||
|
||
// 调用BeforeRequest事件
|
||
p.delegate.BeforeRequest(ctx)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 记录请求开始时间
|
||
startTime := time.Now()
|
||
|
||
// 发送请求
|
||
resp, err = p.httpTransport.RoundTrip(newReq)
|
||
|
||
// 记录请求时间
|
||
if p.metrics != nil && err == nil {
|
||
p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds())
|
||
}
|
||
|
||
// 检查是否需要重试
|
||
if err != nil && p.config.EnableRetry && retries < p.config.MaxRetries {
|
||
retries++
|
||
// 指数退避
|
||
backoff := p.config.RetryBackoff
|
||
if backoff == 0 {
|
||
backoff = time.Second
|
||
}
|
||
// 计算退避时间
|
||
sleepTime := backoff * time.Duration(1<<uint(retries-1))
|
||
if maxBackoff := p.config.MaxRetryBackoff; maxBackoff > 0 && sleepTime > maxBackoff {
|
||
sleepTime = maxBackoff
|
||
}
|
||
// 记录重试日志
|
||
p.logger.Debug("请求失败,进行重试",
|
||
"url", newReq.URL.String(),
|
||
"retry", retries,
|
||
"max_retries", p.config.MaxRetries,
|
||
"sleep", sleepTime,
|
||
"error", err.Error(),
|
||
)
|
||
// 等待一段时间再重试
|
||
time.Sleep(sleepTime)
|
||
continue
|
||
}
|
||
|
||
// 无需重试或达到最大重试次数,返回结果
|
||
break
|
||
}
|
||
|
||
// 调用BeforeResponse事件
|
||
p.delegate.BeforeResponse(ctx, resp, err)
|
||
if ctx.IsAborted() {
|
||
if resp != nil && resp.Body != nil {
|
||
resp.Body.Close()
|
||
}
|
||
return
|
||
}
|
||
|
||
// 如果有错误,返回错误
|
||
if err != nil {
|
||
responseFunc(nil, err)
|
||
return
|
||
}
|
||
|
||
// 如果启用缓存且可以缓存响应
|
||
if p.config.EnableCache && p.cacheAdapter != nil && resp.StatusCode < 300 && canCacheMethodImpl(req.Method) && canCacheStatusImpl(resp.StatusCode) {
|
||
// 生成缓存键
|
||
cacheKey := generateCacheKeyImpl(req)
|
||
// 获取缓存TTL
|
||
cacheTTL := getCacheTTLImpl(resp)
|
||
if cacheTTL > 0 {
|
||
// 保存响应到缓存
|
||
p.cacheAdapter.Set(cacheKey, resp, cacheTTL)
|
||
}
|
||
}
|
||
|
||
// 返回响应
|
||
responseFunc(resp, nil)
|
||
}
|
||
|
||
// Close 关闭代理
|
||
func (p *UnifiedProxyImpl) Close() error {
|
||
// 关闭健康检查器
|
||
if p.healthChecker != nil {
|
||
p.healthChecker.Stop()
|
||
}
|
||
|
||
// 关闭Transport连接
|
||
if p.transport != nil {
|
||
p.transport.CloseIdleConnections()
|
||
}
|
||
|
||
// 等待活跃连接完成
|
||
// 设置一个超时,防止永远等待
|
||
timeout := time.NewTimer(30 * time.Second)
|
||
defer timeout.Stop()
|
||
|
||
// 创建一个心跳定时器,定期检查连接数
|
||
ticker := time.NewTicker(500 * time.Millisecond)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-timeout.C:
|
||
// 超时,强制关闭
|
||
p.logger.Warn("关闭超时,强制关闭代理",
|
||
"active_connections", p.ClientConnNum(),
|
||
)
|
||
return nil
|
||
case <-ticker.C:
|
||
// 检查当前连接数
|
||
if p.ClientConnNum() <= 0 {
|
||
p.logger.Info("所有连接已关闭,代理成功关闭")
|
||
return nil
|
||
}
|
||
p.logger.Debug("等待连接关闭",
|
||
"active_connections", p.ClientConnNum(),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// ClientConnNum 获取当前客户端连接数
|
||
func (p *UnifiedProxyImpl) ClientConnNum() int32 {
|
||
return atomic.LoadInt32(&p.clientConnNum)
|
||
}
|
||
|
||
// 复制HTTP头部
|
||
func copyHeaders(dst, src http.Header) {
|
||
for k, vv := range src {
|
||
for _, v := range vv {
|
||
dst.Add(k, v)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 发送隧道已连接响应
|
||
func (p *UnifiedProxyImpl) tunnelConnected(ctx *Context, err error) {
|
||
// 实现调用tunnelConnectedImpl
|
||
tunnelConnectedImpl(ctx, err, nil)
|
||
}
|
||
|
||
// 处理HTTPS请求
|
||
func (p *UnifiedProxyImpl) httpsProxy(ctx *Context, srcConn *ConnBuffer) {
|
||
req := ctx.Req
|
||
|
||
// 检查是否是WebSocket请求
|
||
if isWebSocketRequestImpl(req) && p.config.EnableWebSocket {
|
||
p.websocketProxy(ctx, srcConn)
|
||
return
|
||
}
|
||
|
||
// 准备发送请求的上下文
|
||
if req.Body == nil {
|
||
req.Body = http.NoBody
|
||
}
|
||
|
||
// 创建请求上下文,支持超时
|
||
reqContext := req.Context()
|
||
if p.config.RequestTimeout > 0 {
|
||
var cancel context.CancelFunc
|
||
reqContext, cancel = context.WithTimeout(reqContext, p.config.RequestTimeout)
|
||
defer cancel()
|
||
}
|
||
|
||
// 创建HTTP客户端
|
||
client := &http.Client{
|
||
Transport: p.httpTransport,
|
||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||
// 允许最多10次重定向
|
||
if len(via) >= 10 {
|
||
return http.ErrUseLastResponse
|
||
}
|
||
return nil
|
||
},
|
||
}
|
||
|
||
// 设置上下文
|
||
req = req.WithContext(reqContext)
|
||
|
||
// 调用BeforeRequest事件
|
||
p.delegate.BeforeRequest(ctx)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 发送请求
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS请求失败: %w", req.URL.String(), err))
|
||
errorResp := &http.Response{
|
||
StatusCode: http.StatusBadGateway,
|
||
Status: "502 Bad Gateway",
|
||
Proto: "HTTP/1.1",
|
||
ProtoMajor: 1,
|
||
ProtoMinor: 1,
|
||
Header: make(http.Header),
|
||
Body: http.NoBody,
|
||
Request: req,
|
||
}
|
||
errorResp.Header.Set("Content-Type", "text/plain; charset=utf-8")
|
||
errorResp.Header.Set("X-Proxy-Error", err.Error())
|
||
|
||
// 写入错误响应
|
||
if writeErr := resp.Write(srcConn); writeErr != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 写入错误响应失败: %w", req.URL.String(), writeErr))
|
||
}
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 调用BeforeResponse事件
|
||
p.delegate.BeforeResponse(ctx, resp, nil)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 写入响应
|
||
if err := resp.Write(srcConn); err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 写入响应失败: %w", req.URL.String(), err))
|
||
}
|
||
}
|
||
|
||
// 建立隧道代理
|
||
func (p *UnifiedProxyImpl) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
|
||
// 获取客户端连接
|
||
clientConn, err := hijackerImpl(rw)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(err)
|
||
//http.Error(rw, "无法劫持连接", http.StatusServiceUnavailable)
|
||
return
|
||
}
|
||
defer clientConn.Close()
|
||
|
||
// 获取目标主机
|
||
host := ctx.Req.Host
|
||
if !strings.Contains(host, ":") {
|
||
host = host + ":443"
|
||
}
|
||
|
||
// 尝试使用负载均衡器选择目标
|
||
var targetHost string
|
||
if p.config.EnableLoadBalancing && p.loadBalancer != nil {
|
||
hostname := strings.Split(host, ":")[0]
|
||
backend, err := p.loadBalancer.Next(hostname)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 负载均衡选择目标失败: %w", host, err))
|
||
// 继续使用原始主机
|
||
targetHost = host
|
||
} else if backend != nil {
|
||
targetHost = backend.Host
|
||
} else {
|
||
targetHost = host
|
||
}
|
||
} else {
|
||
targetHost = host
|
||
}
|
||
|
||
// 连接目标服务器
|
||
targetConn, err := net.DialTimeout("tcp", targetHost, 10*time.Second)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标服务器失败: %w", targetHost, err))
|
||
// 直接写入连接而不是使用 http.Error
|
||
clientConn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||
return
|
||
}
|
||
defer targetConn.Close()
|
||
|
||
// 发送隧道建立成功响应
|
||
if _, err := clientConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")); err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 发送隧道建立响应失败: %w", targetHost, err))
|
||
return
|
||
}
|
||
|
||
// 双向转发数据
|
||
p.transfer(clientConn, targetConn)
|
||
}
|
||
|
||
// 处理WebSocket请求
|
||
func (p *UnifiedProxyImpl) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
|
||
req := ctx.Req
|
||
|
||
// 检查是否启用WebSocket拦截
|
||
if p.config.WebSocketIntercept {
|
||
// 使用WebSocket拦截模式
|
||
// 创建WebSocket升级器
|
||
upgrader := &websocket.Upgrader{
|
||
HandshakeTimeout: 10 * time.Second,
|
||
ReadBufferSize: 4096,
|
||
WriteBufferSize: 4096,
|
||
CheckOrigin: func(r *http.Request) bool {
|
||
return true
|
||
},
|
||
}
|
||
|
||
// 创建伪响应写入器,用于升级WebSocket连接
|
||
rw := &customResponseWriter{
|
||
conn: srcConn,
|
||
header: make(http.Header),
|
||
statusCode: http.StatusOK,
|
||
}
|
||
|
||
// 升级源连接为WebSocket连接
|
||
srcWSConn, err := upgrader.Upgrade(rw, req, nil)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 升级WebSocket连接失败: %w", req.URL.String(), err))
|
||
return
|
||
}
|
||
defer srcWSConn.Close()
|
||
|
||
// 构建目标URL
|
||
targetURL := url.URL{
|
||
Scheme: "ws",
|
||
Host: req.URL.Host,
|
||
Path: req.URL.Path,
|
||
RawQuery: req.URL.RawQuery,
|
||
}
|
||
if req.URL.Scheme == "https" || req.URL.Scheme == "wss" {
|
||
targetURL.Scheme = "wss"
|
||
}
|
||
|
||
// 创建目标WebSocket连接
|
||
dialer := &websocket.Dialer{
|
||
HandshakeTimeout: 10 * time.Second,
|
||
ReadBufferSize: 4096,
|
||
WriteBufferSize: 4096,
|
||
TLSClientConfig: &tls.Config{InsecureSkipVerify: p.config.InsecureSkipVerify},
|
||
}
|
||
|
||
// 连接到目标WebSocket服务器
|
||
targetWSConn, _, err := dialer.Dial(targetURL.String(), nil)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标WebSocket服务器失败: %w", targetURL.String(), err))
|
||
return
|
||
}
|
||
defer targetWSConn.Close()
|
||
|
||
// 转发WebSocket消息
|
||
p.transferWebSocket(ctx, srcWSConn, targetWSConn)
|
||
return
|
||
}
|
||
|
||
// 不使用WebSocket拦截,直接转发TCP流量
|
||
// 确定目标地址
|
||
targetAddr := req.URL.Host
|
||
if !strings.Contains(targetAddr, ":") {
|
||
if req.URL.Scheme == "wss" || req.URL.Scheme == "https" {
|
||
targetAddr += ":443"
|
||
} else {
|
||
targetAddr += ":80"
|
||
}
|
||
}
|
||
|
||
// 创建到目标服务器的连接
|
||
var targetConn net.Conn
|
||
var err error
|
||
|
||
if req.URL.Scheme == "wss" || req.URL.Scheme == "https" {
|
||
// 使用TLS连接
|
||
targetConn, err = tls.Dial("tcp", targetAddr, &tls.Config{
|
||
InsecureSkipVerify: p.config.InsecureSkipVerify,
|
||
})
|
||
} else {
|
||
// 使用普通TCP连接
|
||
targetConn, err = net.Dial("tcp", targetAddr)
|
||
}
|
||
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 连接WebSocket目标服务器失败: %w", targetAddr, err))
|
||
return
|
||
}
|
||
defer targetConn.Close()
|
||
|
||
// 将原始请求转发给目标服务器
|
||
err = req.Write(targetConn)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("%s - 写入WebSocket握手请求失败: %w", targetAddr, err))
|
||
return
|
||
}
|
||
|
||
// 启动双向数据传输
|
||
p.transfer(srcConn, targetConn)
|
||
}
|
||
|
||
// 生成TLS配置
|
||
func (p *UnifiedProxyImpl) generateTLSConfig(host string) (*tls.Config, error) {
|
||
// 如果没有证书管理器,则创建一个
|
||
if p.certManager == nil {
|
||
// 创建证书管理器,使用已有的证书缓存
|
||
options := []CertManagerOption{
|
||
WithDefaultPrivateKey(true), // 使用默认私钥提高性能
|
||
WithValidityYears(1), // 证书有效期1年
|
||
}
|
||
p.certManager = NewCertManager(p.certCache, options...)
|
||
}
|
||
|
||
// 1. 首先检查是否配置了自定义证书
|
||
if p.config.TLSCert != "" && p.config.TLSKey != "" {
|
||
cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("加载TLS证书失败: %s", err)
|
||
}
|
||
return &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
}, nil
|
||
}
|
||
|
||
// 2. 检查是否配置了CA证书和密钥(用于动态生成证书)
|
||
if p.config.CACert != "" && p.config.CAKey != "" {
|
||
// 加载CA证书和私钥
|
||
caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err))
|
||
// 如果加载失败,使用默认CA
|
||
return p.certManager.GenerateTLSConfig(host)
|
||
}
|
||
|
||
// 使用自定义CA生成证书
|
||
cert, err := p.certManager.GenerateCertificate(host, caCert, caKey)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err))
|
||
return nil, err
|
||
}
|
||
|
||
return &tls.Config{
|
||
Certificates: []tls.Certificate{*cert},
|
||
}, nil
|
||
}
|
||
|
||
// 3. 使用默认CA生成证书
|
||
tlsConfig, err := p.certManager.GenerateTLSConfig(host)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err))
|
||
return nil, err
|
||
}
|
||
|
||
return tlsConfig, nil
|
||
}
|
||
|
||
// transfer 在两个连接之间双向转发数据
|
||
func (p *UnifiedProxyImpl) transfer(src net.Conn, dst net.Conn) {
|
||
// 创建完成通道
|
||
done := make(chan struct{}, 2)
|
||
|
||
// src -> dst
|
||
go func() {
|
||
buf := bufPool.Get()
|
||
written, err := io.CopyBuffer(dst, src, buf)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
|
||
}
|
||
|
||
// 记录传输字节数
|
||
if p.metrics != nil {
|
||
p.metrics.AddBytesTransferred("request", written)
|
||
}
|
||
|
||
bufPool.Put(buf)
|
||
dst.Close()
|
||
done <- struct{}{}
|
||
}()
|
||
|
||
// dst -> src
|
||
go func() {
|
||
buf := bufPool.Get()
|
||
written, err := io.CopyBuffer(src, dst, buf)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
|
||
}
|
||
|
||
// 记录传输字节数
|
||
if p.metrics != nil {
|
||
p.metrics.AddBytesTransferred("response", written)
|
||
}
|
||
|
||
bufPool.Put(buf)
|
||
src.Close()
|
||
done <- struct{}{}
|
||
}()
|
||
|
||
// 等待两个方向都结束
|
||
<-done
|
||
<-done
|
||
}
|
||
|
||
// transferWebSocket 在两个WebSocket连接之间双向转发数据
|
||
func (p *UnifiedProxyImpl) transferWebSocket(ctx *Context, srcWSConn, targetWSConn *websocket.Conn) {
|
||
doneCtx, cancel := context.WithCancel(context.Background())
|
||
defer cancel()
|
||
|
||
// 源到目标
|
||
go func() {
|
||
for {
|
||
if doneCtx.Err() != nil {
|
||
return
|
||
}
|
||
|
||
// 读取源消息,正确处理消息类型
|
||
msgType, msg, err := srcWSConn.ReadMessage()
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
|
||
srcWSConn.RemoteAddr().String(), targetWSConn.RemoteAddr().String(), err))
|
||
cancel() // 取消另一个goroutine
|
||
return
|
||
}
|
||
|
||
// 调用消息拦截接口
|
||
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
|
||
|
||
// 写入目标,保留原始消息类型(文本/二进制)
|
||
err = targetWSConn.WriteMessage(msgType, msg)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcWSConn.RemoteAddr().String(), targetWSConn.RemoteAddr().String(), err))
|
||
cancel() // 取消另一个goroutine
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 目标到源
|
||
for {
|
||
if doneCtx.Err() != nil {
|
||
return
|
||
}
|
||
|
||
// 读取目标消息,正确处理消息类型
|
||
msgType, msg, err := targetWSConn.ReadMessage()
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetWSConn.RemoteAddr().String(), srcWSConn.RemoteAddr().String(), err))
|
||
cancel() // 取消另一个goroutine
|
||
return
|
||
}
|
||
|
||
// 调用消息拦截接口
|
||
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
|
||
|
||
// 写入源,保留原始消息类型(文本/二进制)
|
||
err = srcWSConn.WriteMessage(msgType, msg)
|
||
if err != nil {
|
||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetWSConn.RemoteAddr().String(), srcWSConn.RemoteAddr().String(), err))
|
||
cancel() // 取消另一个goroutine
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// SetDialContext 设置自定义的拨号上下文函数
|
||
func (p *UnifiedProxyImpl) SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) {
|
||
p.transport.DialContext = dialContext
|
||
}
|
||
|
||
// 自定义响应写入器,用于WebSocket连接升级
|
||
type customResponseWriter struct {
|
||
conn *ConnBuffer
|
||
header http.Header
|
||
statusCode int
|
||
}
|
||
|
||
// Header 实现 http.ResponseWriter 接口
|
||
func (rw *customResponseWriter) Header() http.Header {
|
||
return rw.header
|
||
}
|
||
|
||
// Write 实现 http.ResponseWriter 接口
|
||
func (rw *customResponseWriter) Write(b []byte) (int, error) {
|
||
return rw.conn.Write(b)
|
||
}
|
||
|
||
// WriteHeader 实现 http.ResponseWriter 接口
|
||
func (rw *customResponseWriter) WriteHeader(statusCode int) {
|
||
rw.statusCode = statusCode
|
||
}
|