Files
demo/proxy.go
2025-03-14 18:50:49 +00:00

1302 lines
33 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package goproxy
import (
"bufio"
"context"
"crypto/elliptic"
"crypto/tls"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/darkit/goproxy/config"
"github.com/darkit/goproxy/pkg/cache"
"github.com/darkit/goproxy/pkg/healthcheck"
"github.com/darkit/goproxy/pkg/loadbalance"
"github.com/darkit/goproxy/pkg/metrics"
"github.com/darkit/goproxy/pkg/middleware"
"github.com/darkit/goproxy/pkg/reverse"
"github.com/ouqiang/websocket"
"github.com/viki-org/dnscache"
)
const (
// 连接目标服务器超时时间
defaultTargetConnectTimeout = 5 * time.Second
// 目标服务器读写超时时间
defaultTargetReadWriteTimeout = 10 * time.Second
)
// 隧道连接成功响应行
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
// 错误网关响应
var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
// 对象池
var (
bufPool = newPool(func() []byte {
return make([]byte, 32*1024)
})
ctxPool = newPool(func() *Context {
return new(Context)
})
)
// CertificateCache 证书缓存接口
type CertificateCache interface {
// Get 获取证书
Get(host string) *tls.Certificate
// Set 设置证书
Set(host string, cert *tls.Certificate)
}
// MemCertCache 内存证书缓存
type MemCertCache struct {
certs sync.Map
}
// Get 获取证书
func (c *MemCertCache) Get(host string) *tls.Certificate {
v, ok := c.certs.Load(host)
if !ok {
return nil
}
return v.(*tls.Certificate)
}
// Set 设置证书
func (c *MemCertCache) Set(host string, cert *tls.Certificate) {
c.certs.Store(host, cert)
}
// CacheAdapter 缓存适配器,统一不同缓存实现的接口
type CacheAdapter struct {
cache interface{}
// 缓存方法类型标志
getMethodType int
setMethodType int
// 方法类型常量
getResponseBool int
getInterfaceBool int
setResponseOnly int
setResponseTTL int
setInterfaceOnly int
}
// NewCacheAdapter 创建缓存适配器
func NewCacheAdapter(cache interface{}) *CacheAdapter {
adapter := &CacheAdapter{
cache: cache,
// 方法类型常量初始化
getResponseBool: 1,
getInterfaceBool: 2,
setResponseOnly: 1,
setResponseTTL: 2,
setInterfaceOnly: 3,
}
// 判断支持的方法类型
if _, ok := cache.(interface {
Get(string) (*http.Response, bool)
}); ok {
adapter.getMethodType = adapter.getResponseBool
} else if _, ok := cache.(interface {
Get(string) (interface{}, bool)
}); ok {
adapter.getMethodType = adapter.getInterfaceBool
}
if _, ok := cache.(interface {
Set(string, *http.Response, time.Duration)
}); ok {
adapter.setMethodType = adapter.setResponseTTL
} else if _, ok := cache.(interface {
Set(string, *http.Response)
}); ok {
adapter.setMethodType = adapter.setResponseOnly
} else if _, ok := cache.(interface {
Set(string, interface{})
}); ok {
adapter.setMethodType = adapter.setInterfaceOnly
}
return adapter
}
// Get 统一的获取方法
func (a *CacheAdapter) Get(key string) (interface{}, bool) {
switch a.getMethodType {
case a.getResponseBool:
if getter, ok := a.cache.(interface {
Get(string) (*http.Response, bool)
}); ok {
return getter.Get(key)
}
case a.getInterfaceBool:
if getter, ok := a.cache.(interface {
Get(string) (interface{}, bool)
}); ok {
return getter.Get(key)
}
}
return nil, false
}
// Set 统一的设置方法
func (a *CacheAdapter) Set(key string, value interface{}, ttl time.Duration) {
resp, isResponse := value.(*http.Response)
switch a.setMethodType {
case a.setResponseTTL:
if setter, ok := a.cache.(interface {
Set(string, *http.Response, time.Duration)
}); ok && isResponse {
setter.Set(key, resp, ttl)
}
case a.setResponseOnly:
if setter, ok := a.cache.(interface {
Set(string, *http.Response)
}); ok && isResponse {
setter.Set(key, resp)
}
case a.setInterfaceOnly:
if setter, ok := a.cache.(interface {
Set(string, interface{})
}); ok {
setter.Set(key, value)
}
}
}
// Proxy HTTP代理
type Proxy struct {
// 配置
config *config.Config
// 委托
delegate Delegate
// 证书缓存
certCache CertificateCache
// HTTP缓存
httpCache cache.Cache
// 缓存适配器
cacheAdapter *CacheAdapter
// 负载均衡器
loadBalancer loadbalance.LoadBalancer
// 健康检查器
healthChecker *healthcheck.HealthChecker
// 监控指标
metrics metrics.MetricsCollector
// 客户端跟踪
clientTrace *httptrace.ClientTrace
// 基础传输(用于直接获取*http.Transport类型
transport *http.Transport
// HTTP请求传输可能被中间件包装
httpTransport http.RoundTripper
// DNS缓存
dnsCache *dnscache.Resolver
// 客户端连接数
clientConnNum int32
// 证书管理器
certManager *CertManager
// 日志记录器
logger *slog.Logger
}
// New 创建代理
func New(opts *Options) *Proxy {
if opts == nil {
opts = &Options{}
}
if opts.Config == nil {
opts.Config = config.DefaultConfig()
}
if opts.Delegate == nil {
opts.Delegate = &DefaultDelegate{}
}
p := &Proxy{
config: opts.Config,
delegate: opts.Delegate,
certCache: opts.CertCache,
httpCache: opts.HTTPCache,
loadBalancer: opts.LoadBalancer,
healthChecker: opts.HealthChecker,
metrics: opts.Metrics,
clientTrace: opts.ClientTrace,
clientConnNum: 0,
logger: opts.Config.Logger,
}
// 如果存在HTTP缓存创建缓存适配器
if p.httpCache != nil {
p.cacheAdapter = NewCacheAdapter(p.httpCache)
}
// 创建DNS缓存
p.dnsCache = dnscache.New(opts.Config.DNSCacheTTL)
// 设置证书管理器
if opts.CertManager != nil {
// 如果选项中已提供证书管理器,直接使用
p.certManager = opts.CertManager
} else if opts.Config.DecryptHTTPS {
// 如果启用了HTTPS解密且未提供证书管理器则创建一个新的证书管理器
certManagerOpts := []CertManagerOption{
WithDefaultPrivateKey(true), // 使用默认私钥提高性能
WithValidityYears(1), // 证书有效期1年
WithUseECDSA(opts.Config.UseECDSA), // 根据配置决定是否使用ECDSA
}
// 如果配置指定了使用ECDSA设置曲线为P-256
if opts.Config.UseECDSA {
certManagerOpts = append(certManagerOpts, WithCurve(elliptic.P256()))
}
p.certManager = NewCertManager(p.certCache, certManagerOpts...)
}
// 创建基础传输
httpTransport := &http.Transport{
Proxy: p.proxyFromDelegate,
DialContext: p.dialContextWithCache(),
MaxIdleConns: opts.Config.ConnectionPoolSize,
IdleConnTimeout: opts.Config.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DisableKeepAlives: false,
DisableCompression: false,
ForceAttemptHTTP2: true,
}
// 保存原始Transport供后续使用
p.transport = httpTransport
// 包装传输对象,应用中间件
var roundTripper http.RoundTripper = httpTransport
// 应用重试中间件
if opts.Config.EnableRetry {
policy := &middleware.RetryPolicy{
MaxRetries: opts.Config.MaxRetries,
BaseBackoff: opts.Config.BaseBackoff,
MaxBackoff: opts.Config.MaxBackoff,
}
retryMiddleware := middleware.NewRetryMiddleware(policy)
roundTripper = retryMiddleware.Transport(roundTripper)
}
// 最终的RoundTripper赋值给p.httpTransport用于HTTP请求
p.httpTransport = roundTripper
// 将健康检查器与负载均衡器集成
if p.healthChecker != nil && p.loadBalancer != nil {
p.healthChecker.SetStatusChangeCallback(func(target string, healthy bool) {
if healthy {
p.loadBalancer.MarkUp(target)
} else {
p.loadBalancer.MarkDown(target)
}
})
}
return p
}
// NewProxy 使用functional options模式创建代理
func NewProxy(options ...Option) *Proxy {
// 创建默认选项
opts := &Options{
Config: config.DefaultConfig(),
}
// 应用所有选项
for _, option := range options {
option(opts)
}
// 使用传统方法创建代理
return New(opts)
}
// ServeHTTP 处理HTTP请求
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// 判断是反向代理还是正向代理
if p.config.ReverseProxy {
if req.URL.Scheme == "" {
req.URL.Scheme = "http"
}
// 如果是反向代理模式,使用反向代理处理请求
reverseProxy, err := reverse.New(convertToReverseConfig(p.config))
if err != nil {
p.logger.Error("创建反向代理失败",
"error", err.Error(),
)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
reverseProxy.ServeHTTP(rw, req)
return
}
// 更新请求计数指标
if p.metrics != nil {
p.metrics.IncRequestCount()
}
if req.URL.Host == "" || req.URL.Host == req.Host {
rw.Header().Set("X-Proxy-Error", "Invalid Request")
http.Error(rw, "", http.StatusBadRequest)
return
}
// 处理请求
ctx := ctxPool.Get()
ctx.Reset(req)
defer ctxPool.Put(ctx)
// 调用连接事件
p.delegate.Connect(ctx, rw)
// 认证检查
p.delegate.Auth(ctx, rw)
if ctx.IsAborted() {
return
}
// HTTP隧道连接CONNECT方法
if req.Method == http.MethodConnect {
p.tunnelProxy(ctx, rw)
return
}
// 如果是WebSocket请求使用WebSocket代理
if isWebSocketRequest(req) {
clientConn, err := hijacker(rw)
if err != nil {
p.delegate.ErrorLog(err)
http.Error(rw, "无法处理WebSocket请求", http.StatusInternalServerError)
return
}
p.websocketProxy(ctx, clientConn)
return
}
// 处理普通HTTP请求
p.handleHTTP(ctx, rw)
}
// handleHTTP 处理HTTP请求
func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) {
// 调用请求前事件
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 开始时间
startTime := time.Now()
// 获取上级代理
parentProxy, err := p.proxyFromDelegate(ctx.Req)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 获取上级代理错误: %s", ctx.Req.URL.Host, err))
ctx.ParentProxyURL = nil
} else {
ctx.ParentProxyURL = parentProxy
}
var (
resp *http.Response
req = ctx.Req
)
// 从缓存获取响应
if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) {
cacheKey := generateCacheKey(req)
var cachedResp interface{}
var ok bool
// 使用缓存适配器获取数据
if p.cacheAdapter != nil {
cachedResp, ok = p.cacheAdapter.Get(cacheKey)
if ok && cachedResp != nil {
// 从缓存中找到响应
resp = cachedResp.(*http.Response)
// 更新指标
if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) {
incrementCacheHit(p.metrics)
}
}
}
}
// 如果缓存中没有,则发送请求
if resp == nil {
// 创建传输上下文
reqCtx := req.Context()
if p.clientTrace != nil {
reqCtx = httptrace.WithClientTrace(reqCtx, p.clientTrace)
}
// 设置请求超时
if p.config.RequestTimeout > 0 {
var cancel context.CancelFunc
reqCtx, cancel = context.WithTimeout(reqCtx, p.config.RequestTimeout)
defer cancel()
}
req = req.WithContext(reqCtx)
// 发送请求
var err error
resp, err = p.httpTransport.RoundTrip(req)
// 处理错误
if err != nil {
p.delegate.BeforeResponse(ctx, nil, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", req.URL.Host, err))
http.Error(rw, err.Error(), http.StatusBadGateway)
return
}
// 更新指标
if p.metrics != nil {
p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds())
}
// 缓存响应
if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) && canCacheStatus(resp.StatusCode) {
cacheKey := generateCacheKey(req)
// 使用缓存适配器设置数据
if p.cacheAdapter != nil {
p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp))
}
}
}
// 调用响应前事件
p.delegate.BeforeResponse(ctx, resp, nil)
if ctx.IsAborted() {
return
}
// 复制头部信息
for key, values := range resp.Header {
for _, value := range values {
rw.Header().Add(key, value)
}
}
// 写入状态码
rw.WriteHeader(resp.StatusCode)
// 复制响应体
buf := bufPool.Get()
defer bufPool.Put(buf)
_, err = io.CopyBuffer(rw, resp.Body, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err))
}
// 关闭响应体
resp.Body.Close()
// 调用完成事件
p.delegate.Finish(ctx)
}
// canCacheMethod 检查请求方法是否可缓存
func canCacheMethod(method string) bool {
return method == http.MethodGet || method == http.MethodHead
}
// canCacheStatus 检查响应状态码是否可缓存
func canCacheStatus(statusCode int) bool {
return statusCode >= 200 && statusCode < 400
}
// generateCacheKey 生成缓存键
func generateCacheKey(req *http.Request) string {
return req.Method + " " + req.URL.String()
}
// getCacheTTL 获取缓存TTL
func getCacheTTL(resp *http.Response) time.Duration {
// 默认5分钟
ttl := 5 * time.Minute
// 从Cache-Control获取max-age
cacheControl := resp.Header.Get("Cache-Control")
if cacheControl != "" {
for _, directive := range strings.Split(cacheControl, ",") {
directive = strings.TrimSpace(directive)
if strings.HasPrefix(directive, "max-age=") {
maxAge := strings.TrimPrefix(directive, "max-age=")
if seconds, err := strconv.Atoi(maxAge); err == nil {
ttl = time.Duration(seconds) * time.Second
}
break
}
}
}
return ttl
}
// ClientConnNum 获取客户端连接数
func (p *Proxy) ClientConnNum() int32 {
return atomic.LoadInt32(&p.clientConnNum)
}
// DoRequest 执行HTTP请求
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
if ctx.Data == nil {
ctx.Data = make(map[interface{}]interface{})
}
// 请求前处理
p.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 检查缓存
if p.httpCache != nil && ctx.Req.Method == http.MethodGet && p.config.EnableCache {
cacheKey := cache.GenerateCacheKey(ctx.Req)
if p.cacheAdapter != nil {
cachedResp, ok := p.cacheAdapter.Get(cacheKey)
if ok && cachedResp != nil {
// 使用缓存的响应
cached := cachedResp.(*http.Response)
p.delegate.BeforeResponse(ctx, cached, nil)
if !ctx.IsAborted() {
responseFunc(cached, nil)
}
// 更新指标
if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) {
incrementCacheHit(p.metrics)
}
return
}
}
}
// 准备请求
newReq := ctx.Req.Clone(ctx.Req.Context())
// 移除hop-by-hop头部
for _, h := range hopHeaders {
newReq.Header.Del(h)
}
// 添加客户端跟踪
if p.clientTrace != nil {
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
}
// 执行请求
resp, err := p.httpTransport.RoundTrip(newReq)
// 响应前处理
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.IsAborted() {
if resp != nil {
resp.Body.Close()
}
return
}
// 错误处理
if err != nil {
responseFunc(nil, err)
return
}
// 移除hop-by-hop头部
for _, h := range hopHeaders {
resp.Header.Del(h)
}
// 缓存响应
if p.httpCache != nil && p.config.EnableCache && cache.ShouldCache(ctx.Req, resp) {
cacheKey := cache.GenerateCacheKey(ctx.Req)
if p.cacheAdapter != nil {
p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp))
}
}
// 返回响应
responseFunc(resp, nil)
}
// HTTPS代理
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
if isWebSocketRequest(ctx.Req) {
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
return
}
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密请求错误: %s", ctx.Req.URL, err))
tlsClientConn.Write(badGatewayResponse)
return
}
// 直接写入TLS连接
err = resp.Write(tlsClientConn)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密响应写入客户端失败: %s", ctx.Req.URL, err))
}
resp.Body.Close()
})
}
// 隧道代理
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
// 获取客户端连接
clientConn, err := hijacker(rw)
if err != nil {
p.delegate.ErrorLog(err)
rw.WriteHeader(http.StatusBadGateway)
return
}
defer clientConn.Close()
// 处理WebSocket请求
if isWebSocketRequest(ctx.Req) {
p.websocketProxy(ctx, clientConn)
return
}
// 获取上级代理
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
if ctx.ParentProxyURL != nil {
parentProxyURL = ctx.ParentProxyURL
}
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
// 如果不使用上级代理,通知客户端隧道已建立
if parentProxyURL == nil {
_, err = clientConn.Write(tunnelEstablishedResponseLine)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
return
}
}
// 检测WebSocket
isWebsocket := false
methodBytes, err := clientConn.Peek(3)
if err == nil && string(methodBytes) == http.MethodGet {
isWebsocket = true
}
// 处理WebSocket
if isWebsocket {
req, err := http.ReadRequest(clientConn.BufferReader())
if err != nil {
if err != io.EOF {
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
req.RemoteAddr = ctx.Req.RemoteAddr
req.URL.Scheme = "http"
req.URL.Host = req.Host
ctx.Req = req
p.websocketProxy(ctx, clientConn)
return
}
// HTTPS解密
var tlsClientConn *tls.Conn
if p.config.DecryptHTTPS {
// 生成证书
certConfig, err := p.generateTLSConfig(ctx.Req.URL.Host)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密生成证书失败: %s", ctx.Req.URL.Host, err))
return
}
// 创建TLS服务器连接
tlsClientConn = tls.Server(clientConn, certConfig)
defer tlsClientConn.Close()
// TLS握手
if err := tlsClientConn.Handshake(); err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密握手失败: %s", ctx.Req.URL.Host, err))
return
}
// 读取HTTPS请求
buf := bufio.NewReader(tlsClientConn)
tlsReq, err := http.ReadRequest(buf)
if err != nil {
if err != io.EOF {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密读取客户端请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
// 更新请求信息
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
tlsReq.URL.Scheme = "https"
tlsReq.URL.Host = tlsReq.Host
ctx.Req = tlsReq
}
// 确定目标地址
targetAddr := ctx.Req.URL.Host
if ctx.TargetAddr != "" {
targetAddr = ctx.TargetAddr
} else if parentProxyURL != nil {
targetAddr = parentProxyURL.Host
}
// 确保地址包含端口
if !strings.Contains(targetAddr, ":") {
targetAddr += ":443"
}
// 连接目标服务器
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
return
}
defer targetConn.Close()
// 向上级代理发送CONNECT请求
if parentProxyURL != nil {
tunnelRequestLine := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", ctx.Req.URL.Host, ctx.Req.URL.Host)
_, err = targetConn.Write([]byte(tunnelRequestLine))
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 向上级代理发送CONNECT请求失败: %s", ctx.Req.URL.Host, err))
return
}
// 读取上级代理响应
bufReader := bufio.NewReader(targetConn)
resp, err := http.ReadResponse(bufReader, ctx.Req)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 读取上级代理响应失败: %s", ctx.Req.URL.Host, err))
return
}
defer resp.Body.Close()
// 检查上级代理响应
if resp.StatusCode != http.StatusOK {
p.tunnelConnected(ctx, fmt.Errorf("上级代理返回错误状态码: %d", resp.StatusCode))
p.delegate.ErrorLog(fmt.Errorf("%s - 上级代理返回错误状态码: %d", ctx.Req.URL.Host, resp.StatusCode))
return
}
}
// 处理HTTPS解密或直接隧道转发
if p.config.DecryptHTTPS {
p.httpsProxy(ctx, tlsClientConn)
} else {
p.tunnelConnected(ctx, nil)
p.transfer(clientConn, targetConn)
}
}
// WebSocket代理
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
if !p.config.WebSocketIntercept {
// 不拦截WebSocket直接转发
remoteAddr := ctx.Addr()
var err error
var targetConn net.Conn
// 根据协议建立连接
if ctx.IsHTTPS() {
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
} else {
targetConn, err = net.Dial("tcp", remoteAddr)
}
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
// 将请求转发给目标
err = ctx.Req.Write(targetConn)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
// 开始转发数据
p.tunnelConnected(ctx, nil)
p.transfer(srcConn, targetConn)
return
}
// 创建WebSocket升级器
upgrader := websocket.Upgrader{
HandshakeTimeout: defaultTargetConnectTimeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// 创建伪响应写入器用于升级WebSocket连接
rw := newResponseWriter(srcConn)
// 升级源连接为WebSocket连接
srcWSConn, err := upgrader.Upgrade(rw, ctx.Req, nil)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 升级WebSocket连接失败: %s", ctx.Req.URL.Host, err))
return
}
defer srcWSConn.Close()
// 构建目标URL
u := url.URL{
Scheme: func() string {
if ctx.IsHTTPS() {
return "wss"
}
return "ws"
}(),
Host: ctx.Req.URL.Host,
Path: ctx.Req.URL.Path,
RawQuery: ctx.Req.URL.RawQuery,
}
// 创建目标WebSocket连接
dialer := websocket.Dialer{
HandshakeTimeout: defaultTargetConnectTimeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
// 连接到目标WebSocket服务器
targetWSConn, _, err := dialer.Dial(u.String(), ctx.Req.Header)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标WebSocket服务器失败: %s", ctx.Req.URL.Host, err))
return
}
defer targetWSConn.Close()
// 连接成功,通知
p.tunnelConnected(ctx, nil)
// 开始WebSocket消息转发
p.transferWebSocket(ctx, srcWSConn, targetWSConn)
}
// transferWebSocket 使用WebSocket协议进行双向消息转发
func (p *Proxy) transferWebSocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
doneCtx, cancel := context.WithCancel(context.Background())
defer cancel()
// 源到目标
go func() {
for {
if doneCtx.Err() != nil {
return
}
// 读取源消息,正确处理消息类型
msgType, msg, err := srcConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
// 调用消息拦截接口
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
// 写入目标,保留原始消息类型(文本/二进制)
err = targetConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
}
}()
// 目标到源
for {
if doneCtx.Err() != nil {
return
}
// 读取目标消息,正确处理消息类型
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
// 调用消息拦截接口
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
// 写入源,保留原始消息类型(文本/二进制)
err = srcConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s",
targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err))
cancel() // 取消另一个goroutine
return
}
}
}
// 用于WebSocket升级的响应写入器
type responseWriter struct {
conn *ConnBuffer
header http.Header
statusCode int
}
func newResponseWriter(conn *ConnBuffer) *responseWriter {
return &responseWriter{
conn: conn,
header: make(http.Header),
statusCode: http.StatusOK,
}
}
func (rw *responseWriter) Header() http.Header {
return rw.header
}
func (rw *responseWriter) Write(b []byte) (int, error) {
return rw.conn.Write(b)
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
}
// 双向转发
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
// 创建完成通道
done := make(chan struct{}, 2)
// src -> dst
go func() {
buf := bufPool.Get()
written, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
}
// 记录传输字节数
if p.metrics != nil {
p.metrics.AddBytesTransferred("request", written)
}
bufPool.Put(buf)
dst.Close()
done <- struct{}{}
}()
// dst -> src
go func() {
buf := bufPool.Get()
written, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
}
// 记录传输字节数
if p.metrics != nil {
p.metrics.AddBytesTransferred("response", written)
}
bufPool.Put(buf)
src.Close()
done <- struct{}{}
}()
// 等待两个方向都结束
<-done
<-done
}
// 隧道连接处理
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
ctx.TunnelProxy = true
p.delegate.BeforeRequest(ctx)
if err != nil {
p.delegate.BeforeResponse(ctx, nil, err)
return
}
resp := &http.Response{
Status: "200 Connection Established",
StatusCode: http.StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: http.NoBody,
}
p.delegate.BeforeResponse(ctx, resp, nil)
}
// 使用DNS缓存的DialContext
func (p *Proxy) dialContextWithCache() func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
// 创建拨号器
dialer := &net.Dialer{
Timeout: defaultTargetConnectTimeout,
KeepAlive: 30 * time.Second,
}
// 如果没有启用DNS缓存直接拨号
if p.dnsCache == nil {
return dialer.DialContext(ctx, network, addr)
}
// 解析主机和端口
separator := strings.LastIndex(addr, ":")
if separator < 0 {
return nil, fmt.Errorf("invalid address: %s", addr)
}
host := addr[:separator]
port := addr[separator:]
// 查询DNS缓存
ips, err := p.dnsCache.Fetch(host)
if err != nil {
return nil, err
}
// 使用第一个IPv4地址
var ip string
for _, item := range ips {
ip = item.String()
if !strings.Contains(ip, ":") {
break
}
}
if ip == "" {
return nil, fmt.Errorf("no valid IP address found for: %s", host)
}
// 连接到解析后的IP
return dialer.DialContext(ctx, network, ip+port)
}
}
// 从委托获取代理
func (p *Proxy) proxyFromDelegate(req *http.Request) (*url.URL, error) {
if p.loadBalancer != nil && p.config.EnableLoadBalancing {
// 使用负载均衡
host := req.URL.Hostname()
return p.loadBalancer.Next(host)
}
// 使用委托
return p.delegate.ParentProxy(req)
}
// 生成TLS配置
func (p *Proxy) generateTLSConfig(host string) (*tls.Config, error) {
// 如果没有证书管理器,则创建一个
if p.certManager == nil {
// 创建证书管理器,使用已有的证书缓存
options := []CertManagerOption{
WithDefaultPrivateKey(true), // 使用默认私钥提高性能
WithValidityYears(1), // 证书有效期1年
}
p.certManager = NewCertManager(p.certCache, options...)
}
// 1. 首先检查是否配置了自定义证书
if p.config.TLSCert != "" && p.config.TLSKey != "" {
cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey)
if err != nil {
return nil, fmt.Errorf("加载TLS证书失败: %s", err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
// 2. 检查是否配置了CA证书和密钥用于动态生成证书
if p.config.CACert != "" && p.config.CAKey != "" {
// 加载CA证书和私钥
caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err))
// 如果加载失败使用默认CA
return p.certManager.GenerateTLSConfig(host)
}
// 使用自定义CA生成证书
cert, err := p.certManager.GenerateCertificate(host, caCert, caKey)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err))
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{*cert},
}, nil
}
// 3. 使用默认CA生成证书
tlsConfig, err := p.certManager.GenerateTLSConfig(host)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err))
return nil, err
}
return tlsConfig, nil
}
// 获取客户端连接
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
return nil, fmt.Errorf("http server不支持Hijacker")
}
conn, bufrw, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijacker错误: %s", err)
}
return NewConnBuffer(conn, bufrw.Reader), nil
}
// 检查是否是WebSocket请求
func isWebSocketRequest(req *http.Request) bool {
if req == nil {
return false
}
// 检查Connection头
connection := strings.ToLower(req.Header.Get("Connection"))
if !strings.Contains(connection, "upgrade") {
return false
}
// 检查Upgrade头
upgrade := strings.ToLower(req.Header.Get("Upgrade"))
if upgrade != "websocket" {
return false
}
return true
}
// hop-by-hop 头部
var hopHeaders = []string{
"Connection",
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
// isCacheHitMetricsSupported 检查指标是否支持缓存命中计数
func isCacheHitMetricsSupported(m metrics.MetricsCollector) bool {
_, ok := m.(interface{ IncCacheHit() })
return ok
}
// incrementCacheHit 增加缓存命中计数
func incrementCacheHit(m metrics.MetricsCollector) {
if hitter, ok := m.(interface{ IncCacheHit() }); ok {
hitter.IncCacheHit()
}
}
// SetDialContext 设置自定义的拨号上下文函数
func (p *Proxy) SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) {
p.transport.DialContext = dialContext
}
// convertToReverseConfig 将 config.Config 转换为 reverse.Config
func convertToReverseConfig(cfg *config.Config) *reverse.Config {
return &reverse.Config{
BaseConfig: reverse.BaseConfig{
ListenAddr: cfg.ListenAddr,
TargetAddr: cfg.TargetAddr,
EnableHTTPS: cfg.DecryptHTTPS,
TLSConfig: &reverse.TLSConfig{
CertFile: cfg.TLSCert,
KeyFile: cfg.TLSKey,
InsecureSkipVerify: cfg.InsecureSkipVerify,
UseECDSA: cfg.UseECDSA,
},
EnableWebSocket: cfg.SupportWebSocketUpgrade,
EnableCompression: cfg.EnableCompression,
EnableCORS: cfg.EnableCORS,
PreserveClientIP: cfg.PreserveClientIP,
AddXForwardedFor: cfg.AddXForwardedFor,
AddXRealIP: cfg.AddXRealIP,
},
RulesFile: cfg.ReverseProxyRulesFile,
InsecureSkipVerify: cfg.InsecureSkipVerify,
EnableHealthCheck: cfg.EnableHealthCheck,
HealthCheckInterval: cfg.HealthCheckInterval,
HealthCheckTimeout: cfg.HealthCheckTimeout,
EnableRetry: cfg.EnableRetry,
MaxRetries: cfg.MaxRetries,
RetryBackoff: cfg.BaseBackoff,
MaxRetryBackoff: cfg.MaxBackoff,
EnableMetrics: cfg.EnableMetrics,
EnableTracing: cfg.EnableTracing,
WebSocketIntercept: cfg.WebSocketIntercept,
DNSCacheTTL: cfg.DNSCacheTTL,
EnableCache: cfg.EnableCache,
CacheTTL: cfg.CacheTTL,
EnableConnectionPool: cfg.EnableConnectionPool,
ConnectionPoolSize: cfg.ConnectionPoolSize,
IdleTimeout: cfg.IdleTimeout,
RequestTimeout: cfg.RequestTimeout,
}
}