176 lines
4.5 KiB
Go
176 lines
4.5 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"net/http"
|
||
"os"
|
||
"os/signal"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/darkit/goproxy"
|
||
"github.com/darkit/goproxy/config"
|
||
"github.com/darkit/goproxy/pkg/cache"
|
||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||
"github.com/darkit/goproxy/pkg/metrics"
|
||
)
|
||
|
||
// 命令行参数
|
||
var (
|
||
// 监听地址
|
||
listenAddr string
|
||
// 目标地址(后端服务)
|
||
targetAddr string
|
||
// 启用HTTPS
|
||
enableHTTPS bool
|
||
// TLS证书文件
|
||
certFile string
|
||
// TLS私钥文件
|
||
keyFile string
|
||
// 启用负载均衡
|
||
enableLoadBalancing bool
|
||
// 负载均衡目标地址列表
|
||
targets string
|
||
// 启用健康检查
|
||
enableHealthCheck bool
|
||
// 健康检查间隔
|
||
healthCheckInterval time.Duration
|
||
// 启用缓存
|
||
enableCache bool
|
||
// 启用压缩
|
||
enableCompression bool
|
||
// 启用CORS
|
||
enableCORS bool
|
||
)
|
||
|
||
func init() {
|
||
// 解析命令行参数
|
||
flag.StringVar(&listenAddr, "listen", ":8080", "监听地址")
|
||
flag.StringVar(&targetAddr, "target", "http://localhost:9090", "目标地址")
|
||
flag.BoolVar(&enableHTTPS, "https", false, "启用HTTPS")
|
||
flag.StringVar(&certFile, "cert", "", "TLS证书文件")
|
||
flag.StringVar(&keyFile, "key", "", "TLS私钥文件")
|
||
flag.BoolVar(&enableLoadBalancing, "lb", false, "启用负载均衡")
|
||
flag.StringVar(&targets, "targets", "", "负载均衡目标地址列表,用逗号分隔")
|
||
flag.BoolVar(&enableHealthCheck, "health", false, "启用健康检查")
|
||
flag.DurationVar(&healthCheckInterval, "health-interval", 10*time.Second, "健康检查间隔")
|
||
flag.BoolVar(&enableCache, "cache", false, "启用缓存")
|
||
flag.BoolVar(&enableCompression, "compression", true, "启用压缩")
|
||
flag.BoolVar(&enableCORS, "cors", false, "启用CORS")
|
||
}
|
||
|
||
func main() {
|
||
// 解析命令行参数
|
||
flag.Parse()
|
||
|
||
// 创建配置
|
||
cfg := config.DefaultConfig()
|
||
|
||
// 设置基本配置
|
||
cfg.ReverseProxy = true // 启用反向代理模式
|
||
cfg.ListenAddr = listenAddr
|
||
cfg.TargetAddr = targetAddr
|
||
cfg.DecryptHTTPS = enableHTTPS
|
||
cfg.TLSCert = certFile
|
||
cfg.TLSKey = keyFile
|
||
cfg.EnableCache = enableCache
|
||
cfg.EnableCompression = enableCompression
|
||
cfg.EnableCORS = enableCORS
|
||
cfg.PreserveClientIP = true // 保留客户端IP
|
||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||
cfg.AddXRealIP = true // 添加X-Real-IP头
|
||
|
||
// 健康检查配置
|
||
cfg.EnableHealthCheck = enableHealthCheck
|
||
cfg.HealthCheckInterval = healthCheckInterval
|
||
cfg.HealthCheckTimeout = 3 * time.Second
|
||
|
||
// 负载均衡配置
|
||
cfg.EnableLoadBalancing = enableLoadBalancing
|
||
|
||
// 创建代理选项
|
||
opts := &goproxy.Options{
|
||
Config: cfg,
|
||
}
|
||
|
||
// 如果启用负载均衡,创建负载均衡器
|
||
if enableLoadBalancing && targets != "" {
|
||
// 创建轮询负载均衡器
|
||
lb := loadbalance.NewRoundRobinBalancer()
|
||
lb.Add(targets, 1)
|
||
opts.LoadBalancer = lb
|
||
|
||
// 如果启用健康检查,创建健康检查器
|
||
if enableHealthCheck {
|
||
// 健康检查配置
|
||
healthCfg := &healthcheck.Config{
|
||
Interval: healthCheckInterval,
|
||
Timeout: 3 * time.Second,
|
||
MinSuccess: 1,
|
||
MaxFails: 3,
|
||
}
|
||
healthChecker := healthcheck.NewHealthChecker(healthCfg)
|
||
healthChecker.AddTarget(targets)
|
||
opts.HealthChecker = healthChecker
|
||
|
||
// 启动健康检查
|
||
healthChecker.Start()
|
||
defer healthChecker.Stop()
|
||
}
|
||
}
|
||
|
||
// 如果启用缓存,创建缓存
|
||
if enableCache {
|
||
// 创建一个内存缓存,TTL为5分钟
|
||
memCache := cache.NewMemoryCache(5*time.Minute, time.Second, 10000)
|
||
opts.HTTPCache = memCache
|
||
}
|
||
|
||
// 创建指标收集器
|
||
metricsCollector := metrics.NewSimpleMetrics()
|
||
opts.Metrics = metricsCollector
|
||
|
||
// 创建代理
|
||
proxy := goproxy.New(opts)
|
||
|
||
// 创建HTTP服务器
|
||
server := &http.Server{
|
||
Addr: listenAddr,
|
||
Handler: proxy,
|
||
}
|
||
|
||
// 启动HTTP服务器
|
||
go func() {
|
||
fmt.Printf("反向代理启动在 %s,目标地址为 %s\n", listenAddr, targetAddr)
|
||
var err error
|
||
if enableHTTPS && certFile != "" && keyFile != "" {
|
||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||
} else {
|
||
err = server.ListenAndServe()
|
||
}
|
||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||
log.Fatalf("服务器启动失败: %v\n", err)
|
||
}
|
||
}()
|
||
|
||
// 优雅退出
|
||
quit := make(chan os.Signal, 1)
|
||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||
<-quit
|
||
|
||
// 关闭服务器
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := server.Shutdown(ctx); err != nil {
|
||
log.Fatalf("服务器关闭失败: %v\n", err)
|
||
}
|
||
|
||
fmt.Println("服务器已关闭")
|
||
}
|