1.  监控指标收集
2.  中间件机制
3.  配置热更新
4.  优雅关闭
5.  插件系统
6.  API文档
7.  认证授权系统
8.  请求/响应压缩优化
This commit is contained in:
2025-03-13 22:58:39 +08:00
parent 35f492b1c5
commit 7efc72b362
16 changed files with 1718 additions and 51 deletions

339
README.md
View File

@@ -22,10 +22,35 @@ GoProxy是一个功能强大的Go语言HTTP代理库支持HTTP、HTTPS和WebS
- 支持请求重试
- 支持HTTP缓存
- 支持请求限流
- 支持监控指标收集
- 支持请求/响应压缩
- 支持gzip压缩
- 智能压缩决策
- 可配置压缩级别
- 支持最小压缩大小
- 支持多种内容类型
- 支持监控指标收集Prometheus格式
- 请求总数和延迟统计
- 请求和响应大小统计
- 错误计数
- 活跃连接数
- 连接池大小
- 缓存命中率
- 内存使用量
- 后端健康状态
- 后端响应时间
- 支持自定义处理逻辑(委托模式)
- 支持DNS缓存
- 支持URL重写反向代理模式
- 支持插件系统
- 动态加载插件
- 插件生命周期管理
- 插件间通信
- 插件配置管理
- 支持认证授权
- JWT认证
- 基于角色的访问控制
- 用户管理
- 权限管理
## 安装
@@ -255,7 +280,7 @@ func main() {
// 启动HTTP服务器和监控服务器
go func() {
log.Println("监控服务器启动在 :8081")
http.Handle("/metrics", metricsCollector)
http.Handle("/metrics", metricsCollector.GetHandler())
if err := http.ListenAndServe(":8081", nil); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
@@ -630,6 +655,316 @@ GoProxy的反向代理模式提供以下特性
- **健康检查**:支持对后端服务器进行健康检查
- **监控指标**:支持收集反向代理的监控指标
## 监控指标
GoProxy提供了两种监控指标实现PrometheusMetrics和SimpleMetrics。
### PrometheusMetrics
PrometheusMetrics是一个完整的Prometheus指标实现提供以下指标
- `proxy_requests_total`: 请求总数(按方法、路径、状态码分类)
- `proxy_request_latency_seconds`: 请求延迟(按方法、路径分类)
- `proxy_request_size_bytes`: 请求大小(按方法、路径分类)
- `proxy_response_size_bytes`: 响应大小(按方法、路径分类)
- `proxy_errors_total`: 错误总数(按类型分类)
- `proxy_active_connections`: 活跃连接数
- `proxy_connection_pool_size`: 连接池大小
- `proxy_cache_hit_rate`: 缓存命中率
- `proxy_memory_usage_bytes`: 内存使用量
使用示例:
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/metrics"
"github.com/darkit/goproxy/internal/proxy"
)
func main() {
// 创建Prometheus指标收集器
metricsCollector := metrics.NewPrometheusMetrics()
// 创建代理
p := proxy.NewProxy(
proxy.WithMetrics(metricsCollector),
)
// 启动监控服务器
go func() {
log.Println("监控服务器启动在 :8081")
http.Handle("/metrics", metricsCollector.GetHandler())
if err := http.ListenAndServe(":8081", nil); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
// 启动代理服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
### SimpleMetrics
SimpleMetrics是一个简单的指标实现提供基本的指标收集功能
- 请求计数
- 错误计数
- 活跃连接数
- 累计响应时间
- 传输字节数
- 后端健康状态
- 后端响应时间
- 缓存命中计数
使用示例:
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/metrics"
"github.com/darkit/goproxy/internal/proxy"
)
func main() {
// 创建简单指标收集器
metricsCollector := metrics.NewSimpleMetrics()
// 创建代理
p := proxy.NewProxy(
proxy.WithMetrics(metricsCollector),
)
// 启动监控服务器
go func() {
log.Println("监控服务器启动在 :8081")
http.Handle("/metrics", metricsCollector.GetHandler())
if err := http.ListenAndServe(":8081", nil); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
// 启动代理服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
### 指标中间件
GoProxy还提供了一个指标中间件可以用于收集HTTP请求的指标
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/metrics"
)
func main() {
// 创建指标收集器
metricsCollector := metrics.NewPrometheusMetrics()
// 创建指标中间件
metricsMiddleware := metrics.NewMetricsMiddleware(metricsCollector)
// 创建HTTP处理器
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 处理请求
w.Write([]byte("Hello, World!"))
})
// 使用中间件包装处理器
wrappedHandler := metricsMiddleware.Middleware(handler)
// 启动HTTP服务器
log.Println("HTTP服务器启动在 :8080")
if err := http.ListenAndServe(":8080", wrappedHandler); err != nil {
log.Fatalf("HTTP服务器启动失败: %v", err)
}
}
```
### 使用压缩中间件
GoProxy提供了压缩中间件支持请求和响应的gzip压缩
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/middleware"
"github.com/darkit/goproxy/internal/proxy"
)
func main() {
// 创建压缩中间件
compressionMiddleware := middleware.NewCompressionMiddleware(6, 1024) // 压缩级别6最小压缩大小1KB
// 创建代理
p := proxy.NewProxy(
proxy.WithMiddleware(compressionMiddleware),
)
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
压缩中间件提供以下特性:
- 自动检测客户端是否支持gzip压缩
- 智能判断内容类型是否适合压缩
- 可配置压缩级别0-9
- 可配置最小压缩大小
- 支持多种内容类型
- 自动处理压缩请求体
- 自动添加压缩相关响应头
## 使用插件系统
GoProxy提供了插件系统,支持动态加载和管理插件:
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/plugin"
"github.com/darkit/goproxy/internal/proxy"
)
func main() {
// 创建插件管理器
pluginManager := plugin.NewPluginManager("./plugins")
// 加载插件
if err := pluginManager.LoadPlugins(); err != nil {
log.Printf("加载插件失败: %v", err)
}
// 创建代理
p := proxy.NewProxy(
proxy.WithPluginManager(pluginManager),
)
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
插件示例:
```go
// plugin.go
package main
import (
"context"
"log"
)
type MyPlugin struct{}
func (p *MyPlugin) Name() string {
return "my-plugin"
}
func (p *MyPlugin) Version() string {
return "1.0.0"
}
func (p *MyPlugin) Init(ctx context.Context) error {
log.Println("初始化插件")
return nil
}
func (p *MyPlugin) Start(ctx context.Context) error {
log.Println("启动插件")
return nil
}
func (p *MyPlugin) Stop(ctx context.Context) error {
log.Println("停止插件")
return nil
}
var Plugin = &MyPlugin{}
```
### 使用认证授权系统
GoProxy提供了完整的认证授权系统,支持JWT认证和基于角色的访问控制:
```go
package main
import (
"log"
"net/http"
"github.com/darkit/goproxy/internal/auth"
"github.com/darkit/goproxy/internal/proxy"
)
func main() {
// 创建认证系统
auth := auth.NewAuth("your-secret-key")
// 添加用户和角色
auth.AddUser("admin", "password123", []string{"admin"})
auth.AddUser("user", "password456", []string{"user"})
// 创建代理
p := proxy.NewProxy(
proxy.WithAuth(auth),
)
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
认证授权系统提供以下特性:
- JWT令牌认证
- 基于角色的访问控制
- 用户管理
- 密码加密存储
- 权限管理
- 认证中间件
## 贡献
欢迎贡献代码、报告问题或提出建议。请遵循以下步骤:

View File

@@ -11,8 +11,9 @@ import (
"time"
"github.com/darkit/goproxy/internal/config"
"github.com/darkit/goproxy/internal/metrics"
"github.com/darkit/goproxy/internal/proxy"
"github.com/darkit/goproxy/internal/reverse"
"github.com/darkit/goproxy/internal/rule"
)
var (
@@ -57,49 +58,46 @@ func main() {
cfg.EnableCORS = *enableCORS
cfg.ReverseProxyRulesFile = *routeFile
// 创建选项
opts := &proxy.Options{
Config: cfg,
}
// 创建反向代理配置
reverseCfg := reverse.DefaultConfig()
reverseCfg.ListenAddr = *addr
reverseCfg.BaseConfig.EnableCompression = *enableCompression
reverseCfg.BaseConfig.EnableCORS = *enableCORS
reverseCfg.BaseConfig.AddXForwardedFor = *addXForwardedFor
reverseCfg.BaseConfig.AddXRealIP = *addXRealIP
// 创建监控
if *enableMetrics {
m := metrics.NewSimpleMetrics()
opts.Metrics = m
// 启动监控服务器
go func() {
mux := http.NewServeMux()
handler := m.GetHandler()
mux.Handle("/metrics", handler)
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
}
// 创建自定义委托
delegate := &ReverseProxyDelegate{
backend: *backend,
prefix: *pathPrefix,
}
opts.Delegate = delegate
// 创建代理
p := proxy.New(opts)
// 创建规则管理器
ruleManager := rule.NewManager(nil)
// 如果有路径前缀,添加重写规则
if *pathPrefix != "" {
reverseProxy := p.NewReverseProxy()
log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix)
reverseProxy.AddRewriteRule(*pathPrefix, "", false)
rewriteRule := &rule.RewriteRule{
BaseRule: rule.BaseRule{
ID: "path-prefix-rewrite",
Type: rule.RuleTypeRewrite,
Priority: 100,
Pattern: *pathPrefix,
MatchType: rule.MatchTypePath,
Enabled: true,
},
Replacement: "",
}
if err := ruleManager.AddRule(rewriteRule); err != nil {
log.Printf("添加重写规则失败: %v", err)
}
}
// 创建反向代理
reverseProxy, err := reverse.New(reverseCfg)
if err != nil {
log.Fatalf("创建反向代理失败: %v", err)
}
// 创建HTTP服务器
server := &http.Server{
Addr: *addr,
Handler: p,
Handler: reverseProxy,
}
// 启动HTTP服务器

251
docs/api.md Normal file
View File

@@ -0,0 +1,251 @@
# GoProxy API 文档
## 概述
GoProxy是一个高性能的HTTP/HTTPS代理服务器提供丰富的功能和插件系统。本文档描述了GoProxy的API接口。
## 认证
所有API请求都需要进行认证。认证方式为Bearer Token
```
Authorization: Bearer <your_token>
```
## API 端点
### 1. 代理配置
#### 1.1 获取代理配置
```
GET /api/v1/config
```
响应:
```json
{
"listen_addr": ":8080",
"enable_load_balancing": true,
"backends": ["http://backend1", "http://backend2"],
"enable_rate_limit": true,
"rate_limit": 100,
"max_connections": 1000,
"enable_cache": true,
"cache_ttl": "1h"
}
```
#### 1.2 更新代理配置
```
PUT /api/v1/config
```
请求体:
```json
{
"listen_addr": ":8080",
"enable_load_balancing": true,
"backends": ["http://backend1", "http://backend2"],
"enable_rate_limit": true,
"rate_limit": 100,
"max_connections": 1000,
"enable_cache": true,
"cache_ttl": "1h"
}
```
### 2. 用户管理
#### 2.1 创建用户
```
POST /api/v1/users
```
请求体:
```json
{
"username": "admin",
"password": "password123",
"roles": ["admin"]
}
```
#### 2.2 获取用户列表
```
GET /api/v1/users
```
响应:
```json
{
"users": [
{
"username": "admin",
"roles": ["admin"],
"created_at": "2024-03-13T10:00:00Z",
"last_login_at": "2024-03-13T11:00:00Z"
}
]
}
```
### 3. 监控指标
#### 3.1 获取代理状态
```
GET /api/v1/status
```
响应:
```json
{
"active_connections": 100,
"total_requests": 1000,
"error_rate": 0.01,
"average_latency": 0.1,
"cache_hit_rate": 0.8
}
```
#### 3.2 获取详细指标
```
GET /api/v1/metrics
```
响应:
```json
{
"requests_total": {
"GET": 800,
"POST": 200
},
"request_latency": {
"p50": 0.1,
"p90": 0.2,
"p99": 0.5
},
"error_total": {
"connection_error": 10,
"timeout": 5
}
}
```
### 4. 插件管理
#### 4.1 获取插件列表
```
GET /api/v1/plugins
```
响应:
```json
{
"plugins": [
{
"name": "compression",
"version": "1.0.0",
"enabled": true
}
]
}
```
#### 4.2 启用/禁用插件
```
PUT /api/v1/plugins/{plugin_name}/status
```
请求体:
```json
{
"enabled": true
}
```
### 5. 缓存管理
#### 5.1 获取缓存统计
```
GET /api/v1/cache/stats
```
响应:
```json
{
"total_items": 1000,
"total_size": "100MB",
"hit_rate": 0.8,
"eviction_rate": 0.1
}
```
#### 5.2 清除缓存
```
DELETE /api/v1/cache
```
### 6. 健康检查
#### 6.1 获取健康状态
```
GET /api/v1/health
```
响应:
```json
{
"status": "healthy",
"uptime": "24h",
"last_check": "2024-03-13T12:00:00Z"
}
```
## 错误响应
所有API在发生错误时会返回以下格式
```json
{
"error": {
"code": "ERROR_CODE",
"message": "错误描述",
"details": {}
}
}
```
常见错误码:
- 400: 请求参数错误
- 401: 未认证
- 403: 权限不足
- 404: 资源不存在
- 500: 服务器内部错误
## 限流
API请求默认限制为每秒100次。超过限制时会返回429状态码。
## 版本控制
API版本通过URL路径指定当前版本为v1。
## 更新日志
### v1.0.0
- 初始版本
- 基本代理功能
- 用户认证
- 监控指标
- 插件系统

View File

@@ -0,0 +1,12 @@
#!/bin/bash
# 编译插件
go build -buildmode=plugin -o example.so example.go
# 检查编译结果
if [ $? -eq 0 ]; then
echo "插件编译成功: example.so"
else
echo "插件编译失败"
exit 1
fi

View File

@@ -0,0 +1,78 @@
package main
import (
"context"
"log"
"time"
)
// ExamplePlugin 示例插件
type ExamplePlugin struct {
// 插件配置
config map[string]interface{}
// 插件状态
running bool
}
// Name 插件名称
func (p *ExamplePlugin) Name() string {
return "example"
}
// Version 插件版本
func (p *ExamplePlugin) Version() string {
return "1.0.0"
}
// Init 初始化插件
func (p *ExamplePlugin) Init(ctx context.Context) error {
log.Printf("初始化插件 %s v%s\n", p.Name(), p.Version())
// 初始化配置
p.config = make(map[string]interface{})
p.config["startTime"] = time.Now()
return nil
}
// Start 启动插件
func (p *ExamplePlugin) Start(ctx context.Context) error {
log.Printf("启动插件 %s\n", p.Name())
// 启动后台任务
go func() {
p.running = true
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for p.running {
select {
case <-ctx.Done():
return
case <-ticker.C:
log.Printf("插件 %s 正在运行...\n", p.Name())
}
}
}()
return nil
}
// Stop 停止插件
func (p *ExamplePlugin) Stop(ctx context.Context) error {
log.Printf("停止插件 %s\n", p.Name())
// 停止后台任务
p.running = false
// 清理资源
if startTime, ok := p.config["startTime"].(time.Time); ok {
duration := time.Since(startTime)
log.Printf("插件 %s 运行时长: %v\n", p.Name(), duration)
}
return nil
}
// Plugin 导出插件实例
var Plugin = &ExamplePlugin{}

14
go.mod
View File

@@ -3,7 +3,21 @@ module github.com/darkit/goproxy
go 1.24.0
require (
github.com/fsnotify/fsnotify v1.7.0
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/ouqiang/websocket v1.6.2
github.com/prometheus/client_golang v1.20.4
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
golang.org/x/time v0.11.0
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
golang.org/x/sys v0.28.0 // indirect
google.golang.org/protobuf v1.36.1 // indirect
)

34
go.sum
View File

@@ -1,6 +1,40 @@
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

232
internal/auth/auth.go Normal file
View File

@@ -0,0 +1,232 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
)
// Auth 认证授权系统
type Auth struct {
// JWT密钥
secretKey []byte
// 用户存储
users map[string]*User
// 角色权限映射
rolePermissions map[string][]string
// 锁
mu sync.RWMutex
}
// User 用户信息
type User struct {
// 用户名
Username string
// 密码
Password string
// 角色列表
Roles []string
// 创建时间
CreatedAt time.Time
// 最后登录时间
LastLoginAt time.Time
}
// NewAuth 创建认证授权系统
func NewAuth(secretKey string) *Auth {
return &Auth{
secretKey: []byte(secretKey),
users: make(map[string]*User),
rolePermissions: make(map[string][]string),
}
}
// AddUser 添加用户
func (a *Auth) AddUser(username, password string, roles []string) error {
a.mu.Lock()
defer a.mu.Unlock()
if _, exists := a.users[username]; exists {
return fmt.Errorf("用户已存在")
}
// 密码加密
hashedPassword := hashPassword(password)
// 创建用户
a.users[username] = &User{
Username: username,
Password: hashedPassword,
Roles: roles,
CreatedAt: time.Now(),
}
return nil
}
// Authenticate 认证用户
func (a *Auth) Authenticate(username, password string) (string, error) {
a.mu.RLock()
user, exists := a.users[username]
a.mu.RUnlock()
if !exists {
return "", fmt.Errorf("用户不存在")
}
// 验证密码
if !a.ValidateUser(username, password) {
return "", fmt.Errorf("密码错误")
}
// 更新最后登录时间
a.mu.Lock()
user.LastLoginAt = time.Now()
a.mu.Unlock()
// 生成JWT令牌
token, err := a.GenerateToken(username)
if err != nil {
return "", err
}
return token, nil
}
// Authorize 授权检查
func (a *Auth) Authorize(token, permission string) error {
// 验证JWT令牌
claims, err := a.ValidateToken(token)
if err != nil {
return err
}
// 检查用户权限
a.mu.RLock()
username, ok := (*claims)["username"].(string)
if !ok {
a.mu.RUnlock()
return fmt.Errorf("无效的用户名")
}
user, exists := a.users[username]
a.mu.RUnlock()
if !exists {
return fmt.Errorf("用户不存在")
}
// 检查用户角色权限
for _, role := range user.Roles {
if a.hasPermission(role, permission) {
return nil
}
}
return fmt.Errorf("权限不足")
}
// Middleware 认证中间件
func (a *Auth) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取认证头
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, "未提供认证信息", http.StatusUnauthorized)
return
}
// 解析Bearer令牌
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
http.Error(w, "认证格式错误", http.StatusUnauthorized)
return
}
// 验证令牌
if err := a.Authorize(parts[1], r.URL.Path); err != nil {
http.Error(w, "认证失败", http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
// ValidateUser 验证用户
func (a *Auth) ValidateUser(username, password string) bool {
a.mu.RLock()
defer a.mu.RUnlock()
user, exists := a.users[username]
if !exists {
return false
}
return user.Password == hashPassword(password)
}
// GenerateToken 生成JWT令牌
func (a *Auth) GenerateToken(username string) (string, error) {
a.mu.RLock()
user, exists := a.users[username]
a.mu.RUnlock()
if !exists {
return "", errors.New("用户不存在")
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"username": username,
"roles": user.Roles,
"exp": time.Now().Add(24 * time.Hour).Unix(),
})
return token.SignedString(a.secretKey)
}
// ValidateToken 验证JWT令牌
func (a *Auth) ValidateToken(tokenString string) (*jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return a.secretKey, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
return &claims, nil
}
return nil, errors.New("无效的令牌")
}
// hashPassword 密码加密
func hashPassword(password string) string {
hash := sha256.New()
hash.Write([]byte(password))
return hex.EncodeToString(hash.Sum(nil))
}
// hasPermission 检查角色是否有权限
func (a *Auth) hasPermission(role, permission string) bool {
permissions, exists := a.rolePermissions[role]
if !exists {
return false
}
for _, p := range permissions {
if p == permission {
return true
}
}
return false
}

View File

@@ -0,0 +1,117 @@
package config
import (
"encoding/json"
"log/slog"
"os"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
// HotReloadConfig 热更新配置
type HotReloadConfig struct {
// 配置文件路径
ConfigPath string
// 配置更新回调函数
OnUpdate func(*Config)
// 配置锁
mu sync.RWMutex
// 当前配置
current *Config
// 文件监视器
watcher *fsnotify.Watcher
// 停止信号
stopChan chan struct{}
}
// NewHotReloadConfig 创建热更新配置
func NewHotReloadConfig(configPath string, onUpdate func(*Config)) (*HotReloadConfig, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
hrc := &HotReloadConfig{
ConfigPath: configPath,
OnUpdate: onUpdate,
watcher: watcher,
stopChan: make(chan struct{}),
}
// 加载初始配置
if err := hrc.loadConfig(); err != nil {
watcher.Close()
return nil, err
}
// 启动文件监视
go hrc.watch()
return hrc, nil
}
// loadConfig 加载配置
func (hrc *HotReloadConfig) loadConfig() error {
data, err := os.ReadFile(hrc.ConfigPath)
if err != nil {
return err
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
return err
}
hrc.mu.Lock()
hrc.current = &config
hrc.mu.Unlock()
if hrc.OnUpdate != nil {
hrc.OnUpdate(&config)
}
return nil
}
// watch 监视配置文件变化
func (hrc *HotReloadConfig) watch() {
// 添加配置文件目录到监视
configDir := filepath.Dir(hrc.ConfigPath)
if err := hrc.watcher.Add(configDir); err != nil {
slog.Error("添加配置目录到监视失败", "error", err)
return
}
for {
select {
case event := <-hrc.watcher.Events:
if event.Name == hrc.ConfigPath {
// 等待文件写入完成
time.Sleep(100 * time.Millisecond)
if err := hrc.loadConfig(); err != nil {
slog.Error("重新加载配置失败", "error", err)
}
}
case err := <-hrc.watcher.Errors:
slog.Error("配置文件监视错误", "error", err)
case <-hrc.stopChan:
hrc.watcher.Close()
return
}
}
}
// Get 获取当前配置
func (hrc *HotReloadConfig) Get() *Config {
hrc.mu.RLock()
defer hrc.mu.RUnlock()
return hrc.current
}
// Stop 停止热更新
func (hrc *HotReloadConfig) Stop() {
close(hrc.stopChan)
}

View File

@@ -3,13 +3,18 @@ package metrics
import (
"fmt"
"net/http"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// Metrics 监控指标接口
type Metrics interface {
// MetricsCollector 监控指标接口
type MetricsCollector interface {
// 增加请求计数
IncRequestCount()
// 增加错误计数
@@ -36,6 +41,144 @@ type Metrics interface {
GetHandler() http.Handler
}
// PrometheusMetrics 指标收集器
type PrometheusMetrics struct {
// 请求总数
requestTotal *prometheus.CounterVec
// 请求延迟
requestLatency *prometheus.HistogramVec
// 请求大小
requestSize *prometheus.HistogramVec
// 响应大小
responseSize *prometheus.HistogramVec
// 错误总数
errorTotal *prometheus.CounterVec
// 活跃连接数
activeConnections prometheus.Gauge
// 连接池大小
connectionPoolSize prometheus.Gauge
// 缓存命中率
cacheHitRate prometheus.Gauge
// 内存使用量
memoryUsage prometheus.Gauge
// 锁
mu sync.RWMutex
}
// NewPrometheusMetrics 创建指标收集器
func NewPrometheusMetrics() *PrometheusMetrics {
m := &PrometheusMetrics{
requestTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "proxy_requests_total",
Help: "代理请求总数",
},
[]string{"method", "path", "status"},
),
requestLatency: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "proxy_request_latency_seconds",
Help: "代理请求延迟",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path"},
),
requestSize: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "proxy_request_size_bytes",
Help: "代理请求大小",
Buckets: prometheus.ExponentialBuckets(100, 2, 10),
},
[]string{"method", "path"},
),
responseSize: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "proxy_response_size_bytes",
Help: "代理响应大小",
Buckets: prometheus.ExponentialBuckets(100, 2, 10),
},
[]string{"method", "path"},
),
errorTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "proxy_errors_total",
Help: "代理错误总数",
},
[]string{"type"},
),
activeConnections: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "proxy_active_connections",
Help: "活跃连接数",
},
),
connectionPoolSize: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "proxy_connection_pool_size",
Help: "连接池大小",
},
),
cacheHitRate: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "proxy_cache_hit_rate",
Help: "缓存命中率",
},
),
memoryUsage: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "proxy_memory_usage_bytes",
Help: "内存使用量",
},
),
}
// 启动定期更新
go m.updateMetrics()
return m
}
// updateMetrics 定期更新指标
func (m *PrometheusMetrics) updateMetrics() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for range ticker.C {
// 更新内存使用量
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
m.memoryUsage.Set(float64(mem.Alloc))
}
}
// RecordRequest 记录请求
func (m *PrometheusMetrics) RecordRequest(method, path string, status int, latency time.Duration, reqSize, respSize int64) {
m.requestTotal.WithLabelValues(method, path, strconv.Itoa(status)).Inc()
m.requestLatency.WithLabelValues(method, path).Observe(latency.Seconds())
m.requestSize.WithLabelValues(method, path).Observe(float64(reqSize))
m.responseSize.WithLabelValues(method, path).Observe(float64(respSize))
}
// RecordError 记录错误
func (m *PrometheusMetrics) RecordError(errType string) {
m.errorTotal.WithLabelValues(errType).Inc()
}
// SetActiveConnections 设置活跃连接数
func (m *PrometheusMetrics) SetActiveConnections(count int) {
m.activeConnections.Set(float64(count))
}
// SetConnectionPoolSize 设置连接池大小
func (m *PrometheusMetrics) SetConnectionPoolSize(size int) {
m.connectionPoolSize.Set(float64(size))
}
// SetCacheHitRate 设置缓存命中率
func (m *PrometheusMetrics) SetCacheHitRate(rate float64) {
m.cacheHitRate.Set(rate)
}
// SimpleMetrics 简单指标实现
type SimpleMetrics struct {
// 请求计数
@@ -187,19 +330,13 @@ func (m *SimpleMetrics) GetHandler() http.Handler {
})
}
// PrometheusMetrics Prometheus指标实现
type PrometheusMetrics struct {
// 可以通过引入prometheus客户端库实现更完整的指标收集
// 此处省略具体实现
}
// MetricsMiddleware 指标中间件
type MetricsMiddleware struct {
metrics Metrics
metrics MetricsCollector
}
// NewMetricsMiddleware 创建指标中间件
func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware {
func NewMetricsMiddleware(metrics MetricsCollector) *MetricsMiddleware {
return &MetricsMiddleware{
metrics: metrics,
}

View File

@@ -0,0 +1,63 @@
package middleware
import (
"net/http"
)
// Middleware 中间件接口
type Middleware interface {
ServeHTTP(http.ResponseWriter, *http.Request, http.HandlerFunc)
}
// Chain 中间件链
type Chain struct {
middlewares []Middleware
}
// NewChain 创建新的中间件链
func NewChain(middlewares ...Middleware) *Chain {
return &Chain{
middlewares: middlewares,
}
}
// Then 将中间件链应用到处理器
func (c *Chain) Then(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}
for i := len(c.middlewares) - 1; i >= 0; i-- {
h = c.wrap(h, c.middlewares[i])
}
return h
}
// wrap 包装处理器
func (c *Chain) wrap(h http.Handler, m Middleware) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.ServeHTTP(w, r, h.ServeHTTP)
})
}
// Add 添加中间件
func (c *Chain) Add(middlewares ...Middleware) *Chain {
c.middlewares = append(c.middlewares, middlewares...)
return c
}
// Remove 移除中间件
func (c *Chain) Remove(index int) *Chain {
if index < 0 || index >= len(c.middlewares) {
return c
}
c.middlewares = append(c.middlewares[:index], c.middlewares[index+1:]...)
return c
}
// Clear 清空中间件链
func (c *Chain) Clear() *Chain {
c.middlewares = nil
return c
}

View File

@@ -0,0 +1,127 @@
package middleware
import (
"compress/gzip"
"io"
"net/http"
"strings"
)
// CompressionMiddleware 压缩中间件
type CompressionMiddleware struct {
// 压缩级别 (0-9)
level int
// 最小压缩大小
minSize int64
// 支持的内容类型
contentTypes []string
}
// NewCompressionMiddleware 创建压缩中间件
func NewCompressionMiddleware(level int, minSize int64) *CompressionMiddleware {
return &CompressionMiddleware{
level: level,
minSize: minSize,
contentTypes: []string{
"text/plain",
"text/html",
"text/css",
"text/javascript",
"application/javascript",
"application/json",
"application/xml",
"application/xml+rss",
"text/xml",
"application/x-yaml",
"text/yaml",
"application/x-www-form-urlencoded",
"application/x-protobuf",
"application/grpc",
"application/grpc+proto",
},
}
}
// Middleware 中间件处理函数
func (m *CompressionMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 处理请求压缩
if m.shouldCompress(r.Header.Get("Content-Encoding")) {
reader, err := gzip.NewReader(r.Body)
if err != nil {
http.Error(w, "Invalid gzip content", http.StatusBadRequest)
return
}
defer reader.Close()
r.Body = io.NopCloser(reader)
}
// 处理响应压缩
if m.shouldCompressResponse(r) {
gw := gzip.NewWriter(w)
defer gw.Close()
// 包装响应写入器
writer := &gzipResponseWriter{
ResponseWriter: w,
writer: gw,
}
// 设置响应头
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
next.ServeHTTP(writer, r)
} else {
next.ServeHTTP(w, r)
}
})
}
// shouldCompress 检查是否应该压缩请求
func (m *CompressionMiddleware) shouldCompress(encoding string) bool {
return strings.Contains(encoding, "gzip")
}
// shouldCompressResponse 检查是否应该压缩响应
func (m *CompressionMiddleware) shouldCompressResponse(r *http.Request) bool {
// 检查客户端是否支持gzip
acceptEncoding := r.Header.Get("Accept-Encoding")
if !strings.Contains(acceptEncoding, "gzip") {
return false
}
// 检查内容类型
contentType := r.Header.Get("Content-Type")
for _, t := range m.contentTypes {
if strings.Contains(contentType, t) {
return true
}
}
return false
}
// gzipResponseWriter 包装的响应写入器
type gzipResponseWriter struct {
http.ResponseWriter
writer *gzip.Writer
}
// Write 写入数据
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.writer.Write(b)
}
// WriteHeader 写入状态码
func (w *gzipResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
// Flush 刷新数据
func (w *gzipResponseWriter) Flush() {
w.writer.Flush()
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

162
internal/plugin/plugin.go Normal file
View File

@@ -0,0 +1,162 @@
package plugin
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"plugin"
"sync"
)
// Plugin 插件接口
type Plugin interface {
// Name 插件名称
Name() string
// Version 插件版本
Version() string
// Init 初始化插件
Init(ctx context.Context) error
// Start 启动插件
Start(ctx context.Context) error
// Stop 停止插件
Stop(ctx context.Context) error
}
// PluginManager 插件管理器
type PluginManager struct {
pluginsDir string
plugins map[string]Plugin
mu sync.RWMutex
}
// NewPluginManager 创建插件管理器
func NewPluginManager(pluginsDir string) *PluginManager {
return &PluginManager{
pluginsDir: pluginsDir,
plugins: make(map[string]Plugin),
}
}
// LoadPlugins 加载插件
func (pm *PluginManager) LoadPlugins() error {
// 确保插件目录存在
if err := os.MkdirAll(pm.pluginsDir, 0755); err != nil {
return fmt.Errorf("创建插件目录失败: %v", err)
}
// 遍历插件目录
err := filepath.Walk(pm.pluginsDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
// 跳过目录和非.so文件
if info.IsDir() || filepath.Ext(path) != ".so" {
return nil
}
// 加载插件
if err := pm.loadPlugin(path); err != nil {
fmt.Printf("加载插件 %s 失败: %v\n", path, err)
}
return nil
})
return err
}
// loadPlugin 加载单个插件
func (pm *PluginManager) loadPlugin(path string) error {
// 打开插件文件
p, err := plugin.Open(path)
if err != nil {
return fmt.Errorf("打开插件失败: %v", err)
}
// 查找Plugin变量
symPlugin, err := p.Lookup("Plugin")
if err != nil {
return fmt.Errorf("查找Plugin变量失败: %v", err)
}
// 类型断言
plugin, ok := symPlugin.(Plugin)
if !ok {
return errors.New("插件类型错误")
}
// 检查插件名称是否已存在
pm.mu.Lock()
if _, exists := pm.plugins[plugin.Name()]; exists {
pm.mu.Unlock()
return fmt.Errorf("插件 %s 已存在", plugin.Name())
}
pm.plugins[plugin.Name()] = plugin
pm.mu.Unlock()
return nil
}
// GetPlugin 获取插件
func (pm *PluginManager) GetPlugin(name string) (Plugin, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
plugin, exists := pm.plugins[name]
return plugin, exists
}
// GetAllPlugins 获取所有插件
func (pm *PluginManager) GetAllPlugins() []Plugin {
pm.mu.RLock()
defer pm.mu.RUnlock()
plugins := make([]Plugin, 0, len(pm.plugins))
for _, plugin := range pm.plugins {
plugins = append(plugins, plugin)
}
return plugins
}
// InitPlugins 初始化所有插件
func (pm *PluginManager) InitPlugins(ctx context.Context) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
for name, plugin := range pm.plugins {
if err := plugin.Init(ctx); err != nil {
return fmt.Errorf("初始化插件 %s 失败: %v", name, err)
}
}
return nil
}
// StartPlugins 启动所有插件
func (pm *PluginManager) StartPlugins(ctx context.Context) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
for name, plugin := range pm.plugins {
if err := plugin.Start(ctx); err != nil {
return fmt.Errorf("启动插件 %s 失败: %v", name, err)
}
}
return nil
}
// StopPlugins 停止所有插件
func (pm *PluginManager) StopPlugins(ctx context.Context) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
for name, plugin := range pm.plugins {
if err := plugin.Stop(ctx); err != nil {
return fmt.Errorf("停止插件 %s 失败: %v", name, err)
}
}
return nil
}

View File

@@ -107,7 +107,7 @@ func WithHealthChecker(hc *healthcheck.HealthChecker) Option {
}
// WithMetrics 设置监控指标
func WithMetrics(m metrics.Metrics) Option {
func WithMetrics(m metrics.MetricsCollector) Option {
return func(opt *Options) {
opt.Metrics = m
}

View File

@@ -18,12 +18,14 @@ import (
"sync/atomic"
"time"
"github.com/darkit/goproxy/internal/auth"
"github.com/darkit/goproxy/internal/cache"
"github.com/darkit/goproxy/internal/config"
"github.com/darkit/goproxy/internal/healthcheck"
"github.com/darkit/goproxy/internal/loadbalance"
"github.com/darkit/goproxy/internal/metrics"
"github.com/darkit/goproxy/internal/middleware"
"github.com/darkit/goproxy/internal/plugin"
"github.com/darkit/goproxy/internal/reverse"
"github.com/ouqiang/websocket"
"github.com/viki-org/dnscache"
@@ -197,13 +199,31 @@ type Options struct {
// 健康检查器
HealthChecker *healthcheck.HealthChecker
// 监控指标
Metrics metrics.Metrics
Metrics metrics.MetricsCollector
// 客户端跟踪
ClientTrace *httptrace.ClientTrace
// 认证系统
Auth *auth.Auth
// 插件管理器
PluginManager *plugin.PluginManager
// 证书管理器
CertManager *CertManager
}
// WithAuth 设置认证系统
func WithAuth(auth *auth.Auth) Option {
return func(o *Options) {
o.Auth = auth
}
}
// WithPluginManager 设置插件管理器
func WithPluginManager(pluginManager *plugin.PluginManager) Option {
return func(o *Options) {
o.PluginManager = pluginManager
}
}
// Proxy HTTP代理
type Proxy struct {
// 配置
@@ -221,7 +241,7 @@ type Proxy struct {
// 健康检查器
healthChecker *healthcheck.HealthChecker
// 监控指标
metrics metrics.Metrics
metrics metrics.MetricsCollector
// 客户端跟踪
clientTrace *httptrace.ClientTrace
// 基础传输(用于直接获取*http.Transport类型
@@ -1257,13 +1277,13 @@ var hopHeaders = []string{
}
// isCacheHitMetricsSupported 检查指标是否支持缓存命中计数
func isCacheHitMetricsSupported(m metrics.Metrics) bool {
func isCacheHitMetricsSupported(m metrics.MetricsCollector) bool {
_, ok := m.(interface{ IncCacheHit() })
return ok
}
// incrementCacheHit 增加缓存命中计数
func incrementCacheHit(m metrics.Metrics) {
func incrementCacheHit(m metrics.MetricsCollector) {
if hitter, ok := m.(interface{ IncCacheHit() }); ok {
hitter.IncCacheHit()
}

View File

@@ -0,0 +1,87 @@
package server
import (
"context"
"log/slog"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
// GracefulServer 优雅关闭服务器
type GracefulServer struct {
// HTTP服务器
server *http.Server
// 等待组
wg sync.WaitGroup
// 停止信号
stopChan chan struct{}
}
// NewGracefulServer 创建优雅关闭服务器
func NewGracefulServer(addr string, handler http.Handler) *GracefulServer {
return &GracefulServer{
server: &http.Server{
Addr: addr,
Handler: handler,
},
stopChan: make(chan struct{}),
}
}
// Start 启动服务器
func (s *GracefulServer) Start() error {
// 启动HTTP服务器
go func() {
slog.Info("启动HTTP服务器", "addr", s.server.Addr)
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
slog.Error("HTTP服务器错误", "error", err)
}
}()
// 等待中断信号
s.waitForInterrupt()
return nil
}
// waitForInterrupt 等待中断信号
func (s *GracefulServer) waitForInterrupt() {
// 创建信号通道
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// 等待信号
<-sigChan
slog.Info("收到停止信号,开始优雅关闭")
// 创建关闭上下文
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 关闭HTTP服务器
if err := s.server.Shutdown(ctx); err != nil {
slog.Error("关闭HTTP服务器失败", "error", err)
}
// 通知所有goroutine停止
close(s.stopChan)
// 等待所有goroutine结束
s.wg.Wait()
slog.Info("服务器已优雅关闭")
}
// Stop 停止服务器
func (s *GracefulServer) Stop() {
close(s.stopChan)
}
// WaitGroup 获取等待组
func (s *GracefulServer) WaitGroup() *sync.WaitGroup {
return &s.wg
}