Files
demo/cmd/unified_proxy/main.go
2025-03-15 10:17:07 +00:00

259 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"errors"
"flag"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/darkit/goproxy"
"github.com/darkit/goproxy/config"
"github.com/darkit/goproxy/pkg/cache"
"github.com/darkit/goproxy/pkg/dns"
"github.com/darkit/goproxy/pkg/healthcheck"
"github.com/darkit/goproxy/pkg/loadbalance"
"github.com/darkit/goproxy/pkg/metrics"
)
func main() {
// 通用选项
proxyMode := flag.String("mode", "forward", "代理模式 [forward, reverse, transparent]")
listenAddr := flag.String("listen", ":8080", "监听地址")
targetAddr := flag.String("target", "http://192.168.1.240", "目标地址 (只在反向代理模式下使用)")
enableHTTPS := flag.Bool("https", false, "是否启用HTTPS")
certFile := flag.String("cert", "", "证书文件路径")
keyFile := flag.String("key", "", "密钥文件路径")
enableCache := flag.Bool("cache", false, "是否启用响应缓存")
cacheTTL := flag.Duration("cache-ttl", 5*time.Minute, "缓存过期时间")
insecure := flag.Bool("insecure", false, "是否跳过证书验证")
// 反向代理选项
rulesFile := flag.String("rules", "", "规则文件路径 (只在反向代理模式下使用)")
// 负载均衡选项
backends := flag.String("backends", "", "后端服务器列表,用逗号分隔")
lbStrategy := flag.String("lb-strategy", "round-robin", "负载均衡策略 [round-robin, random, weighted]")
healthCheck := flag.Bool("health-check", false, "是否启用健康检查")
healthInterval := flag.Duration("health-interval", 10*time.Second, "健康检查间隔时间")
// 其他选项
verbose := flag.Bool("verbose", false, "是否输出详细日志")
// 解析命令行参数
flag.Parse()
// 设置日志级别
var logLevel slog.Level
if *verbose {
logLevel = slog.LevelDebug
} else {
logLevel = slog.LevelInfo
}
logLevelVar := &logLevel
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
Level: logLevelVar,
}))
slog.SetDefault(logger)
// 创建统一配置
cfg := config.DefaultUnifiedConfig()
cfg.ListenAddr = *listenAddr
cfg.Logger = logger
// 设置代理模式
switch strings.ToLower(*proxyMode) {
case "forward", "proxy":
cfg.ProxyMode = config.ModeForward
case "reverse":
cfg.ProxyMode = config.ModeReverse
if *targetAddr != "" {
cfg.TargetAddr = *targetAddr
}
if *rulesFile != "" {
cfg.RulesFile = *rulesFile
}
case "transparent":
cfg.ProxyMode = config.ModeTransparent
default:
logger.Error("不支持的代理模式", "mode", *proxyMode)
os.Exit(1)
}
// HTTPS设置
cfg.EnableHTTPS = *enableHTTPS
if *enableHTTPS {
if *certFile == "" || *keyFile == "" {
logger.Error("启用HTTPS需要提供证书和密钥文件")
os.Exit(1)
}
cfg.TLSCert = *certFile
cfg.TLSKey = *keyFile
}
// 缓存设置
cfg.EnableCache = *enableCache
cfg.CacheTTL = *cacheTTL
// SSL验证设置
cfg.InsecureSkipVerify = *insecure
// 添加X-Forwarded-For头
cfg.AddXForwardedFor = true
// 添加X-Real-IP头
cfg.AddXRealIP = true
// 负载均衡设置
if *backends != "" {
backendList := strings.Split(*backends, ",")
cfg.EnableLoadBalancing = true
cfg.Backends = backendList
// 健康检查设置
if *healthCheck {
cfg.EnableHealthCheck = true
cfg.HealthCheckInterval = *healthInterval
cfg.HealthCheckTimeout = 5 * time.Second
}
}
// 创建统一代理选项
opts := &goproxy.UnifiedOptions{
Config: cfg,
}
// 根据需要添加额外组件
// 添加缓存
if *enableCache {
opts.HTTPCache = cache.NewMemoryCache(*cacheTTL, time.Minute, 10000)
}
// 添加负载均衡器
if cfg.EnableLoadBalancing {
var lb loadbalance.LoadBalancer
switch strings.ToLower(*lbStrategy) {
case "round-robin":
lb = loadbalance.NewRoundRobinBalancer()
case "random":
lb = loadbalance.NewRandomBalancer()
case "weighted":
lb = loadbalance.NewRoundRobinBalancer() // 实际上是加权轮询,但简化处理
default:
logger.Warn("不支持的负载均衡策略,使用默认的轮询策略", "strategy", *lbStrategy)
lb = loadbalance.NewRoundRobinBalancer()
}
// 添加后端服务器
for _, backend := range cfg.Backends {
lb.Add(backend, 1)
}
opts.LoadBalancer = lb
// 添加健康检查器
if cfg.EnableHealthCheck {
healthCfg := &healthcheck.Config{
Interval: cfg.HealthCheckInterval,
Timeout: cfg.HealthCheckTimeout,
MaxFails: 3,
MinSuccess: 1,
}
healthChecker := healthcheck.NewHealthChecker(healthCfg)
for _, backend := range cfg.Backends {
healthChecker.AddTarget(backend)
}
opts.HealthChecker = healthChecker
// 启动健康检查
healthChecker.Start()
defer healthChecker.Stop()
}
}
// 添加DNS解析器
if cfg.ProxyMode == config.ModeReverse {
resolver := dns.NewResolver(
dns.WithFallback(true), // 如果找不到自定义规则回退到系统DNS
dns.WithTTL(5*time.Minute), // 设置缓存TTL
)
// 将解析器添加到选项中
opts.DNSResolver = resolver
logger.Info("已创建自定义DNS解析器", "fallback", true, "ttl", "5m")
}
// 添加监控指标
opts.Metrics = metrics.NewSimpleMetrics()
// 创建统一代理
proxy, err := goproxy.NewUnifiedProxy(opts)
if err != nil {
logger.Error("创建代理失败", "error", err)
os.Exit(1)
}
// 打印启动信息
logger.Info("代理服务器启动", "mode", cfg.ProxyMode, "listen", cfg.ListenAddr, "https", cfg.EnableHTTPS)
if cfg.ProxyMode == config.ModeReverse {
logger.Info("反向代理设置", "target", cfg.TargetAddr, "rules", cfg.RulesFile)
}
if cfg.EnableLoadBalancing {
logger.Info("负载均衡设置", "strategy", *lbStrategy, "backends", cfg.Backends, "health_check", cfg.EnableHealthCheck)
}
// 创建HTTP服务器
server := &http.Server{
Addr: cfg.ListenAddr,
Handler: proxy,
}
// 启动HTTP服务器
go func() {
var err error
if cfg.EnableHTTPS {
err = server.ListenAndServeTLS(cfg.TLSCert, cfg.TLSKey)
} else {
err = server.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error("服务器运行失败", "error", err)
os.Exit(1)
}
}()
// 等待中断信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info("正在关闭服务器...")
// 优雅关闭
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
logger.Error("服务器关闭失败", "error", err)
}
// 关闭代理
if err := proxy.Close(); err != nil {
logger.Error("代理关闭失败", "error", err)
}
logger.Info("服务器已关闭")
}