259 lines
6.6 KiB
Go
259 lines
6.6 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"flag"
|
||
"log/slog"
|
||
"net/http"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/darkit/goproxy"
|
||
"github.com/darkit/goproxy/config"
|
||
"github.com/darkit/goproxy/pkg/cache"
|
||
"github.com/darkit/goproxy/pkg/dns"
|
||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||
"github.com/darkit/goproxy/pkg/metrics"
|
||
)
|
||
|
||
func main() {
|
||
// 通用选项
|
||
proxyMode := flag.String("mode", "forward", "代理模式 [forward, reverse, transparent]")
|
||
listenAddr := flag.String("listen", ":8080", "监听地址")
|
||
targetAddr := flag.String("target", "http://192.168.1.240", "目标地址 (只在反向代理模式下使用)")
|
||
enableHTTPS := flag.Bool("https", false, "是否启用HTTPS")
|
||
certFile := flag.String("cert", "", "证书文件路径")
|
||
keyFile := flag.String("key", "", "密钥文件路径")
|
||
enableCache := flag.Bool("cache", false, "是否启用响应缓存")
|
||
cacheTTL := flag.Duration("cache-ttl", 5*time.Minute, "缓存过期时间")
|
||
insecure := flag.Bool("insecure", false, "是否跳过证书验证")
|
||
|
||
// 反向代理选项
|
||
rulesFile := flag.String("rules", "", "规则文件路径 (只在反向代理模式下使用)")
|
||
|
||
// 负载均衡选项
|
||
backends := flag.String("backends", "", "后端服务器列表,用逗号分隔")
|
||
lbStrategy := flag.String("lb-strategy", "round-robin", "负载均衡策略 [round-robin, random, weighted]")
|
||
healthCheck := flag.Bool("health-check", false, "是否启用健康检查")
|
||
healthInterval := flag.Duration("health-interval", 10*time.Second, "健康检查间隔时间")
|
||
|
||
// 其他选项
|
||
verbose := flag.Bool("verbose", false, "是否输出详细日志")
|
||
|
||
// 解析命令行参数
|
||
flag.Parse()
|
||
|
||
// 设置日志级别
|
||
var logLevel slog.Level
|
||
if *verbose {
|
||
logLevel = slog.LevelDebug
|
||
} else {
|
||
logLevel = slog.LevelInfo
|
||
}
|
||
|
||
logLevelVar := &logLevel
|
||
|
||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
||
Level: logLevelVar,
|
||
}))
|
||
slog.SetDefault(logger)
|
||
|
||
// 创建统一配置
|
||
cfg := config.DefaultUnifiedConfig()
|
||
cfg.ListenAddr = *listenAddr
|
||
cfg.Logger = logger
|
||
|
||
// 设置代理模式
|
||
switch strings.ToLower(*proxyMode) {
|
||
case "forward", "proxy":
|
||
cfg.ProxyMode = config.ModeForward
|
||
case "reverse":
|
||
cfg.ProxyMode = config.ModeReverse
|
||
if *targetAddr != "" {
|
||
cfg.TargetAddr = *targetAddr
|
||
}
|
||
if *rulesFile != "" {
|
||
cfg.RulesFile = *rulesFile
|
||
}
|
||
case "transparent":
|
||
cfg.ProxyMode = config.ModeTransparent
|
||
default:
|
||
logger.Error("不支持的代理模式", "mode", *proxyMode)
|
||
os.Exit(1)
|
||
}
|
||
|
||
// HTTPS设置
|
||
cfg.EnableHTTPS = *enableHTTPS
|
||
if *enableHTTPS {
|
||
if *certFile == "" || *keyFile == "" {
|
||
logger.Error("启用HTTPS需要提供证书和密钥文件")
|
||
os.Exit(1)
|
||
}
|
||
cfg.TLSCert = *certFile
|
||
cfg.TLSKey = *keyFile
|
||
}
|
||
|
||
// 缓存设置
|
||
cfg.EnableCache = *enableCache
|
||
cfg.CacheTTL = *cacheTTL
|
||
|
||
// SSL验证设置
|
||
cfg.InsecureSkipVerify = *insecure
|
||
|
||
// 添加X-Forwarded-For头
|
||
cfg.AddXForwardedFor = true
|
||
// 添加X-Real-IP头
|
||
cfg.AddXRealIP = true
|
||
|
||
// 负载均衡设置
|
||
if *backends != "" {
|
||
backendList := strings.Split(*backends, ",")
|
||
cfg.EnableLoadBalancing = true
|
||
cfg.Backends = backendList
|
||
|
||
// 健康检查设置
|
||
if *healthCheck {
|
||
cfg.EnableHealthCheck = true
|
||
cfg.HealthCheckInterval = *healthInterval
|
||
cfg.HealthCheckTimeout = 5 * time.Second
|
||
}
|
||
}
|
||
|
||
// 创建统一代理选项
|
||
opts := &goproxy.UnifiedOptions{
|
||
Config: cfg,
|
||
}
|
||
|
||
// 根据需要添加额外组件
|
||
|
||
// 添加缓存
|
||
if *enableCache {
|
||
opts.HTTPCache = cache.NewMemoryCache(*cacheTTL, time.Minute, 10000)
|
||
}
|
||
|
||
// 添加负载均衡器
|
||
if cfg.EnableLoadBalancing {
|
||
var lb loadbalance.LoadBalancer
|
||
|
||
switch strings.ToLower(*lbStrategy) {
|
||
case "round-robin":
|
||
lb = loadbalance.NewRoundRobinBalancer()
|
||
case "random":
|
||
lb = loadbalance.NewRandomBalancer()
|
||
case "weighted":
|
||
lb = loadbalance.NewRoundRobinBalancer() // 实际上是加权轮询,但简化处理
|
||
default:
|
||
logger.Warn("不支持的负载均衡策略,使用默认的轮询策略", "strategy", *lbStrategy)
|
||
lb = loadbalance.NewRoundRobinBalancer()
|
||
}
|
||
|
||
// 添加后端服务器
|
||
for _, backend := range cfg.Backends {
|
||
lb.Add(backend, 1)
|
||
}
|
||
|
||
opts.LoadBalancer = lb
|
||
|
||
// 添加健康检查器
|
||
if cfg.EnableHealthCheck {
|
||
healthCfg := &healthcheck.Config{
|
||
Interval: cfg.HealthCheckInterval,
|
||
Timeout: cfg.HealthCheckTimeout,
|
||
MaxFails: 3,
|
||
MinSuccess: 1,
|
||
}
|
||
|
||
healthChecker := healthcheck.NewHealthChecker(healthCfg)
|
||
for _, backend := range cfg.Backends {
|
||
healthChecker.AddTarget(backend)
|
||
}
|
||
|
||
opts.HealthChecker = healthChecker
|
||
|
||
// 启动健康检查
|
||
healthChecker.Start()
|
||
defer healthChecker.Stop()
|
||
}
|
||
}
|
||
|
||
// 添加DNS解析器
|
||
if cfg.ProxyMode == config.ModeReverse {
|
||
resolver := dns.NewResolver(
|
||
dns.WithFallback(true), // 如果找不到自定义规则,回退到系统DNS
|
||
dns.WithTTL(5*time.Minute), // 设置缓存TTL
|
||
)
|
||
// 将解析器添加到选项中
|
||
opts.DNSResolver = resolver
|
||
logger.Info("已创建自定义DNS解析器", "fallback", true, "ttl", "5m")
|
||
}
|
||
|
||
// 添加监控指标
|
||
opts.Metrics = metrics.NewSimpleMetrics()
|
||
|
||
// 创建统一代理
|
||
proxy, err := goproxy.NewUnifiedProxy(opts)
|
||
if err != nil {
|
||
logger.Error("创建代理失败", "error", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 打印启动信息
|
||
logger.Info("代理服务器启动", "mode", cfg.ProxyMode, "listen", cfg.ListenAddr, "https", cfg.EnableHTTPS)
|
||
|
||
if cfg.ProxyMode == config.ModeReverse {
|
||
logger.Info("反向代理设置", "target", cfg.TargetAddr, "rules", cfg.RulesFile)
|
||
}
|
||
|
||
if cfg.EnableLoadBalancing {
|
||
logger.Info("负载均衡设置", "strategy", *lbStrategy, "backends", cfg.Backends, "health_check", cfg.EnableHealthCheck)
|
||
}
|
||
|
||
// 创建HTTP服务器
|
||
server := &http.Server{
|
||
Addr: cfg.ListenAddr,
|
||
Handler: proxy,
|
||
}
|
||
|
||
// 启动HTTP服务器
|
||
go func() {
|
||
var err error
|
||
if cfg.EnableHTTPS {
|
||
err = server.ListenAndServeTLS(cfg.TLSCert, cfg.TLSKey)
|
||
} else {
|
||
err = server.ListenAndServe()
|
||
}
|
||
|
||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||
logger.Error("服务器运行失败", "error", err)
|
||
os.Exit(1)
|
||
}
|
||
}()
|
||
|
||
// 等待中断信号
|
||
quit := make(chan os.Signal, 1)
|
||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||
<-quit
|
||
|
||
logger.Info("正在关闭服务器...")
|
||
|
||
// 优雅关闭
|
||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||
defer cancel()
|
||
|
||
if err := server.Shutdown(ctx); err != nil {
|
||
logger.Error("服务器关闭失败", "error", err)
|
||
}
|
||
|
||
// 关闭代理
|
||
if err := proxy.Close(); err != nil {
|
||
logger.Error("代理关闭失败", "error", err)
|
||
}
|
||
|
||
logger.Info("服务器已关闭")
|
||
}
|