Files
DarkiT d423ed5029 refactor: 将所有对象池升级为泛型实现
使用 Go 1.18+ 泛型特性优化对象池,提高类型安全性和性能。

- 将 bufPool 和 ctxPool 升级为使用已有的泛型池实现
- 移除所有 Get() 操作后的类型断言
- 保持 API 兼容性,确保现有代码无需大量修改
- 优化流数据传输中内存使用

性能改进:
- 减少运行时类型检查开销
- 消除了类型断言导致的潜在 panic 风险
- 优化了高并发场景下的内存分配模式
2025-03-14 01:37:55 +08:00

183 lines
5.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"flag"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/darkit/goproxy/internal/config"
"github.com/darkit/goproxy/internal/proxy"
"github.com/darkit/goproxy/internal/reverse"
"github.com/darkit/goproxy/internal/rule"
)
var (
// 监听地址
addr = flag.String("addr", ":8080", "反向代理服务器监听地址")
// 后端服务器
backend = flag.String("backend", "192.168.1.212:80", "后端服务器地址")
// 路由规则文件
routeFile = flag.String("route-file", "", "路由规则文件路径")
// 是否启用URL重写
enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写")
// 是否启用缓存
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
// 是否启用压缩
enableCompression = flag.Bool("enable-compression", false, "是否启用压缩")
// 是否启用监控
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
// 监控地址
metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址")
// 是否添加X-Forwarded-For
addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头")
// 是否添加X-Real-IP
addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头")
// 是否启用CORS
enableCORS = flag.Bool("enable-cors", false, "是否启用CORS")
// 路径前缀
pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除")
)
func main() {
// 解析命令行参数
flag.Parse()
// 创建配置
cfg := config.DefaultConfig()
cfg.ReverseProxy = true
cfg.EnableCache = *enableCache
cfg.EnableCompression = *enableCompression
cfg.AddXForwardedFor = *addXForwardedFor
cfg.AddXRealIP = *addXRealIP
cfg.EnableCORS = *enableCORS
cfg.ReverseProxyRulesFile = *routeFile
// 创建反向代理配置
reverseCfg := reverse.DefaultConfig()
reverseCfg.ListenAddr = *addr
reverseCfg.BaseConfig.EnableCompression = *enableCompression
reverseCfg.BaseConfig.EnableCORS = *enableCORS
reverseCfg.BaseConfig.AddXForwardedFor = *addXForwardedFor
reverseCfg.BaseConfig.AddXRealIP = *addXRealIP
// 创建规则管理器
ruleManager := rule.NewManager(nil)
// 如果有路径前缀,添加重写规则
if *pathPrefix != "" {
log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix)
rewriteRule := &rule.RewriteRule{
BaseRule: rule.BaseRule{
ID: "path-prefix-rewrite",
Type: rule.RuleTypeRewrite,
Priority: 100,
Pattern: *pathPrefix,
MatchType: rule.MatchTypePath,
Enabled: true,
},
Replacement: "",
}
if err := ruleManager.AddRule(rewriteRule); err != nil {
log.Printf("添加重写规则失败: %v", err)
}
}
// 创建反向代理
reverseProxy, err := reverse.New(reverseCfg)
if err != nil {
log.Fatalf("创建反向代理失败: %v", err)
}
// 创建HTTP服务器
server := &http.Server{
Addr: *addr,
Handler: reverseProxy,
}
// 启动HTTP服务器
go func() {
log.Printf("反向代理服务器启动在 %s后端服务器为 %s\n", *addr, *backend)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("代理服务器启动失败: %v", err)
}
}()
// 等待信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭代理服务器...")
server.Close()
log.Println("代理服务器已关闭")
}
// ReverseProxyDelegate 反向代理委托
type ReverseProxyDelegate struct {
proxy.DefaultDelegate
backend string
prefix string
}
// ResolveBackend 解析后端服务器
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
// 这里可以实现基于请求路径、主机名等的路由逻辑
return d.backend, nil
}
// ModifyRequest 修改请求
func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) {
// 移除路径前缀
if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) {
req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix)
if req.URL.Path == "" {
req.URL.Path = "/"
}
}
// 添加自定义请求头
req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339))
}
// ModifyResponse 修改响应
func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error {
// 添加自定义响应头
resp.Header.Set("X-Proxied-By", "GoProxy")
return nil
}
// Connect 连接事件
func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path)
}
// BeforeRequest 请求前事件
func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) {
log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path)
}
// BeforeResponse 响应前事件
func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
if err != nil {
log.Printf("响应错误: %v\n", err)
return
}
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
}
// ErrorLog 错误日志
func (d *ReverseProxyDelegate) ErrorLog(err error) {
log.Printf("错误: %v\n", err)
}
// HandleError 处理错误
func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("处理错误: %v\n", err)
http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway)
}