update
This commit is contained in:
301
factory.go
Normal file
301
factory.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/config"
|
||||
"github.com/darkit/goproxy/pkg/auth"
|
||||
"github.com/darkit/goproxy/pkg/cache"
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// ErrNotSupportHijacking 不支持劫持错误
|
||||
var ErrNotSupportHijacking = errors.New("connection does not support hijacking")
|
||||
|
||||
// CreateUnifiedProxy 创建统一代理
|
||||
// 为了避免与现有的NewProxy冲突
|
||||
func CreateUnifiedProxy(options ...UnifiedOption) (UnifiedProxy, error) {
|
||||
opts := &UnifiedOptions{}
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
return NewUnifiedProxy(opts)
|
||||
}
|
||||
|
||||
// NewForwardProxy 创建正向代理
|
||||
func NewForwardProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
||||
cfg := config.DefaultUnifiedConfig()
|
||||
cfg.ListenAddr = listenAddr
|
||||
cfg.ProxyMode = config.ModeForward
|
||||
|
||||
opts := &UnifiedOptions{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
return NewUnifiedProxy(opts)
|
||||
}
|
||||
|
||||
// NewReverseProxy 创建反向代理
|
||||
func NewReverseProxy(listenAddr, targetAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
||||
cfg := config.DefaultUnifiedConfig()
|
||||
cfg.ListenAddr = listenAddr
|
||||
cfg.TargetAddr = targetAddr
|
||||
cfg.ProxyMode = config.ModeReverse
|
||||
|
||||
opts := &UnifiedOptions{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
return NewUnifiedProxy(opts)
|
||||
}
|
||||
|
||||
// NewTransparentProxy 创建透明代理
|
||||
func NewTransparentProxy(listenAddr string, options ...UnifiedOption) (UnifiedProxy, error) {
|
||||
cfg := config.DefaultUnifiedConfig()
|
||||
cfg.ListenAddr = listenAddr
|
||||
cfg.ProxyMode = config.ModeTransparent
|
||||
|
||||
opts := &UnifiedOptions{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
return NewUnifiedProxy(opts)
|
||||
}
|
||||
|
||||
// UnifiedOption 用于配置统一代理选项的函数类型
|
||||
type UnifiedOption func(*UnifiedOptions)
|
||||
|
||||
// WithUnifiedConfig 设置代理配置
|
||||
func WithUnifiedConfig(cfg *config.UnifiedConfig) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.Config = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedDelegate 设置委托类
|
||||
func WithUnifiedDelegate(delegate Delegate) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.Delegate = delegate
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedCertCache 设置证书缓存
|
||||
func WithUnifiedCertCache(cache CertificateCache) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.CertCache = cache
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedHTTPCache 设置HTTP缓存
|
||||
func WithUnifiedHTTPCache(c cache.Cache) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.HTTPCache = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedLoadBalancer 设置负载均衡器
|
||||
func WithUnifiedLoadBalancer(lb loadbalance.LoadBalancer) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.LoadBalancer = lb
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedHealthChecker 设置健康检查器
|
||||
func WithUnifiedHealthChecker(hc *healthcheck.HealthChecker) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.HealthChecker = hc
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedMetrics 设置监控指标
|
||||
func WithUnifiedMetrics(m metrics.MetricsCollector) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.Metrics = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedClientTrace 设置HTTP客户端跟踪
|
||||
func WithUnifiedClientTrace(t *httptrace.ClientTrace) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.ClientTrace = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedCertManager 设置证书管理器
|
||||
func WithUnifiedCertManager(cm *CertManager) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.CertManager = cm
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedAuth 设置认证系统
|
||||
func WithUnifiedAuth(a *auth.Auth) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.Auth = a
|
||||
}
|
||||
}
|
||||
|
||||
// WithUnifiedDNSResolver 设置DNS解析器
|
||||
func WithUnifiedDNSResolver(resolver *dns.CustomResolver) UnifiedOption {
|
||||
return func(opt *UnifiedOptions) {
|
||||
opt.DNSResolver = resolver
|
||||
}
|
||||
}
|
||||
|
||||
// 处理劫持连接
|
||||
func hijackerImpl(rw http.ResponseWriter) (*ConnBuffer, error) {
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, ErrNotSupportHijacking
|
||||
}
|
||||
conn, bufrw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewConnBuffer(conn, bufrw.Reader), nil
|
||||
}
|
||||
|
||||
// 检查是否是WebSocket请求
|
||||
func isWebSocketRequestImpl(req *http.Request) bool {
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
connection := req.Header.Get("Connection")
|
||||
if connection == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
upgrade := req.Header.Get("Upgrade")
|
||||
return upgrade != ""
|
||||
}
|
||||
|
||||
// 发送隧道已连接响应
|
||||
func tunnelConnectedImpl(_ *Context, err error, rw http.ResponseWriter) error {
|
||||
// 此函数应根据原始代码实现
|
||||
if err != nil {
|
||||
if rw != nil {
|
||||
http.Error(rw, err.Error(), http.StatusServiceUnavailable)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送连接成功响应
|
||||
// 这里简化处理,实际应该根据原代码完整实现
|
||||
if rw != nil {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 可以缓存的HTTP方法
|
||||
func canCacheMethodImpl(method string) bool {
|
||||
return method == http.MethodGet
|
||||
}
|
||||
|
||||
// 可以缓存的状态码
|
||||
func canCacheStatusImpl(statusCode int) bool {
|
||||
return statusCode >= 200 && statusCode < 300
|
||||
}
|
||||
|
||||
// 生成缓存键
|
||||
func generateCacheKeyImpl(req *http.Request) string {
|
||||
return req.Method + ":" + req.URL.String()
|
||||
}
|
||||
|
||||
// 获取缓存TTL
|
||||
func getCacheTTLImpl(resp *http.Response) time.Duration {
|
||||
// 检查响应头中的Cache-Control
|
||||
cacheControl := resp.Header.Get("Cache-Control")
|
||||
if cacheControl != "" {
|
||||
// 解析Cache-Control头
|
||||
if strings.Contains(cacheControl, "no-store") || strings.Contains(cacheControl, "no-cache") {
|
||||
return 0 // 不缓存
|
||||
}
|
||||
|
||||
// 查找max-age指令
|
||||
if strings.Contains(cacheControl, "max-age=") {
|
||||
parts := strings.Split(cacheControl, "max-age=")
|
||||
if len(parts) > 1 {
|
||||
seconds := strings.Split(parts[1], ",")[0]
|
||||
if maxAge, err := strconv.ParseInt(seconds, 10, 64); err == nil {
|
||||
return time.Duration(maxAge) * time.Second
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查Expires头
|
||||
if expires := resp.Header.Get("Expires"); expires != "" {
|
||||
if expiresTime, err := time.Parse(time.RFC1123, expires); err == nil {
|
||||
return time.Until(expiresTime)
|
||||
}
|
||||
}
|
||||
|
||||
// 返回默认缓存时间
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
// 检查是否支持缓存命中指标
|
||||
func isCacheHitMetricsSupportedImpl(m metrics.MetricsCollector) bool {
|
||||
// 首先检查是否为nil
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否实现了具体的缓存命中计数方法
|
||||
_, hasIncrementCacheHit := m.(interface {
|
||||
IncrementCacheHit()
|
||||
})
|
||||
|
||||
// 或者检查是否实现了通用的计数器增加方法
|
||||
_, hasIncrement := m.(interface {
|
||||
Increment(key string, value float64)
|
||||
})
|
||||
|
||||
// 如果实现了其中任何一个方法,就认为支持缓存命中计数
|
||||
return hasIncrementCacheHit || hasIncrement
|
||||
}
|
||||
|
||||
// 增加缓存命中计数
|
||||
func incrementCacheHitImpl(m metrics.MetricsCollector) {
|
||||
// 如果指标收集器实现了特定接口,则增加缓存命中计数
|
||||
if counterInterface, ok := m.(interface {
|
||||
IncrementCacheHit()
|
||||
}); ok {
|
||||
counterInterface.IncrementCacheHit()
|
||||
} else if counterInterface, ok := m.(interface {
|
||||
Increment(key string, value float64)
|
||||
}); ok {
|
||||
counterInterface.Increment("cache_hit", 1)
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSConfig 创建新的TLS配置
|
||||
func NewTLSConfig(serverName string, insecureSkipVerify bool) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user