302 lines
7.3 KiB
Go
302 lines
7.3 KiB
Go
package goproxy
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/darkit/goproxy/config"
|
|
"github.com/darkit/goproxy/pkg/auth"
|
|
"github.com/darkit/goproxy/pkg/cache"
|
|
"github.com/darkit/goproxy/pkg/dns"
|
|
"github.com/darkit/goproxy/pkg/healthcheck"
|
|
"github.com/darkit/goproxy/pkg/loadbalance"
|
|
"github.com/darkit/goproxy/pkg/metrics"
|
|
)
|
|
|
|
// ErrNotSupportHijacking 不支持劫持错误
|
|
var ErrNotSupportHijacking = errors.New("connection does not support hijacking")
|
|
|
|
// CreateUnifiedProxy 创建统一代理
|
|
// 为了避免与现有的NewProxy冲突
|
|
func CreateUnifiedProxy(options ...UnifiedOption) (UnifiedProxy, error) {
|
|
opts := &UnifiedOptions{}
|
|
for _, option := range options {
|
|
option(opts)
|
|
}
|
|
return NewUnifiedProxy(opts)
|
|
}
|
|
|
|
// NewForwardProxy 创建正向代理
|
|
func NewForwardProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
|
cfg := config.DefaultUnifiedConfig()
|
|
cfg.ListenAddr = listenAddr
|
|
cfg.ProxyMode = config.ModeForward
|
|
|
|
opts := &UnifiedOptions{
|
|
Config: cfg,
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(opts)
|
|
}
|
|
|
|
return NewUnifiedProxy(opts)
|
|
}
|
|
|
|
// NewReverseProxy 创建反向代理
|
|
func NewReverseProxy(listenAddr, targetAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
|
cfg := config.DefaultUnifiedConfig()
|
|
cfg.ListenAddr = listenAddr
|
|
cfg.TargetAddr = targetAddr
|
|
cfg.ProxyMode = config.ModeReverse
|
|
|
|
opts := &UnifiedOptions{
|
|
Config: cfg,
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(opts)
|
|
}
|
|
|
|
return NewUnifiedProxy(opts)
|
|
}
|
|
|
|
// NewTransparentProxy 创建透明代理
|
|
func NewTransparentProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
|
cfg := config.DefaultUnifiedConfig()
|
|
cfg.ListenAddr = listenAddr
|
|
cfg.ProxyMode = config.ModeTransparent
|
|
|
|
opts := &UnifiedOptions{
|
|
Config: cfg,
|
|
}
|
|
|
|
for _, option := range options {
|
|
option(opts)
|
|
}
|
|
|
|
return NewUnifiedProxy(opts)
|
|
}
|
|
|
|
// UnifiedOption 用于配置统一代理选项的函数类型
|
|
type UnifiedOption func(*UnifiedOptions)
|
|
|
|
// WithUnifiedConfig 设置代理配置
|
|
func WithUnifiedConfig(cfg *config.UnifiedConfig) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.Config = cfg
|
|
}
|
|
}
|
|
|
|
// WithUnifiedDelegate 设置委托类
|
|
func WithUnifiedDelegate(delegate Delegate) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.Delegate = delegate
|
|
}
|
|
}
|
|
|
|
// WithUnifiedCertCache 设置证书缓存
|
|
func WithUnifiedCertCache(cache CertificateCache) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.CertCache = cache
|
|
}
|
|
}
|
|
|
|
// WithUnifiedHTTPCache 设置HTTP缓存
|
|
func WithUnifiedHTTPCache(c cache.Cache) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.HTTPCache = c
|
|
}
|
|
}
|
|
|
|
// WithUnifiedLoadBalancer 设置负载均衡器
|
|
func WithUnifiedLoadBalancer(lb loadbalance.LoadBalancer) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.LoadBalancer = lb
|
|
}
|
|
}
|
|
|
|
// WithUnifiedHealthChecker 设置健康检查器
|
|
func WithUnifiedHealthChecker(hc *healthcheck.HealthChecker) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.HealthChecker = hc
|
|
}
|
|
}
|
|
|
|
// WithUnifiedMetrics 设置监控指标
|
|
func WithUnifiedMetrics(m metrics.MetricsCollector) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.Metrics = m
|
|
}
|
|
}
|
|
|
|
// WithUnifiedClientTrace 设置HTTP客户端跟踪
|
|
func WithUnifiedClientTrace(t *httptrace.ClientTrace) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.ClientTrace = t
|
|
}
|
|
}
|
|
|
|
// WithUnifiedCertManager 设置证书管理器
|
|
func WithUnifiedCertManager(cm *CertManager) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.CertManager = cm
|
|
}
|
|
}
|
|
|
|
// WithUnifiedAuth 设置认证系统
|
|
func WithUnifiedAuth(a *auth.Auth) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.Auth = a
|
|
}
|
|
}
|
|
|
|
// WithUnifiedDNSResolver 设置DNS解析器
|
|
func WithUnifiedDNSResolver(resolver *dns.CustomResolver) UnifiedOption {
|
|
return func(opt *UnifiedOptions) {
|
|
opt.DNSResolver = resolver
|
|
}
|
|
}
|
|
|
|
// 处理劫持连接
|
|
func hijackerImpl(rw http.ResponseWriter) (*ConnBuffer, error) {
|
|
hijacker, ok := rw.(http.Hijacker)
|
|
if !ok {
|
|
return nil, ErrNotSupportHijacking
|
|
}
|
|
conn, bufrw, err := hijacker.Hijack()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return NewConnBuffer(conn, bufrw.Reader), nil
|
|
}
|
|
|
|
// 检查是否是WebSocket请求
|
|
func isWebSocketRequestImpl(req *http.Request) bool {
|
|
if req.Method != http.MethodGet {
|
|
return false
|
|
}
|
|
|
|
connection := req.Header.Get("Connection")
|
|
if connection == "" {
|
|
return false
|
|
}
|
|
|
|
upgrade := req.Header.Get("Upgrade")
|
|
return upgrade != ""
|
|
}
|
|
|
|
// 发送隧道已连接响应
|
|
func tunnelConnectedImpl(_ *Context, err error, rw http.ResponseWriter) error {
|
|
// 此函数应根据原始代码实现
|
|
if err != nil {
|
|
if rw != nil {
|
|
http.Error(rw, err.Error(), http.StatusServiceUnavailable)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 发送连接成功响应
|
|
// 这里简化处理,实际应该根据原代码完整实现
|
|
if rw != nil {
|
|
rw.WriteHeader(http.StatusOK)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// 可以缓存的HTTP方法
|
|
func canCacheMethodImpl(method string) bool {
|
|
return method == http.MethodGet
|
|
}
|
|
|
|
// 可以缓存的状态码
|
|
func canCacheStatusImpl(statusCode int) bool {
|
|
return statusCode >= 200 && statusCode < 300
|
|
}
|
|
|
|
// 生成缓存键
|
|
func generateCacheKeyImpl(req *http.Request) string {
|
|
return req.Method + ":" + req.URL.String()
|
|
}
|
|
|
|
// 获取缓存TTL
|
|
func getCacheTTLImpl(resp *http.Response) time.Duration {
|
|
// 检查响应头中的Cache-Control
|
|
cacheControl := resp.Header.Get("Cache-Control")
|
|
if cacheControl != "" {
|
|
// 解析Cache-Control头
|
|
if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") {
|
|
return 0 // 不缓存
|
|
}
|
|
|
|
// 查找max-age指令
|
|
if strings.Contains(cacheControl, "max-age=") {
|
|
parts := strings.Split(cacheControl, "max-age=")
|
|
if len(parts) > 1 {
|
|
seconds := strings.Split(parts[1], ",")[0]
|
|
if maxAge, err := strconv.ParseInt(seconds, 10, 64); err == nil {
|
|
return time.Duration(maxAge) * time.Second
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 检查Expires头
|
|
if expires := resp.Header.Get("Expires"); expires != "" {
|
|
if expiresTime, err := time.Parse(time.RFC1123, expires); err == nil {
|
|
return time.Until(expiresTime)
|
|
}
|
|
}
|
|
|
|
// 返回默认缓存时间
|
|
return 5 * time.Minute
|
|
}
|
|
|
|
// 检查是否支持缓存命中指标
|
|
func isCacheHitMetricsSupportedImpl(m metrics.MetricsCollector) bool {
|
|
// 首先检查是否为nil
|
|
if m == nil {
|
|
return false
|
|
}
|
|
|
|
// 检查是否实现了具体的缓存命中计数方法
|
|
_, hasIncrementCacheHit := m.(interface {
|
|
IncrementCacheHit()
|
|
})
|
|
|
|
// 或者检查是否实现了通用的计数器增加方法
|
|
_, hasIncrement := m.(interface {
|
|
Increment(key string, value float64)
|
|
})
|
|
|
|
// 如果实现了其中任何一个方法,就认为支持缓存命中计数
|
|
return hasIncrementCacheHit || hasIncrement
|
|
}
|
|
|
|
// 增加缓存命中计数
|
|
func incrementCacheHitImpl(m metrics.MetricsCollector) {
|
|
// 如果指标收集器实现了特定接口,则增加缓存命中计数
|
|
if counterInterface, ok := m.(interface {
|
|
IncrementCacheHit()
|
|
}); ok {
|
|
counterInterface.IncrementCacheHit()
|
|
} else if counterInterface, ok := m.(interface {
|
|
Increment(key string, value float64)
|
|
}); ok {
|
|
counterInterface.Increment("cache_hit", 1)
|
|
}
|
|
}
|
|
|
|
// NewTLSConfig 创建新的TLS配置
|
|
func NewTLSConfig(serverName string, insecureSkipVerify bool) *tls.Config {
|
|
return &tls.Config{
|
|
ServerName: serverName,
|
|
InsecureSkipVerify: insecureSkipVerify,
|
|
}
|
|
}
|