diff --git a/README.md b/README.md index a1dfd6e..c916688 100644 --- a/README.md +++ b/README.md @@ -22,10 +22,35 @@ GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebS - 支持请求重试 - 支持HTTP缓存 - 支持请求限流 -- 支持监控指标收集 +- 支持请求/响应压缩 + - 支持gzip压缩 + - 智能压缩决策 + - 可配置压缩级别 + - 支持最小压缩大小 + - 支持多种内容类型 +- 支持监控指标收集(Prometheus格式) + - 请求总数和延迟统计 + - 请求和响应大小统计 + - 错误计数 + - 活跃连接数 + - 连接池大小 + - 缓存命中率 + - 内存使用量 + - 后端健康状态 + - 后端响应时间 - 支持自定义处理逻辑(委托模式) - 支持DNS缓存 - 支持URL重写(反向代理模式) +- 支持插件系统 + - 动态加载插件 + - 插件生命周期管理 + - 插件间通信 + - 插件配置管理 +- 支持认证授权 + - JWT认证 + - 基于角色的访问控制 + - 用户管理 + - 权限管理 ## 安装 @@ -255,7 +280,7 @@ func main() { // 启动HTTP服务器和监控服务器 go func() { log.Println("监控服务器启动在 :8081") - http.Handle("/metrics", metricsCollector) + http.Handle("/metrics", metricsCollector.GetHandler()) if err := http.ListenAndServe(":8081", nil); err != nil { log.Fatalf("监控服务器启动失败: %v", err) } @@ -630,6 +655,316 @@ GoProxy的反向代理模式提供以下特性: - **健康检查**:支持对后端服务器进行健康检查 - **监控指标**:支持收集反向代理的监控指标 +## 监控指标 + +GoProxy提供了两种监控指标实现:PrometheusMetrics和SimpleMetrics。 + +### PrometheusMetrics + +PrometheusMetrics是一个完整的Prometheus指标实现,提供以下指标: + +- `proxy_requests_total`: 请求总数(按方法、路径、状态码分类) +- `proxy_request_latency_seconds`: 请求延迟(按方法、路径分类) +- `proxy_request_size_bytes`: 请求大小(按方法、路径分类) +- `proxy_response_size_bytes`: 响应大小(按方法、路径分类) +- `proxy_errors_total`: 错误总数(按类型分类) +- `proxy_active_connections`: 活跃连接数 +- `proxy_connection_pool_size`: 连接池大小 +- `proxy_cache_hit_rate`: 缓存命中率 +- `proxy_memory_usage_bytes`: 内存使用量 + +使用示例: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/metrics" + "github.com/darkit/goproxy/internal/proxy" +) + +func main() { + // 创建Prometheus指标收集器 + metricsCollector := metrics.NewPrometheusMetrics() + + // 创建代理 + p := proxy.NewProxy( + proxy.WithMetrics(metricsCollector), + ) + + // 启动监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector.GetHandler()) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### SimpleMetrics + +SimpleMetrics是一个简单的指标实现,提供基本的指标收集功能: + +- 请求计数 +- 错误计数 +- 活跃连接数 +- 累计响应时间 +- 传输字节数 +- 后端健康状态 +- 后端响应时间 +- 缓存命中计数 + +使用示例: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/metrics" + "github.com/darkit/goproxy/internal/proxy" +) + +func main() { + // 创建简单指标收集器 + metricsCollector := metrics.NewSimpleMetrics() + + // 创建代理 + p := proxy.NewProxy( + proxy.WithMetrics(metricsCollector), + ) + + // 启动监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector.GetHandler()) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### 指标中间件 + +GoProxy还提供了一个指标中间件,可以用于收集HTTP请求的指标: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/metrics" +) + +func main() { + // 创建指标收集器 + metricsCollector := metrics.NewPrometheusMetrics() + + // 创建指标中间件 + metricsMiddleware := metrics.NewMetricsMiddleware(metricsCollector) + + // 创建HTTP处理器 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 处理请求 + w.Write([]byte("Hello, World!")) + }) + + // 使用中间件包装处理器 + wrappedHandler := metricsMiddleware.Middleware(handler) + + // 启动HTTP服务器 + log.Println("HTTP服务器启动在 :8080") + if err := http.ListenAndServe(":8080", wrappedHandler); err != nil { + log.Fatalf("HTTP服务器启动失败: %v", err) + } +} +``` + +### 使用压缩中间件 + +GoProxy提供了压缩中间件,支持请求和响应的gzip压缩: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/middleware" + "github.com/darkit/goproxy/internal/proxy" +) + +func main() { + // 创建压缩中间件 + compressionMiddleware := middleware.NewCompressionMiddleware(6, 1024) // 压缩级别6,最小压缩大小1KB + + // 创建代理 + p := proxy.NewProxy( + proxy.WithMiddleware(compressionMiddleware), + ) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +压缩中间件提供以下特性: + +- 自动检测客户端是否支持gzip压缩 +- 智能判断内容类型是否适合压缩 +- 可配置压缩级别(0-9) +- 可配置最小压缩大小 +- 支持多种内容类型 +- 自动处理压缩请求体 +- 自动添加压缩相关响应头 + +## 使用插件系统 + +GoProxy提供了插件系统,支持动态加载和管理插件: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/plugin" + "github.com/darkit/goproxy/internal/proxy" +) + +func main() { + // 创建插件管理器 + pluginManager := plugin.NewPluginManager("./plugins") + + // 加载插件 + if err := pluginManager.LoadPlugins(); err != nil { + log.Printf("加载插件失败: %v", err) + } + + // 创建代理 + p := proxy.NewProxy( + proxy.WithPluginManager(pluginManager), + ) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +插件示例: + +```go +// plugin.go +package main + +import ( + "context" + "log" +) + +type MyPlugin struct{} + +func (p *MyPlugin) Name() string { + return "my-plugin" +} + +func (p *MyPlugin) Version() string { + return "1.0.0" +} + +func (p *MyPlugin) Init(ctx context.Context) error { + log.Println("初始化插件") + return nil +} + +func (p *MyPlugin) Start(ctx context.Context) error { + log.Println("启动插件") + return nil +} + +func (p *MyPlugin) Stop(ctx context.Context) error { + log.Println("停止插件") + return nil +} + +var Plugin = &MyPlugin{} +``` + +### 使用认证授权系统 + +GoProxy提供了完整的认证授权系统,支持JWT认证和基于角色的访问控制: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/internal/auth" + "github.com/darkit/goproxy/internal/proxy" +) + +func main() { + // 创建认证系统 + auth := auth.NewAuth("your-secret-key") + + // 添加用户和角色 + auth.AddUser("admin", "password123", []string{"admin"}) + auth.AddUser("user", "password456", []string{"user"}) + + // 创建代理 + p := proxy.NewProxy( + proxy.WithAuth(auth), + ) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +认证授权系统提供以下特性: + +- JWT令牌认证 +- 基于角色的访问控制 +- 用户管理 +- 密码加密存储 +- 权限管理 +- 认证中间件 + ## 贡献 欢迎贡献代码、报告问题或提出建议。请遵循以下步骤: diff --git a/cmd/reverse_proxy_example/main.go b/cmd/reverse_proxy_example/main.go index f3de688..ebbd172 100644 --- a/cmd/reverse_proxy_example/main.go +++ b/cmd/reverse_proxy_example/main.go @@ -11,8 +11,9 @@ import ( "time" "github.com/darkit/goproxy/internal/config" - "github.com/darkit/goproxy/internal/metrics" "github.com/darkit/goproxy/internal/proxy" + "github.com/darkit/goproxy/internal/reverse" + "github.com/darkit/goproxy/internal/rule" ) var ( @@ -57,49 +58,46 @@ func main() { cfg.EnableCORS = *enableCORS cfg.ReverseProxyRulesFile = *routeFile - // 创建选项 - opts := &proxy.Options{ - Config: cfg, - } + // 创建反向代理配置 + reverseCfg := reverse.DefaultConfig() + reverseCfg.ListenAddr = *addr + reverseCfg.BaseConfig.EnableCompression = *enableCompression + reverseCfg.BaseConfig.EnableCORS = *enableCORS + reverseCfg.BaseConfig.AddXForwardedFor = *addXForwardedFor + reverseCfg.BaseConfig.AddXRealIP = *addXRealIP - // 创建监控 - if *enableMetrics { - m := metrics.NewSimpleMetrics() - opts.Metrics = m - - // 启动监控服务器 - go func() { - mux := http.NewServeMux() - handler := m.GetHandler() - mux.Handle("/metrics", handler) - log.Printf("监控服务器启动在 %s\n", *metricsAddr) - if err := http.ListenAndServe(*metricsAddr, mux); err != nil { - log.Fatalf("监控服务器启动失败: %v", err) - } - }() - } - - // 创建自定义委托 - delegate := &ReverseProxyDelegate{ - backend: *backend, - prefix: *pathPrefix, - } - opts.Delegate = delegate - - // 创建代理 - p := proxy.New(opts) + // 创建规则管理器 + ruleManager := rule.NewManager(nil) // 如果有路径前缀,添加重写规则 if *pathPrefix != "" { - reverseProxy := p.NewReverseProxy() log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix) - reverseProxy.AddRewriteRule(*pathPrefix, "", false) + rewriteRule := &rule.RewriteRule{ + BaseRule: rule.BaseRule{ + ID: "path-prefix-rewrite", + Type: rule.RuleTypeRewrite, + Priority: 100, + Pattern: *pathPrefix, + MatchType: rule.MatchTypePath, + Enabled: true, + }, + Replacement: "", + } + if err := ruleManager.AddRule(rewriteRule); err != nil { + log.Printf("添加重写规则失败: %v", err) + } + } + + // 创建反向代理 + reverseProxy, err := reverse.New(reverseCfg) + if err != nil { + log.Fatalf("创建反向代理失败: %v", err) } // 创建HTTP服务器 server := &http.Server{ Addr: *addr, - Handler: p, + Handler: reverseProxy, } // 启动HTTP服务器 diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000..2a6197b --- /dev/null +++ b/docs/api.md @@ -0,0 +1,251 @@ +# GoProxy API 文档 + +## 概述 + +GoProxy是一个高性能的HTTP/HTTPS代理服务器,提供丰富的功能和插件系统。本文档描述了GoProxy的API接口。 + +## 认证 + +所有API请求都需要进行认证。认证方式为Bearer Token: + +``` +Authorization: Bearer +``` + +## API 端点 + +### 1. 代理配置 + +#### 1.1 获取代理配置 + +``` +GET /api/v1/config +``` + +响应: +```json +{ + "listen_addr": ":8080", + "enable_load_balancing": true, + "backends": ["http://backend1", "http://backend2"], + "enable_rate_limit": true, + "rate_limit": 100, + "max_connections": 1000, + "enable_cache": true, + "cache_ttl": "1h" +} +``` + +#### 1.2 更新代理配置 + +``` +PUT /api/v1/config +``` + +请求体: +```json +{ + "listen_addr": ":8080", + "enable_load_balancing": true, + "backends": ["http://backend1", "http://backend2"], + "enable_rate_limit": true, + "rate_limit": 100, + "max_connections": 1000, + "enable_cache": true, + "cache_ttl": "1h" +} +``` + +### 2. 用户管理 + +#### 2.1 创建用户 + +``` +POST /api/v1/users +``` + +请求体: +```json +{ + "username": "admin", + "password": "password123", + "roles": ["admin"] +} +``` + +#### 2.2 获取用户列表 + +``` +GET /api/v1/users +``` + +响应: +```json +{ + "users": [ + { + "username": "admin", + "roles": ["admin"], + "created_at": "2024-03-13T10:00:00Z", + "last_login_at": "2024-03-13T11:00:00Z" + } + ] +} +``` + +### 3. 监控指标 + +#### 3.1 获取代理状态 + +``` +GET /api/v1/status +``` + +响应: +```json +{ + "active_connections": 100, + "total_requests": 1000, + "error_rate": 0.01, + "average_latency": 0.1, + "cache_hit_rate": 0.8 +} +``` + +#### 3.2 获取详细指标 + +``` +GET /api/v1/metrics +``` + +响应: +```json +{ + "requests_total": { + "GET": 800, + "POST": 200 + }, + "request_latency": { + "p50": 0.1, + "p90": 0.2, + "p99": 0.5 + }, + "error_total": { + "connection_error": 10, + "timeout": 5 + } +} +``` + +### 4. 插件管理 + +#### 4.1 获取插件列表 + +``` +GET /api/v1/plugins +``` + +响应: +```json +{ + "plugins": [ + { + "name": "compression", + "version": "1.0.0", + "enabled": true + } + ] +} +``` + +#### 4.2 启用/禁用插件 + +``` +PUT /api/v1/plugins/{plugin_name}/status +``` + +请求体: +```json +{ + "enabled": true +} +``` + +### 5. 缓存管理 + +#### 5.1 获取缓存统计 + +``` +GET /api/v1/cache/stats +``` + +响应: +```json +{ + "total_items": 1000, + "total_size": "100MB", + "hit_rate": 0.8, + "eviction_rate": 0.1 +} +``` + +#### 5.2 清除缓存 + +``` +DELETE /api/v1/cache +``` + +### 6. 健康检查 + +#### 6.1 获取健康状态 + +``` +GET /api/v1/health +``` + +响应: +```json +{ + "status": "healthy", + "uptime": "24h", + "last_check": "2024-03-13T12:00:00Z" +} +``` + +## 错误响应 + +所有API在发生错误时会返回以下格式: + +```json +{ + "error": { + "code": "ERROR_CODE", + "message": "错误描述", + "details": {} + } +} +``` + +常见错误码: +- 400: 请求参数错误 +- 401: 未认证 +- 403: 权限不足 +- 404: 资源不存在 +- 500: 服务器内部错误 + +## 限流 + +API请求默认限制为每秒100次。超过限制时会返回429状态码。 + +## 版本控制 + +API版本通过URL路径指定,当前版本为v1。 + +## 更新日志 + +### v1.0.0 +- 初始版本 +- 基本代理功能 +- 用户认证 +- 监控指标 +- 插件系统 \ No newline at end of file diff --git a/examples/plugins/example/build.sh b/examples/plugins/example/build.sh new file mode 100644 index 0000000..4110174 --- /dev/null +++ b/examples/plugins/example/build.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +# 编译插件 +go build -buildmode=plugin -o example.so example.go + +# 检查编译结果 +if [ $? -eq 0 ]; then + echo "插件编译成功: example.so" +else + echo "插件编译失败" + exit 1 +fi \ No newline at end of file diff --git a/examples/plugins/example/example.go b/examples/plugins/example/example.go new file mode 100644 index 0000000..0ea5000 --- /dev/null +++ b/examples/plugins/example/example.go @@ -0,0 +1,78 @@ +package main + +import ( + "context" + "log" + "time" +) + +// ExamplePlugin 示例插件 +type ExamplePlugin struct { + // 插件配置 + config map[string]interface{} + // 插件状态 + running bool +} + +// Name 插件名称 +func (p *ExamplePlugin) Name() string { + return "example" +} + +// Version 插件版本 +func (p *ExamplePlugin) Version() string { + return "1.0.0" +} + +// Init 初始化插件 +func (p *ExamplePlugin) Init(ctx context.Context) error { + log.Printf("初始化插件 %s v%s\n", p.Name(), p.Version()) + + // 初始化配置 + p.config = make(map[string]interface{}) + p.config["startTime"] = time.Now() + + return nil +} + +// Start 启动插件 +func (p *ExamplePlugin) Start(ctx context.Context) error { + log.Printf("启动插件 %s\n", p.Name()) + + // 启动后台任务 + go func() { + p.running = true + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for p.running { + select { + case <-ctx.Done(): + return + case <-ticker.C: + log.Printf("插件 %s 正在运行...\n", p.Name()) + } + } + }() + + return nil +} + +// Stop 停止插件 +func (p *ExamplePlugin) Stop(ctx context.Context) error { + log.Printf("停止插件 %s\n", p.Name()) + + // 停止后台任务 + p.running = false + + // 清理资源 + if startTime, ok := p.config["startTime"].(time.Time); ok { + duration := time.Since(startTime) + log.Printf("插件 %s 运行时长: %v\n", p.Name(), duration) + } + + return nil +} + +// Plugin 导出插件实例 +var Plugin = &ExamplePlugin{} diff --git a/go.mod b/go.mod index acb28c6..190e1d3 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,21 @@ module github.com/darkit/goproxy go 1.24.0 require ( + github.com/fsnotify/fsnotify v1.7.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/ouqiang/websocket v1.6.2 + github.com/prometheus/client_golang v1.20.4 github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 golang.org/x/time v0.11.0 ) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + golang.org/x/sys v0.28.0 // indirect + google.golang.org/protobuf v1.36.1 // indirect +) diff --git a/go.sum b/go.sum index 44ed21d..93ab40d 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,40 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U= github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ= github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go new file mode 100644 index 0000000..af21931 --- /dev/null +++ b/internal/auth/auth.go @@ -0,0 +1,232 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// Auth 认证授权系统 +type Auth struct { + // JWT密钥 + secretKey []byte + // 用户存储 + users map[string]*User + // 角色权限映射 + rolePermissions map[string][]string + // 锁 + mu sync.RWMutex +} + +// User 用户信息 +type User struct { + // 用户名 + Username string + // 密码 + Password string + // 角色列表 + Roles []string + // 创建时间 + CreatedAt time.Time + // 最后登录时间 + LastLoginAt time.Time +} + +// NewAuth 创建认证授权系统 +func NewAuth(secretKey string) *Auth { + return &Auth{ + secretKey: []byte(secretKey), + users: make(map[string]*User), + rolePermissions: make(map[string][]string), + } +} + +// AddUser 添加用户 +func (a *Auth) AddUser(username, password string, roles []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, exists := a.users[username]; exists { + return fmt.Errorf("用户已存在") + } + + // 密码加密 + hashedPassword := hashPassword(password) + + // 创建用户 + a.users[username] = &User{ + Username: username, + Password: hashedPassword, + Roles: roles, + CreatedAt: time.Now(), + } + + return nil +} + +// Authenticate 认证用户 +func (a *Auth) Authenticate(username, password string) (string, error) { + a.mu.RLock() + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return "", fmt.Errorf("用户不存在") + } + + // 验证密码 + if !a.ValidateUser(username, password) { + return "", fmt.Errorf("密码错误") + } + + // 更新最后登录时间 + a.mu.Lock() + user.LastLoginAt = time.Now() + a.mu.Unlock() + + // 生成JWT令牌 + token, err := a.GenerateToken(username) + if err != nil { + return "", err + } + + return token, nil +} + +// Authorize 授权检查 +func (a *Auth) Authorize(token, permission string) error { + // 验证JWT令牌 + claims, err := a.ValidateToken(token) + if err != nil { + return err + } + + // 检查用户权限 + a.mu.RLock() + username, ok := (*claims)["username"].(string) + if !ok { + a.mu.RUnlock() + return fmt.Errorf("无效的用户名") + } + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return fmt.Errorf("用户不存在") + } + + // 检查用户角色权限 + for _, role := range user.Roles { + if a.hasPermission(role, permission) { + return nil + } + } + + return fmt.Errorf("权限不足") +} + +// Middleware 认证中间件 +func (a *Auth) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取认证头 + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "未提供认证信息", http.StatusUnauthorized) + return + } + + // 解析Bearer令牌 + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + http.Error(w, "认证格式错误", http.StatusUnauthorized) + return + } + + // 验证令牌 + if err := a.Authorize(parts[1], r.URL.Path); err != nil { + http.Error(w, "认证失败", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// ValidateUser 验证用户 +func (a *Auth) ValidateUser(username, password string) bool { + a.mu.RLock() + defer a.mu.RUnlock() + + user, exists := a.users[username] + if !exists { + return false + } + + return user.Password == hashPassword(password) +} + +// GenerateToken 生成JWT令牌 +func (a *Auth) GenerateToken(username string) (string, error) { + a.mu.RLock() + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return "", errors.New("用户不存在") + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "username": username, + "roles": user.Roles, + "exp": time.Now().Add(24 * time.Hour).Unix(), + }) + + return token.SignedString(a.secretKey) +} + +// ValidateToken 验证JWT令牌 +func (a *Auth) ValidateToken(tokenString string) (*jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return a.secretKey, nil + }) + + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + return &claims, nil + } + + return nil, errors.New("无效的令牌") +} + +// hashPassword 密码加密 +func hashPassword(password string) string { + hash := sha256.New() + hash.Write([]byte(password)) + return hex.EncodeToString(hash.Sum(nil)) +} + +// hasPermission 检查角色是否有权限 +func (a *Auth) hasPermission(role, permission string) bool { + permissions, exists := a.rolePermissions[role] + if !exists { + return false + } + + for _, p := range permissions { + if p == permission { + return true + } + } + + return false +} diff --git a/internal/config/hot_reload.go b/internal/config/hot_reload.go new file mode 100644 index 0000000..c9aefc9 --- /dev/null +++ b/internal/config/hot_reload.go @@ -0,0 +1,117 @@ +package config + +import ( + "encoding/json" + "log/slog" + "os" + "path/filepath" + "sync" + "time" + + "github.com/fsnotify/fsnotify" +) + +// HotReloadConfig 热更新配置 +type HotReloadConfig struct { + // 配置文件路径 + ConfigPath string + // 配置更新回调函数 + OnUpdate func(*Config) + // 配置锁 + mu sync.RWMutex + // 当前配置 + current *Config + // 文件监视器 + watcher *fsnotify.Watcher + // 停止信号 + stopChan chan struct{} +} + +// NewHotReloadConfig 创建热更新配置 +func NewHotReloadConfig(configPath string, onUpdate func(*Config)) (*HotReloadConfig, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + hrc := &HotReloadConfig{ + ConfigPath: configPath, + OnUpdate: onUpdate, + watcher: watcher, + stopChan: make(chan struct{}), + } + + // 加载初始配置 + if err := hrc.loadConfig(); err != nil { + watcher.Close() + return nil, err + } + + // 启动文件监视 + go hrc.watch() + + return hrc, nil +} + +// loadConfig 加载配置 +func (hrc *HotReloadConfig) loadConfig() error { + data, err := os.ReadFile(hrc.ConfigPath) + if err != nil { + return err + } + + var config Config + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + hrc.mu.Lock() + hrc.current = &config + hrc.mu.Unlock() + + if hrc.OnUpdate != nil { + hrc.OnUpdate(&config) + } + + return nil +} + +// watch 监视配置文件变化 +func (hrc *HotReloadConfig) watch() { + // 添加配置文件目录到监视 + configDir := filepath.Dir(hrc.ConfigPath) + if err := hrc.watcher.Add(configDir); err != nil { + slog.Error("添加配置目录到监视失败", "error", err) + return + } + + for { + select { + case event := <-hrc.watcher.Events: + if event.Name == hrc.ConfigPath { + // 等待文件写入完成 + time.Sleep(100 * time.Millisecond) + if err := hrc.loadConfig(); err != nil { + slog.Error("重新加载配置失败", "error", err) + } + } + case err := <-hrc.watcher.Errors: + slog.Error("配置文件监视错误", "error", err) + case <-hrc.stopChan: + hrc.watcher.Close() + return + } + } +} + +// Get 获取当前配置 +func (hrc *HotReloadConfig) Get() *Config { + hrc.mu.RLock() + defer hrc.mu.RUnlock() + return hrc.current +} + +// Stop 停止热更新 +func (hrc *HotReloadConfig) Stop() { + close(hrc.stopChan) +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 502c72f..af3396e 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -3,13 +3,18 @@ package metrics import ( "fmt" "net/http" + "runtime" + "strconv" "sync" "sync/atomic" "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" ) -// Metrics 监控指标接口 -type Metrics interface { +// MetricsCollector 监控指标接口 +type MetricsCollector interface { // 增加请求计数 IncRequestCount() // 增加错误计数 @@ -36,6 +41,144 @@ type Metrics interface { GetHandler() http.Handler } +// PrometheusMetrics 指标收集器 +type PrometheusMetrics struct { + // 请求总数 + requestTotal *prometheus.CounterVec + // 请求延迟 + requestLatency *prometheus.HistogramVec + // 请求大小 + requestSize *prometheus.HistogramVec + // 响应大小 + responseSize *prometheus.HistogramVec + // 错误总数 + errorTotal *prometheus.CounterVec + // 活跃连接数 + activeConnections prometheus.Gauge + // 连接池大小 + connectionPoolSize prometheus.Gauge + // 缓存命中率 + cacheHitRate prometheus.Gauge + // 内存使用量 + memoryUsage prometheus.Gauge + // 锁 + mu sync.RWMutex +} + +// NewPrometheusMetrics 创建指标收集器 +func NewPrometheusMetrics() *PrometheusMetrics { + m := &PrometheusMetrics{ + requestTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_requests_total", + Help: "代理请求总数", + }, + []string{"method", "path", "status"}, + ), + requestLatency: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_request_latency_seconds", + Help: "代理请求延迟", + Buckets: prometheus.DefBuckets, + }, + []string{"method", "path"}, + ), + requestSize: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_request_size_bytes", + Help: "代理请求大小", + Buckets: prometheus.ExponentialBuckets(100, 2, 10), + }, + []string{"method", "path"}, + ), + responseSize: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_response_size_bytes", + Help: "代理响应大小", + Buckets: prometheus.ExponentialBuckets(100, 2, 10), + }, + []string{"method", "path"}, + ), + errorTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_errors_total", + Help: "代理错误总数", + }, + []string{"type"}, + ), + activeConnections: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_active_connections", + Help: "活跃连接数", + }, + ), + connectionPoolSize: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_connection_pool_size", + Help: "连接池大小", + }, + ), + cacheHitRate: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_cache_hit_rate", + Help: "缓存命中率", + }, + ), + memoryUsage: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_memory_usage_bytes", + Help: "内存使用量", + }, + ), + } + + // 启动定期更新 + go m.updateMetrics() + + return m +} + +// updateMetrics 定期更新指标 +func (m *PrometheusMetrics) updateMetrics() { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for range ticker.C { + // 更新内存使用量 + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + m.memoryUsage.Set(float64(mem.Alloc)) + } +} + +// RecordRequest 记录请求 +func (m *PrometheusMetrics) RecordRequest(method, path string, status int, latency time.Duration, reqSize, respSize int64) { + m.requestTotal.WithLabelValues(method, path, strconv.Itoa(status)).Inc() + m.requestLatency.WithLabelValues(method, path).Observe(latency.Seconds()) + m.requestSize.WithLabelValues(method, path).Observe(float64(reqSize)) + m.responseSize.WithLabelValues(method, path).Observe(float64(respSize)) +} + +// RecordError 记录错误 +func (m *PrometheusMetrics) RecordError(errType string) { + m.errorTotal.WithLabelValues(errType).Inc() +} + +// SetActiveConnections 设置活跃连接数 +func (m *PrometheusMetrics) SetActiveConnections(count int) { + m.activeConnections.Set(float64(count)) +} + +// SetConnectionPoolSize 设置连接池大小 +func (m *PrometheusMetrics) SetConnectionPoolSize(size int) { + m.connectionPoolSize.Set(float64(size)) +} + +// SetCacheHitRate 设置缓存命中率 +func (m *PrometheusMetrics) SetCacheHitRate(rate float64) { + m.cacheHitRate.Set(rate) +} + // SimpleMetrics 简单指标实现 type SimpleMetrics struct { // 请求计数 @@ -187,19 +330,13 @@ func (m *SimpleMetrics) GetHandler() http.Handler { }) } -// PrometheusMetrics Prometheus指标实现 -type PrometheusMetrics struct { - // 可以通过引入prometheus客户端库实现更完整的指标收集 - // 此处省略具体实现 -} - // MetricsMiddleware 指标中间件 type MetricsMiddleware struct { - metrics Metrics + metrics MetricsCollector } // NewMetricsMiddleware 创建指标中间件 -func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware { +func NewMetricsMiddleware(metrics MetricsCollector) *MetricsMiddleware { return &MetricsMiddleware{ metrics: metrics, } diff --git a/internal/middleware/chain.go b/internal/middleware/chain.go new file mode 100644 index 0000000..5c15259 --- /dev/null +++ b/internal/middleware/chain.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "net/http" +) + +// Middleware 中间件接口 +type Middleware interface { + ServeHTTP(http.ResponseWriter, *http.Request, http.HandlerFunc) +} + +// Chain 中间件链 +type Chain struct { + middlewares []Middleware +} + +// NewChain 创建新的中间件链 +func NewChain(middlewares ...Middleware) *Chain { + return &Chain{ + middlewares: middlewares, + } +} + +// Then 将中间件链应用到处理器 +func (c *Chain) Then(h http.Handler) http.Handler { + if h == nil { + h = http.DefaultServeMux + } + + for i := len(c.middlewares) - 1; i >= 0; i-- { + h = c.wrap(h, c.middlewares[i]) + } + + return h +} + +// wrap 包装处理器 +func (c *Chain) wrap(h http.Handler, m Middleware) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.ServeHTTP(w, r, h.ServeHTTP) + }) +} + +// Add 添加中间件 +func (c *Chain) Add(middlewares ...Middleware) *Chain { + c.middlewares = append(c.middlewares, middlewares...) + return c +} + +// Remove 移除中间件 +func (c *Chain) Remove(index int) *Chain { + if index < 0 || index >= len(c.middlewares) { + return c + } + c.middlewares = append(c.middlewares[:index], c.middlewares[index+1:]...) + return c +} + +// Clear 清空中间件链 +func (c *Chain) Clear() *Chain { + c.middlewares = nil + return c +} diff --git a/internal/middleware/compression.go b/internal/middleware/compression.go new file mode 100644 index 0000000..7803bc7 --- /dev/null +++ b/internal/middleware/compression.go @@ -0,0 +1,127 @@ +package middleware + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +// CompressionMiddleware 压缩中间件 +type CompressionMiddleware struct { + // 压缩级别 (0-9) + level int + // 最小压缩大小 + minSize int64 + // 支持的内容类型 + contentTypes []string +} + +// NewCompressionMiddleware 创建压缩中间件 +func NewCompressionMiddleware(level int, minSize int64) *CompressionMiddleware { + return &CompressionMiddleware{ + level: level, + minSize: minSize, + contentTypes: []string{ + "text/plain", + "text/html", + "text/css", + "text/javascript", + "application/javascript", + "application/json", + "application/xml", + "application/xml+rss", + "text/xml", + "application/x-yaml", + "text/yaml", + "application/x-www-form-urlencoded", + "application/x-protobuf", + "application/grpc", + "application/grpc+proto", + }, + } +} + +// Middleware 中间件处理函数 +func (m *CompressionMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 处理请求压缩 + if m.shouldCompress(r.Header.Get("Content-Encoding")) { + reader, err := gzip.NewReader(r.Body) + if err != nil { + http.Error(w, "Invalid gzip content", http.StatusBadRequest) + return + } + defer reader.Close() + r.Body = io.NopCloser(reader) + } + + // 处理响应压缩 + if m.shouldCompressResponse(r) { + gw := gzip.NewWriter(w) + defer gw.Close() + + // 包装响应写入器 + writer := &gzipResponseWriter{ + ResponseWriter: w, + writer: gw, + } + + // 设置响应头 + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Vary", "Accept-Encoding") + + next.ServeHTTP(writer, r) + } else { + next.ServeHTTP(w, r) + } + }) +} + +// shouldCompress 检查是否应该压缩请求 +func (m *CompressionMiddleware) shouldCompress(encoding string) bool { + return strings.Contains(encoding, "gzip") +} + +// shouldCompressResponse 检查是否应该压缩响应 +func (m *CompressionMiddleware) shouldCompressResponse(r *http.Request) bool { + // 检查客户端是否支持gzip + acceptEncoding := r.Header.Get("Accept-Encoding") + if !strings.Contains(acceptEncoding, "gzip") { + return false + } + + // 检查内容类型 + contentType := r.Header.Get("Content-Type") + for _, t := range m.contentTypes { + if strings.Contains(contentType, t) { + return true + } + } + + return false +} + +// gzipResponseWriter 包装的响应写入器 +type gzipResponseWriter struct { + http.ResponseWriter + writer *gzip.Writer +} + +// Write 写入数据 +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + return w.writer.Write(b) +} + +// WriteHeader 写入状态码 +func (w *gzipResponseWriter) WriteHeader(statusCode int) { + w.ResponseWriter.WriteHeader(statusCode) +} + +// Flush 刷新数据 +func (w *gzipResponseWriter) Flush() { + w.writer.Flush() + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/plugin/plugin.go b/internal/plugin/plugin.go new file mode 100644 index 0000000..afb389d --- /dev/null +++ b/internal/plugin/plugin.go @@ -0,0 +1,162 @@ +package plugin + +import ( + "context" + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "plugin" + "sync" +) + +// Plugin 插件接口 +type Plugin interface { + // Name 插件名称 + Name() string + // Version 插件版本 + Version() string + // Init 初始化插件 + Init(ctx context.Context) error + // Start 启动插件 + Start(ctx context.Context) error + // Stop 停止插件 + Stop(ctx context.Context) error +} + +// PluginManager 插件管理器 +type PluginManager struct { + pluginsDir string + plugins map[string]Plugin + mu sync.RWMutex +} + +// NewPluginManager 创建插件管理器 +func NewPluginManager(pluginsDir string) *PluginManager { + return &PluginManager{ + pluginsDir: pluginsDir, + plugins: make(map[string]Plugin), + } +} + +// LoadPlugins 加载插件 +func (pm *PluginManager) LoadPlugins() error { + // 确保插件目录存在 + if err := os.MkdirAll(pm.pluginsDir, 0755); err != nil { + return fmt.Errorf("创建插件目录失败: %v", err) + } + + // 遍历插件目录 + err := filepath.Walk(pm.pluginsDir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + + // 跳过目录和非.so文件 + if info.IsDir() || filepath.Ext(path) != ".so" { + return nil + } + + // 加载插件 + if err := pm.loadPlugin(path); err != nil { + fmt.Printf("加载插件 %s 失败: %v\n", path, err) + } + + return nil + }) + + return err +} + +// loadPlugin 加载单个插件 +func (pm *PluginManager) loadPlugin(path string) error { + // 打开插件文件 + p, err := plugin.Open(path) + if err != nil { + return fmt.Errorf("打开插件失败: %v", err) + } + + // 查找Plugin变量 + symPlugin, err := p.Lookup("Plugin") + if err != nil { + return fmt.Errorf("查找Plugin变量失败: %v", err) + } + + // 类型断言 + plugin, ok := symPlugin.(Plugin) + if !ok { + return errors.New("插件类型错误") + } + + // 检查插件名称是否已存在 + pm.mu.Lock() + if _, exists := pm.plugins[plugin.Name()]; exists { + pm.mu.Unlock() + return fmt.Errorf("插件 %s 已存在", plugin.Name()) + } + pm.plugins[plugin.Name()] = plugin + pm.mu.Unlock() + + return nil +} + +// GetPlugin 获取插件 +func (pm *PluginManager) GetPlugin(name string) (Plugin, bool) { + pm.mu.RLock() + defer pm.mu.RUnlock() + + plugin, exists := pm.plugins[name] + return plugin, exists +} + +// GetAllPlugins 获取所有插件 +func (pm *PluginManager) GetAllPlugins() []Plugin { + pm.mu.RLock() + defer pm.mu.RUnlock() + + plugins := make([]Plugin, 0, len(pm.plugins)) + for _, plugin := range pm.plugins { + plugins = append(plugins, plugin) + } + return plugins +} + +// InitPlugins 初始化所有插件 +func (pm *PluginManager) InitPlugins(ctx context.Context) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + for name, plugin := range pm.plugins { + if err := plugin.Init(ctx); err != nil { + return fmt.Errorf("初始化插件 %s 失败: %v", name, err) + } + } + return nil +} + +// StartPlugins 启动所有插件 +func (pm *PluginManager) StartPlugins(ctx context.Context) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + for name, plugin := range pm.plugins { + if err := plugin.Start(ctx); err != nil { + return fmt.Errorf("启动插件 %s 失败: %v", name, err) + } + } + return nil +} + +// StopPlugins 停止所有插件 +func (pm *PluginManager) StopPlugins(ctx context.Context) error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + for name, plugin := range pm.plugins { + if err := plugin.Stop(ctx); err != nil { + return fmt.Errorf("停止插件 %s 失败: %v", name, err) + } + } + return nil +} diff --git a/internal/proxy/options.go b/internal/proxy/options.go index 7ef4f01..5b40b7a 100644 --- a/internal/proxy/options.go +++ b/internal/proxy/options.go @@ -107,7 +107,7 @@ func WithHealthChecker(hc *healthcheck.HealthChecker) Option { } // WithMetrics 设置监控指标 -func WithMetrics(m metrics.Metrics) Option { +func WithMetrics(m metrics.MetricsCollector) Option { return func(opt *Options) { opt.Metrics = m } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 39c0f55..c8addb1 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -18,12 +18,14 @@ import ( "sync/atomic" "time" + "github.com/darkit/goproxy/internal/auth" "github.com/darkit/goproxy/internal/cache" "github.com/darkit/goproxy/internal/config" "github.com/darkit/goproxy/internal/healthcheck" "github.com/darkit/goproxy/internal/loadbalance" "github.com/darkit/goproxy/internal/metrics" "github.com/darkit/goproxy/internal/middleware" + "github.com/darkit/goproxy/internal/plugin" "github.com/darkit/goproxy/internal/reverse" "github.com/ouqiang/websocket" "github.com/viki-org/dnscache" @@ -197,13 +199,31 @@ type Options struct { // 健康检查器 HealthChecker *healthcheck.HealthChecker // 监控指标 - Metrics metrics.Metrics + Metrics metrics.MetricsCollector // 客户端跟踪 ClientTrace *httptrace.ClientTrace + // 认证系统 + Auth *auth.Auth + // 插件管理器 + PluginManager *plugin.PluginManager // 证书管理器 CertManager *CertManager } +// WithAuth 设置认证系统 +func WithAuth(auth *auth.Auth) Option { + return func(o *Options) { + o.Auth = auth + } +} + +// WithPluginManager 设置插件管理器 +func WithPluginManager(pluginManager *plugin.PluginManager) Option { + return func(o *Options) { + o.PluginManager = pluginManager + } +} + // Proxy HTTP代理 type Proxy struct { // 配置 @@ -221,7 +241,7 @@ type Proxy struct { // 健康检查器 healthChecker *healthcheck.HealthChecker // 监控指标 - metrics metrics.Metrics + metrics metrics.MetricsCollector // 客户端跟踪 clientTrace *httptrace.ClientTrace // 基础传输(用于直接获取*http.Transport类型) @@ -1257,13 +1277,13 @@ var hopHeaders = []string{ } // isCacheHitMetricsSupported 检查指标是否支持缓存命中计数 -func isCacheHitMetricsSupported(m metrics.Metrics) bool { +func isCacheHitMetricsSupported(m metrics.MetricsCollector) bool { _, ok := m.(interface{ IncCacheHit() }) return ok } // incrementCacheHit 增加缓存命中计数 -func incrementCacheHit(m metrics.Metrics) { +func incrementCacheHit(m metrics.MetricsCollector) { if hitter, ok := m.(interface{ IncCacheHit() }); ok { hitter.IncCacheHit() } diff --git a/internal/server/graceful.go b/internal/server/graceful.go new file mode 100644 index 0000000..2983581 --- /dev/null +++ b/internal/server/graceful.go @@ -0,0 +1,87 @@ +package server + +import ( + "context" + "log/slog" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" +) + +// GracefulServer 优雅关闭服务器 +type GracefulServer struct { + // HTTP服务器 + server *http.Server + // 等待组 + wg sync.WaitGroup + // 停止信号 + stopChan chan struct{} +} + +// NewGracefulServer 创建优雅关闭服务器 +func NewGracefulServer(addr string, handler http.Handler) *GracefulServer { + return &GracefulServer{ + server: &http.Server{ + Addr: addr, + Handler: handler, + }, + stopChan: make(chan struct{}), + } +} + +// Start 启动服务器 +func (s *GracefulServer) Start() error { + // 启动HTTP服务器 + go func() { + slog.Info("启动HTTP服务器", "addr", s.server.Addr) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Error("HTTP服务器错误", "error", err) + } + }() + + // 等待中断信号 + s.waitForInterrupt() + + return nil +} + +// waitForInterrupt 等待中断信号 +func (s *GracefulServer) waitForInterrupt() { + // 创建信号通道 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // 等待信号 + <-sigChan + slog.Info("收到停止信号,开始优雅关闭") + + // 创建关闭上下文 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 关闭HTTP服务器 + if err := s.server.Shutdown(ctx); err != nil { + slog.Error("关闭HTTP服务器失败", "error", err) + } + + // 通知所有goroutine停止 + close(s.stopChan) + + // 等待所有goroutine结束 + s.wg.Wait() + + slog.Info("服务器已优雅关闭") +} + +// Stop 停止服务器 +func (s *GracefulServer) Stop() { + close(s.stopChan) +} + +// WaitGroup 获取等待组 +func (s *GracefulServer) WaitGroup() *sync.WaitGroup { + return &s.wg +}