package main import ( "context" "errors" "flag" "log/slog" "net/http" "os" "os/signal" "strings" "syscall" "time" "github.com/darkit/goproxy" "github.com/darkit/goproxy/config" "github.com/darkit/goproxy/pkg/cache" "github.com/darkit/goproxy/pkg/dns" "github.com/darkit/goproxy/pkg/healthcheck" "github.com/darkit/goproxy/pkg/loadbalance" "github.com/darkit/goproxy/pkg/metrics" ) func main() { // 通用选项 proxyMode := flag.String("mode", "forward", "代理模式 [forward, reverse, transparent]") listenAddr := flag.String("listen", ":8080", "监听地址") targetAddr := flag.String("target", "http://192.168.1.240", "目标地址 (只在反向代理模式下使用)") enableHTTPS := flag.Bool("https", false, "是否启用HTTPS") certFile := flag.String("cert", "", "证书文件路径") keyFile := flag.String("key", "", "密钥文件路径") enableCache := flag.Bool("cache", false, "是否启用响应缓存") cacheTTL := flag.Duration("cache-ttl", 5*time.Minute, "缓存过期时间") insecure := flag.Bool("insecure", false, "是否跳过证书验证") // 反向代理选项 rulesFile := flag.String("rules", "", "规则文件路径 (只在反向代理模式下使用)") // 负载均衡选项 backends := flag.String("backends", "", "后端服务器列表,用逗号分隔") lbStrategy := flag.String("lb-strategy", "round-robin", "负载均衡策略 [round-robin, random, weighted]") healthCheck := flag.Bool("health-check", false, "是否启用健康检查") healthInterval := flag.Duration("health-interval", 10*time.Second, "健康检查间隔时间") // 其他选项 verbose := flag.Bool("verbose", false, "是否输出详细日志") // 解析命令行参数 flag.Parse() // 设置日志级别 var logLevel slog.Level if *verbose { logLevel = slog.LevelDebug } else { logLevel = slog.LevelInfo } logLevelVar := &logLevel logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: logLevelVar, })) slog.SetDefault(logger) // 创建统一配置 cfg := config.DefaultUnifiedConfig() cfg.ListenAddr = *listenAddr cfg.Logger = logger // 设置代理模式 switch strings.ToLower(*proxyMode) { case "forward", "proxy": cfg.ProxyMode = config.ModeForward case "reverse": cfg.ProxyMode = config.ModeReverse if *targetAddr != "" { cfg.TargetAddr = *targetAddr } if *rulesFile != "" { cfg.RulesFile = *rulesFile } case "transparent": cfg.ProxyMode = config.ModeTransparent default: logger.Error("不支持的代理模式", "mode", *proxyMode) os.Exit(1) } // HTTPS设置 cfg.EnableHTTPS = *enableHTTPS if *enableHTTPS { if *certFile == "" || *keyFile == "" { logger.Error("启用HTTPS需要提供证书和密钥文件") os.Exit(1) } cfg.TLSCert = *certFile cfg.TLSKey = *keyFile } // 缓存设置 cfg.EnableCache = *enableCache cfg.CacheTTL = *cacheTTL // SSL验证设置 cfg.InsecureSkipVerify = *insecure // 添加X-Forwarded-For头 cfg.AddXForwardedFor = true // 添加X-Real-IP头 cfg.AddXRealIP = true // 负载均衡设置 if *backends != "" { backendList := strings.Split(*backends, ",") cfg.EnableLoadBalancing = true cfg.Backends = backendList // 健康检查设置 if *healthCheck { cfg.EnableHealthCheck = true cfg.HealthCheckInterval = *healthInterval cfg.HealthCheckTimeout = 5 * time.Second } } // 创建统一代理选项 opts := &goproxy.UnifiedOptions{ Config: cfg, } // 根据需要添加额外组件 // 添加缓存 if *enableCache { opts.HTTPCache = cache.NewMemoryCache(*cacheTTL, time.Minute, 10000) } // 添加负载均衡器 if cfg.EnableLoadBalancing { var lb loadbalance.LoadBalancer switch strings.ToLower(*lbStrategy) { case "round-robin": lb = loadbalance.NewRoundRobinBalancer() case "random": lb = loadbalance.NewRandomBalancer() case "weighted": lb = loadbalance.NewRoundRobinBalancer() // 实际上是加权轮询,但简化处理 default: logger.Warn("不支持的负载均衡策略,使用默认的轮询策略", "strategy", *lbStrategy) lb = loadbalance.NewRoundRobinBalancer() } // 添加后端服务器 for _, backend := range cfg.Backends { lb.Add(backend, 1) } opts.LoadBalancer = lb // 添加健康检查器 if cfg.EnableHealthCheck { healthCfg := &healthcheck.Config{ Interval: cfg.HealthCheckInterval, Timeout: cfg.HealthCheckTimeout, MaxFails: 3, MinSuccess: 1, } healthChecker := healthcheck.NewHealthChecker(healthCfg) for _, backend := range cfg.Backends { healthChecker.AddTarget(backend) } opts.HealthChecker = healthChecker // 启动健康检查 healthChecker.Start() defer healthChecker.Stop() } } // 添加DNS解析器 if cfg.ProxyMode == config.ModeReverse { resolver := dns.NewResolver( dns.WithFallback(true), // 如果找不到自定义规则,回退到系统DNS dns.WithTTL(5*time.Minute), // 设置缓存TTL ) // 将解析器添加到选项中 opts.DNSResolver = resolver logger.Info("已创建自定义DNS解析器", "fallback", true, "ttl", "5m") } // 添加监控指标 opts.Metrics = metrics.NewSimpleMetrics() // 创建统一代理 proxy, err := goproxy.NewUnifiedProxy(opts) if err != nil { logger.Error("创建代理失败", "error", err) os.Exit(1) } // 打印启动信息 logger.Info("代理服务器启动", "mode", cfg.ProxyMode, "listen", cfg.ListenAddr, "https", cfg.EnableHTTPS) if cfg.ProxyMode == config.ModeReverse { logger.Info("反向代理设置", "target", cfg.TargetAddr, "rules", cfg.RulesFile) } if cfg.EnableLoadBalancing { logger.Info("负载均衡设置", "strategy", *lbStrategy, "backends", cfg.Backends, "health_check", cfg.EnableHealthCheck) } // 创建HTTP服务器 server := &http.Server{ Addr: cfg.ListenAddr, Handler: proxy, } // 启动HTTP服务器 go func() { var err error if cfg.EnableHTTPS { err = server.ListenAndServeTLS(cfg.TLSCert, cfg.TLSKey) } else { err = server.ListenAndServe() } if err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Error("服务器运行失败", "error", err) os.Exit(1) } }() // 等待中断信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit logger.Info("正在关闭服务器...") // 优雅关闭 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { logger.Error("服务器关闭失败", "error", err) } // 关闭代理 if err := proxy.Close(); err != nil { logger.Error("代理关闭失败", "error", err) } logger.Info("服务器已关闭") }