Files
demo/factory.go
2025-03-15 10:17:07 +00:00

302 lines
7.3 KiB
Go

package goproxy
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptrace"
"strconv"
"strings"
"time"
"github.com/darkit/goproxy/config"
"github.com/darkit/goproxy/pkg/auth"
"github.com/darkit/goproxy/pkg/cache"
"github.com/darkit/goproxy/pkg/dns"
"github.com/darkit/goproxy/pkg/healthcheck"
"github.com/darkit/goproxy/pkg/loadbalance"
"github.com/darkit/goproxy/pkg/metrics"
)
// ErrNotSupportHijacking 不支持劫持错误
var ErrNotSupportHijacking = errors.New("connection does not support hijacking")
// CreateUnifiedProxy 创建统一代理
// 为了避免与现有的NewProxy冲突
func CreateUnifiedProxy(options ...UnifiedOption) (UnifiedProxy, error) {
opts := &UnifiedOptions{}
for _, option := range options {
option(opts)
}
return NewUnifiedProxy(opts)
}
// NewForwardProxy 创建正向代理
func NewForwardProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
cfg := config.DefaultUnifiedConfig()
cfg.ListenAddr = listenAddr
cfg.ProxyMode = config.ModeForward
opts := &UnifiedOptions{
Config: cfg,
}
for _, option := range options {
option(opts)
}
return NewUnifiedProxy(opts)
}
// NewReverseProxy 创建反向代理
func NewReverseProxy(listenAddr, targetAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
cfg := config.DefaultUnifiedConfig()
cfg.ListenAddr = listenAddr
cfg.TargetAddr = targetAddr
cfg.ProxyMode = config.ModeReverse
opts := &UnifiedOptions{
Config: cfg,
}
for _, option := range options {
option(opts)
}
return NewUnifiedProxy(opts)
}
// NewTransparentProxy 创建透明代理
func NewTransparentProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
cfg := config.DefaultUnifiedConfig()
cfg.ListenAddr = listenAddr
cfg.ProxyMode = config.ModeTransparent
opts := &UnifiedOptions{
Config: cfg,
}
for _, option := range options {
option(opts)
}
return NewUnifiedProxy(opts)
}
// UnifiedOption 用于配置统一代理选项的函数类型
type UnifiedOption func(*UnifiedOptions)
// WithUnifiedConfig 设置代理配置
func WithUnifiedConfig(cfg *config.UnifiedConfig) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.Config = cfg
}
}
// WithUnifiedDelegate 设置委托类
func WithUnifiedDelegate(delegate Delegate) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.Delegate = delegate
}
}
// WithUnifiedCertCache 设置证书缓存
func WithUnifiedCertCache(cache CertificateCache) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.CertCache = cache
}
}
// WithUnifiedHTTPCache 设置HTTP缓存
func WithUnifiedHTTPCache(c cache.Cache) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.HTTPCache = c
}
}
// WithUnifiedLoadBalancer 设置负载均衡器
func WithUnifiedLoadBalancer(lb loadbalance.LoadBalancer) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.LoadBalancer = lb
}
}
// WithUnifiedHealthChecker 设置健康检查器
func WithUnifiedHealthChecker(hc *healthcheck.HealthChecker) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.HealthChecker = hc
}
}
// WithUnifiedMetrics 设置监控指标
func WithUnifiedMetrics(m metrics.MetricsCollector) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.Metrics = m
}
}
// WithUnifiedClientTrace 设置HTTP客户端跟踪
func WithUnifiedClientTrace(t *httptrace.ClientTrace) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.ClientTrace = t
}
}
// WithUnifiedCertManager 设置证书管理器
func WithUnifiedCertManager(cm *CertManager) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.CertManager = cm
}
}
// WithUnifiedAuth 设置认证系统
func WithUnifiedAuth(a *auth.Auth) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.Auth = a
}
}
// WithUnifiedDNSResolver 设置DNS解析器
func WithUnifiedDNSResolver(resolver *dns.CustomResolver) UnifiedOption {
return func(opt *UnifiedOptions) {
opt.DNSResolver = resolver
}
}
// 处理劫持连接
func hijackerImpl(rw http.ResponseWriter) (*ConnBuffer, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
return nil, ErrNotSupportHijacking
}
conn, bufrw, err := hijacker.Hijack()
if err != nil {
return nil, err
}
return NewConnBuffer(conn, bufrw.Reader), nil
}
// 检查是否是WebSocket请求
func isWebSocketRequestImpl(req *http.Request) bool {
if req.Method != http.MethodGet {
return false
}
connection := req.Header.Get("Connection")
if connection == "" {
return false
}
upgrade := req.Header.Get("Upgrade")
return upgrade != ""
}
// 发送隧道已连接响应
func tunnelConnectedImpl(_ *Context, err error, rw http.ResponseWriter) error {
// 此函数应根据原始代码实现
if err != nil {
if rw != nil {
http.Error(rw, err.Error(), http.StatusServiceUnavailable)
}
return err
}
// 发送连接成功响应
// 这里简化处理,实际应该根据原代码完整实现
if rw != nil {
rw.WriteHeader(http.StatusOK)
}
return nil
}
// 可以缓存的HTTP方法
func canCacheMethodImpl(method string) bool {
return method == http.MethodGet
}
// 可以缓存的状态码
func canCacheStatusImpl(statusCode int) bool {
return statusCode >= 200 && statusCode < 300
}
// 生成缓存键
func generateCacheKeyImpl(req *http.Request) string {
return req.Method + ":" + req.URL.String()
}
// 获取缓存TTL
func getCacheTTLImpl(resp *http.Response) time.Duration {
// 检查响应头中的Cache-Control
cacheControl := resp.Header.Get("Cache-Control")
if cacheControl != "" {
// 解析Cache-Control头
if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") {
return 0 // 不缓存
}
// 查找max-age指令
if strings.Contains(cacheControl, "max-age=") {
parts := strings.Split(cacheControl, "max-age=")
if len(parts) > 1 {
seconds := strings.Split(parts[1], ",")[0]
if maxAge, err := strconv.ParseInt(seconds, 10, 64); err == nil {
return time.Duration(maxAge) * time.Second
}
}
}
}
// 检查Expires头
if expires := resp.Header.Get("Expires"); expires != "" {
if expiresTime, err := time.Parse(time.RFC1123, expires); err == nil {
return time.Until(expiresTime)
}
}
// 返回默认缓存时间
return 5 * time.Minute
}
// 检查是否支持缓存命中指标
func isCacheHitMetricsSupportedImpl(m metrics.MetricsCollector) bool {
// 首先检查是否为nil
if m == nil {
return false
}
// 检查是否实现了具体的缓存命中计数方法
_, hasIncrementCacheHit := m.(interface {
IncrementCacheHit()
})
// 或者检查是否实现了通用的计数器增加方法
_, hasIncrement := m.(interface {
Increment(key string, value float64)
})
// 如果实现了其中任何一个方法,就认为支持缓存命中计数
return hasIncrementCacheHit || hasIncrement
}
// 增加缓存命中计数
func incrementCacheHitImpl(m metrics.MetricsCollector) {
// 如果指标收集器实现了特定接口,则增加缓存命中计数
if counterInterface, ok := m.(interface {
IncrementCacheHit()
}); ok {
counterInterface.IncrementCacheHit()
} else if counterInterface, ok := m.(interface {
Increment(key string, value float64)
}); ok {
counterInterface.Increment("cache_hit", 1)
}
}
// NewTLSConfig 创建新的TLS配置
func NewTLSConfig(serverName string, insecureSkipVerify bool) *tls.Config {
return &tls.Config{
ServerName: serverName,
InsecureSkipVerify: insecureSkipVerify,
}
}