package main import ( "flag" "log" "net/http" "os" "os/signal" "strings" "syscall" "time" "github.com/goproxy/internal/config" "github.com/goproxy/internal/metrics" "github.com/goproxy/internal/proxy" ) var ( // 监听地址 addr = flag.String("addr", ":8080", "反向代理服务器监听地址") // 后端服务器 backend = flag.String("backend", "localhost:8081", "后端服务器地址") // 路由规则文件 routeFile = flag.String("route-file", "", "路由规则文件路径") // 是否启用URL重写 enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写") // 是否启用缓存 enableCache = flag.Bool("enable-cache", false, "是否启用缓存") // 是否启用压缩 enableCompression = flag.Bool("enable-compression", false, "是否启用压缩") // 是否启用监控 enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控") // 监控地址 metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址") // 是否添加X-Forwarded-For addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头") // 是否添加X-Real-IP addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头") // 是否启用CORS enableCORS = flag.Bool("enable-cors", false, "是否启用CORS") // 路径前缀 pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除") ) func main() { // 解析命令行参数 flag.Parse() // 创建配置 cfg := config.DefaultConfig() cfg.ReverseProxy = true cfg.EnableCache = *enableCache cfg.EnableCompression = *enableCompression cfg.EnableURLRewrite = *enableRewrite cfg.AddXForwardedFor = *addXForwardedFor cfg.AddXRealIP = *addXRealIP cfg.EnableCORS = *enableCORS cfg.ReverseProxyRulesFile = *routeFile // 创建选项 opts := &proxy.Options{ Config: cfg, } // 创建监控 if *enableMetrics { m := metrics.NewSimpleMetrics() opts.Metrics = m // 启动监控服务器 go func() { mux := http.NewServeMux() handler := m.GetHandler() mux.Handle("/metrics", handler) log.Printf("监控服务器启动在 %s\n", *metricsAddr) if err := http.ListenAndServe(*metricsAddr, mux); err != nil { log.Fatalf("监控服务器启动失败: %v", err) } }() } // 创建自定义委托 delegate := &ReverseProxyDelegate{ backend: *backend, prefix: *pathPrefix, } opts.Delegate = delegate // 创建代理 p := proxy.New(opts) // 如果有路径前缀,添加重写规则 if *pathPrefix != "" { reverseProxy := p.NewReverseProxy() log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix) reverseProxy.AddRewriteRule(*pathPrefix, "", false) } // 创建HTTP服务器 server := &http.Server{ Addr: *addr, Handler: p, } // 启动HTTP服务器 go func() { log.Printf("反向代理服务器启动在 %s,后端服务器为 %s\n", *addr, *backend) if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("代理服务器启动失败: %v", err) } }() // 等待信号 quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) <-quit log.Println("正在关闭代理服务器...") server.Close() log.Println("代理服务器已关闭") } // ReverseProxyDelegate 反向代理委托 type ReverseProxyDelegate struct { proxy.DefaultDelegate backend string prefix string } // ResolveBackend 解析后端服务器 func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) { // 这里可以实现基于请求路径、主机名等的路由逻辑 return d.backend, nil } // ModifyRequest 修改请求 func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) { // 移除路径前缀 if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) { req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix) if req.URL.Path == "" { req.URL.Path = "/" } } // 添加自定义请求头 req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339)) } // ModifyResponse 修改响应 func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error { // 添加自定义响应头 resp.Header.Set("X-Proxied-By", "GoProxy") return nil } // Connect 连接事件 func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) { log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path) } // BeforeRequest 请求前事件 func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) { log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path) } // BeforeResponse 响应前事件 func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) { if err != nil { log.Printf("响应错误: %v\n", err) return } log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status) } // ErrorLog 错误日志 func (d *ReverseProxyDelegate) ErrorLog(err error) { log.Printf("错误: %v\n", err) } // HandleError 处理错误 func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) { log.Printf("处理错误: %v\n", err) http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway) }