package main import ( "context" "errors" "flag" "fmt" "log" "net/http" "os" "os/signal" "syscall" "time" "github.com/darkit/goproxy" "github.com/darkit/goproxy/config" "github.com/darkit/goproxy/pkg/cache" "github.com/darkit/goproxy/pkg/healthcheck" "github.com/darkit/goproxy/pkg/loadbalance" "github.com/darkit/goproxy/pkg/metrics" ) // 命令行参数 var ( // 监听地址 listenAddr string // 目标地址(后端服务) targetAddr string // 启用HTTPS enableHTTPS bool // TLS证书文件 certFile string // TLS私钥文件 keyFile string // 启用负载均衡 enableLoadBalancing bool // 负载均衡目标地址列表 targets string // 启用健康检查 enableHealthCheck bool // 健康检查间隔 healthCheckInterval time.Duration // 启用缓存 enableCache bool // 启用压缩 enableCompression bool // 启用CORS enableCORS bool ) func init() { // 解析命令行参数 flag.StringVar(&listenAddr, "listen", ":8080", "监听地址") flag.StringVar(&targetAddr, "target", "http://localhost:9090", "目标地址") flag.BoolVar(&enableHTTPS, "https", false, "启用HTTPS") flag.StringVar(&certFile, "cert", "", "TLS证书文件") flag.StringVar(&keyFile, "key", "", "TLS私钥文件") flag.BoolVar(&enableLoadBalancing, "lb", false, "启用负载均衡") flag.StringVar(&targets, "targets", "", "负载均衡目标地址列表,用逗号分隔") flag.BoolVar(&enableHealthCheck, "health", false, "启用健康检查") flag.DurationVar(&healthCheckInterval, "health-interval", 10*time.Second, "健康检查间隔") flag.BoolVar(&enableCache, "cache", false, "启用缓存") flag.BoolVar(&enableCompression, "compression", true, "启用压缩") flag.BoolVar(&enableCORS, "cors", false, "启用CORS") } func main() { // 解析命令行参数 flag.Parse() // 创建配置 cfg := config.DefaultConfig() // 设置基本配置 cfg.ReverseProxy = true // 启用反向代理模式 cfg.ListenAddr = listenAddr cfg.TargetAddr = targetAddr cfg.DecryptHTTPS = enableHTTPS cfg.TLSCert = certFile cfg.TLSKey = keyFile cfg.EnableCache = enableCache cfg.EnableCompression = enableCompression cfg.EnableCORS = enableCORS cfg.PreserveClientIP = true // 保留客户端IP cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 cfg.AddXRealIP = true // 添加X-Real-IP头 // 健康检查配置 cfg.EnableHealthCheck = enableHealthCheck cfg.HealthCheckInterval = healthCheckInterval cfg.HealthCheckTimeout = 3 * time.Second // 负载均衡配置 cfg.EnableLoadBalancing = enableLoadBalancing // 创建代理选项 opts := &goproxy.Options{ Config: cfg, } // 如果启用负载均衡,创建负载均衡器 if enableLoadBalancing && targets != "" { // 创建轮询负载均衡器 lb := loadbalance.NewRoundRobinBalancer() lb.Add(targets, 1) opts.LoadBalancer = lb // 如果启用健康检查,创建健康检查器 if enableHealthCheck { // 健康检查配置 healthCfg := &healthcheck.Config{ Interval: healthCheckInterval, Timeout: 3 * time.Second, MinSuccess: 1, MaxFails: 3, } healthChecker := healthcheck.NewHealthChecker(healthCfg) healthChecker.AddTarget(targets) opts.HealthChecker = healthChecker // 启动健康检查 healthChecker.Start() defer healthChecker.Stop() } } // 如果启用缓存,创建缓存 if enableCache { // 创建一个内存缓存,TTL为5分钟 memCache := cache.NewMemoryCache(5*time.Minute, time.Second, 10000) opts.HTTPCache = memCache } // 创建指标收集器 metricsCollector := metrics.NewSimpleMetrics() opts.Metrics = metricsCollector // 创建代理 proxy := goproxy.New(opts) // 创建HTTP服务器 server := &http.Server{ Addr: listenAddr, Handler: proxy, } // 启动HTTP服务器 go func() { fmt.Printf("反向代理启动在 %s,目标地址为 %s\n", listenAddr, targetAddr) var err error if enableHTTPS && certFile != "" && keyFile != "" { err = server.ListenAndServeTLS(certFile, keyFile) } else { err = server.ListenAndServe() } if err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("服务器启动失败: %v\n", err) } }() // 优雅退出 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit // 关闭服务器 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { log.Fatalf("服务器关闭失败: %v\n", err) } fmt.Println("服务器已关闭") }