Files
demo/cmd/reverse_proxy_example/reverse_proxy_example.go
2025-03-15 10:17:07 +00:00

176 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"errors"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/darkit/goproxy"
"github.com/darkit/goproxy/config"
"github.com/darkit/goproxy/pkg/cache"
"github.com/darkit/goproxy/pkg/healthcheck"
"github.com/darkit/goproxy/pkg/loadbalance"
"github.com/darkit/goproxy/pkg/metrics"
)
// 命令行参数
var (
// 监听地址
listenAddr string
// 目标地址(后端服务)
targetAddr string
// 启用HTTPS
enableHTTPS bool
// TLS证书文件
certFile string
// TLS私钥文件
keyFile string
// 启用负载均衡
enableLoadBalancing bool
// 负载均衡目标地址列表
targets string
// 启用健康检查
enableHealthCheck bool
// 健康检查间隔
healthCheckInterval time.Duration
// 启用缓存
enableCache bool
// 启用压缩
enableCompression bool
// 启用CORS
enableCORS bool
)
func init() {
// 解析命令行参数
flag.StringVar(&listenAddr, "listen", ":8080", "监听地址")
flag.StringVar(&targetAddr, "target", "http://localhost:9090", "目标地址")
flag.BoolVar(&enableHTTPS, "https", false, "启用HTTPS")
flag.StringVar(&certFile, "cert", "", "TLS证书文件")
flag.StringVar(&keyFile, "key", "", "TLS私钥文件")
flag.BoolVar(&enableLoadBalancing, "lb", false, "启用负载均衡")
flag.StringVar(&targets, "targets", "", "负载均衡目标地址列表,用逗号分隔")
flag.BoolVar(&enableHealthCheck, "health", false, "启用健康检查")
flag.DurationVar(&healthCheckInterval, "health-interval", 10*time.Second, "健康检查间隔")
flag.BoolVar(&enableCache, "cache", false, "启用缓存")
flag.BoolVar(&enableCompression, "compression", true, "启用压缩")
flag.BoolVar(&enableCORS, "cors", false, "启用CORS")
}
func main() {
// 解析命令行参数
flag.Parse()
// 创建配置
cfg := config.DefaultConfig()
// 设置基本配置
cfg.ReverseProxy = true // 启用反向代理模式
cfg.ListenAddr = listenAddr
cfg.TargetAddr = targetAddr
cfg.DecryptHTTPS = enableHTTPS
cfg.TLSCert = certFile
cfg.TLSKey = keyFile
cfg.EnableCache = enableCache
cfg.EnableCompression = enableCompression
cfg.EnableCORS = enableCORS
cfg.PreserveClientIP = true // 保留客户端IP
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
cfg.AddXRealIP = true // 添加X-Real-IP头
// 健康检查配置
cfg.EnableHealthCheck = enableHealthCheck
cfg.HealthCheckInterval = healthCheckInterval
cfg.HealthCheckTimeout = 3 * time.Second
// 负载均衡配置
cfg.EnableLoadBalancing = enableLoadBalancing
// 创建代理选项
opts := &goproxy.Options{
Config: cfg,
}
// 如果启用负载均衡,创建负载均衡器
if enableLoadBalancing && targets != "" {
// 创建轮询负载均衡器
lb := loadbalance.NewRoundRobinBalancer()
lb.Add(targets, 1)
opts.LoadBalancer = lb
// 如果启用健康检查,创建健康检查器
if enableHealthCheck {
// 健康检查配置
healthCfg := &healthcheck.Config{
Interval: healthCheckInterval,
Timeout: 3 * time.Second,
MinSuccess: 1,
MaxFails: 3,
}
healthChecker := healthcheck.NewHealthChecker(healthCfg)
healthChecker.AddTarget(targets)
opts.HealthChecker = healthChecker
// 启动健康检查
healthChecker.Start()
defer healthChecker.Stop()
}
}
// 如果启用缓存,创建缓存
if enableCache {
// 创建一个内存缓存TTL为5分钟
memCache := cache.NewMemoryCache(5*time.Minute, time.Second, 10000)
opts.HTTPCache = memCache
}
// 创建指标收集器
metricsCollector := metrics.NewSimpleMetrics()
opts.Metrics = metricsCollector
// 创建代理
proxy := goproxy.New(opts)
// 创建HTTP服务器
server := &http.Server{
Addr: listenAddr,
Handler: proxy,
}
// 启动HTTP服务器
go func() {
fmt.Printf("反向代理启动在 %s目标地址为 %s\n", listenAddr, targetAddr)
var err error
if enableHTTPS && certFile != "" && keyFile != "" {
err = server.ListenAndServeTLS(certFile, keyFile)
} else {
err = server.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("服务器启动失败: %v\n", err)
}
}()
// 优雅退出
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
// 关闭服务器
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatalf("服务器关闭失败: %v\n", err)
}
fmt.Println("服务器已关闭")
}