init
This commit is contained in:
895
README.md
Normal file
895
README.md
Normal file
@@ -0,0 +1,895 @@
|
||||
# GoProxy
|
||||
|
||||
GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebSocket代理,并提供了丰富的功能和扩展点。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持HTTP、HTTPS和WebSocket代理
|
||||
- 支持正向代理和反向代理
|
||||
- 支持HTTPS解密(中间人模式)
|
||||
- 自定义CA证书和私钥
|
||||
- 动态证书生成与缓存
|
||||
- 通配符域名证书支持
|
||||
- 支持RSA和ECDSA证书算法选择
|
||||
- 支持上游代理链
|
||||
- 支持多后端DNS解析和负载均衡
|
||||
- 支持一个域名对应多个后端服务器
|
||||
- 支持多种负载均衡策略(轮询、随机、第一个可用)
|
||||
- 支持通配符域名解析
|
||||
- 支持自定义DNS解析规则
|
||||
- 支持动态添加/删除后端服务器
|
||||
- 支持健康检查
|
||||
- 支持请求重试
|
||||
- 支持HTTP缓存
|
||||
- 支持请求限流
|
||||
- 支持请求/响应压缩
|
||||
- 支持gzip压缩
|
||||
- 智能压缩决策
|
||||
- 可配置压缩级别
|
||||
- 支持最小压缩大小
|
||||
- 支持多种内容类型
|
||||
- 支持监控指标收集(Prometheus格式)
|
||||
- 请求总数和延迟统计
|
||||
- 请求和响应大小统计
|
||||
- 错误计数
|
||||
- 活跃连接数
|
||||
- 连接池大小
|
||||
- 缓存命中率
|
||||
- 内存使用量
|
||||
- 后端健康状态
|
||||
- 后端响应时间
|
||||
- 支持自定义处理逻辑(委托模式)
|
||||
- 支持DNS缓存
|
||||
- 支持认证授权
|
||||
- JWT认证
|
||||
- 基于角色的访问控制
|
||||
- 用户管理
|
||||
- 权限管理
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
go get github.com/darkit/goproxy
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 正向代理
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理
|
||||
p := goproxy.New(nil)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 启用HTTPS解密
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/config"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.DecryptHTTPS = true
|
||||
cfg.CACert = "ca.crt" // CA证书路径
|
||||
cfg.CAKey = "ca.key" // CA私钥路径
|
||||
cfg.UseECDSA = true // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 可选:使用自定义TLS证书
|
||||
// cfg.TLSCert = "server.crt"
|
||||
// cfg.TLSKey = "server.key"
|
||||
|
||||
// 创建证书缓存
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
CertCache: certCache,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("HTTPS解密代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **注意**: 使用HTTPS解密功能时,需要在客户端安装CA证书,否则会出现证书警告。
|
||||
|
||||
### 反向代理
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/config"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ReverseProxy = true
|
||||
cfg.EnableURLRewrite = true
|
||||
cfg.AddXForwardedFor = true
|
||||
cfg.AddXRealIP = true
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &ReverseProxyDelegate{
|
||||
backend: "localhost:8081",
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("反向代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseProxyDelegate 反向代理委托
|
||||
type ReverseProxyDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
backend string
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.backend, nil
|
||||
}
|
||||
```
|
||||
|
||||
### 自定义委托
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{}
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.New(&goproxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) {
|
||||
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
log.Printf("响应错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||
}
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
- 正向代理示例: [cmd/example/main.go](cmd/example/main.go)
|
||||
- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go)
|
||||
|
||||
### 使用函数式选项模式
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建监控指标
|
||||
metricsCollector := metrics.NewSimpleMetrics()
|
||||
|
||||
// 创建证书缓存
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
|
||||
// 使用函数式选项模式创建代理
|
||||
p := goproxy.NewProxy(
|
||||
// 启用HTTPS解密
|
||||
goproxy.WithDecryptHTTPS(certCache),
|
||||
goproxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
|
||||
// 设置监控指标
|
||||
goproxy.WithMetrics(metricsCollector),
|
||||
|
||||
// 设置请求超时和连接池
|
||||
goproxy.WithRequestTimeout(30 * time.Second),
|
||||
goproxy.WithConnectionPoolSize(100),
|
||||
goproxy.WithIdleTimeout(90 * time.Second),
|
||||
|
||||
// 启用DNS缓存
|
||||
goproxy.WithDNSCacheTTL(10 * time.Minute),
|
||||
|
||||
// 启用请求重试
|
||||
goproxy.WithEnableRetry(3, 1*time.Second, 10*time.Second),
|
||||
|
||||
// 启用CORS支持
|
||||
goproxy.WithEnableCORS(true),
|
||||
)
|
||||
|
||||
// 启动HTTP服务器和监控服务器
|
||||
go func() {
|
||||
log.Println("监控服务器启动在 :8081")
|
||||
http.Handle("/metrics", metricsCollector.GetHandler())
|
||||
if err := http.ListenAndServe(":8081", nil); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 多后端DNS解析和负载均衡
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建DNS解析器
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
|
||||
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
|
||||
)
|
||||
|
||||
// 添加多个后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
|
||||
// 添加通配符域名解析
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.New(&goproxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
resolver *dns.CustomResolver
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *CustomDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
// 从请求中获取目标主机
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
// 解析域名获取后端服务器
|
||||
endpoint, err := d.resolver.ResolveWithPort(host, 80)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return endpoint.String(), nil
|
||||
}
|
||||
```
|
||||
|
||||
### 从配置文件加载DNS规则
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建DNS解析器
|
||||
resolver := dns.NewResolver()
|
||||
|
||||
// 从JSON文件加载DNS规则
|
||||
if err := loadDNSConfig(resolver, "dns_config.json"); err != nil {
|
||||
log.Fatalf("加载DNS配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.New(&goproxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DNSConfig DNS配置结构
|
||||
type DNSConfig struct {
|
||||
Records map[string][]string `json:"records"` // 域名到IP地址列表的映射
|
||||
Wildcards map[string][]string `json:"wildcards"` // 通配符域名到IP地址列表的映射
|
||||
}
|
||||
|
||||
func loadDNSConfig(resolver *dns.CustomResolver, filename string) error {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var config DNSConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 加载精确匹配记录
|
||||
for host, ips := range config.Records {
|
||||
for _, ip := range ips {
|
||||
if err := resolver.Add(host, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加载通配符记录
|
||||
for pattern, ips := range config.Wildcards {
|
||||
for _, ip := range ips {
|
||||
if err := resolver.AddWildcard(pattern, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
示例配置文件 `dns_config.json`:
|
||||
```json
|
||||
{
|
||||
"records": {
|
||||
"api.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
]
|
||||
},
|
||||
"wildcards": {
|
||||
"*.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 动态管理后端服务器
|
||||
|
||||
```go
|
||||
// 添加新的后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 删除特定的后端服务器
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
|
||||
// 删除整个域名记录
|
||||
resolver.Remove("api.example.com")
|
||||
|
||||
// 清除所有记录
|
||||
resolver.Clear()
|
||||
```
|
||||
|
||||
## 架构设计
|
||||
|
||||
GoProxy采用模块化设计,主要包含以下模块:
|
||||
|
||||
- **代理核心(Proxy)**:处理HTTP请求和响应,实现代理功能
|
||||
- **反向代理(ReverseProxy)**:处理反向代理请求,支持请求修改
|
||||
- **路由(Router)**:基于主机名、路径、正则表达式等规则路由请求到不同的后端
|
||||
- **代理上下文(Context)**:保存请求上下文信息,用于在处理过程中传递数据
|
||||
- **代理委托(Delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑
|
||||
- **连接缓冲区(ConnBuffer)**:封装网络连接,提供缓冲读写功能
|
||||
- **负载均衡(LoadBalancer)**:实现负载均衡算法,支持轮询、随机、权重等
|
||||
- **健康检查(HealthChecker)**:检查上游服务器的健康状态,自动剔除不健康的服务器
|
||||
- **缓存(Cache)**:实现HTTP缓存,减少重复请求
|
||||
- **缓存适配器(CacheAdapter)**:统一不同缓存实现的接口,提高代码可读性和性能
|
||||
- **证书生成(CertGenerator)**:动态生成TLS证书,支持HTTPS解密
|
||||
- **限流(RateLimit)**:实现请求限流,防止过载
|
||||
- **监控(Metrics)**:收集代理运行指标,用于监控和分析
|
||||
- **重试(Retry)**:实现请求重试,提高请求成功率
|
||||
|
||||
## 配置选项
|
||||
|
||||
GoProxy提供了丰富的配置选项,可以通过`Options`结构体进行配置:
|
||||
|
||||
```go
|
||||
type Options struct {
|
||||
// 配置
|
||||
Config *config.Config
|
||||
// 委托
|
||||
Delegate Delegate
|
||||
// 证书缓存
|
||||
CertCache CertificateCache
|
||||
// HTTP缓存
|
||||
HTTPCache cache.Cache
|
||||
// 负载均衡器
|
||||
LoadBalancer loadbalance.LoadBalancer
|
||||
// 健康检查器
|
||||
HealthChecker *healthcheck.HealthChecker
|
||||
// 监控指标
|
||||
Metrics metrics.Metrics
|
||||
// 客户端跟踪
|
||||
ClientTrace *httptrace.ClientTrace
|
||||
}
|
||||
```
|
||||
|
||||
### 函数式选项模式
|
||||
|
||||
GoProxy 现在支持函数式选项模式(Functional Options Pattern),通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于:
|
||||
|
||||
- 参数配置更加直观和清晰
|
||||
- 可以灵活选择需要的配置项,不必记忆参数顺序
|
||||
- 代码可读性更高,便于维护
|
||||
- 可以逐步添加新的配置选项而不破坏兼容性
|
||||
|
||||
可以使用 `NewProxy` 函数和函数式选项创建代理:
|
||||
|
||||
```go
|
||||
// 创建一个简单的代理
|
||||
proxy := goproxy.NewProxy()
|
||||
|
||||
// 创建一个功能丰富的代理
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithConfig(config.DefaultConfig()),
|
||||
goproxy.WithHTTPCache(myCache),
|
||||
goproxy.WithDecryptHTTPS(myCertCache),
|
||||
goproxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
goproxy.WithMetrics(myMetrics),
|
||||
goproxy.WithLoadBalancer(myLoadBalancer),
|
||||
goproxy.WithRequestTimeout(10 * time.Second),
|
||||
goproxy.WithEnableCORS(true)
|
||||
)
|
||||
```
|
||||
|
||||
### 可用的 With 方法
|
||||
|
||||
GoProxy 提供了以下 With 方法用于配置代理的各个方面:
|
||||
|
||||
#### 基础配置选项
|
||||
- `WithConfig(cfg *config.Config)`: 设置代理配置
|
||||
- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用
|
||||
- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输
|
||||
- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪
|
||||
|
||||
#### 功能模块选项
|
||||
- `WithDelegate(delegate Delegate)`: 设置委托类
|
||||
- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存
|
||||
- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器
|
||||
- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器
|
||||
- `WithMetrics(m metrics.Metrics)`: 设置监控指标
|
||||
|
||||
#### 功能开启选项
|
||||
- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS
|
||||
- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成(默认使用RSA)
|
||||
- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截
|
||||
- `WithReverseProxy(enable bool)`: 启用反向代理模式
|
||||
- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试
|
||||
- `WithRateLimit(rps float64)`: 设置请求限流
|
||||
- `WithEnableCORS(enable bool)`: 启用CORS支持
|
||||
|
||||
#### 证书相关选项
|
||||
- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥
|
||||
- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥
|
||||
|
||||
#### 性能和超时相关选项
|
||||
- `WithConnectionPoolSize(size int)`: 设置连接池大小
|
||||
- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间
|
||||
- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间
|
||||
- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL
|
||||
|
||||
### 配置HTTPS解密
|
||||
|
||||
要启用HTTPS解密功能,需要在配置中设置以下选项:
|
||||
|
||||
```go
|
||||
config := &config.Config{
|
||||
// 启用HTTPS解密
|
||||
DecryptHTTPS: true,
|
||||
|
||||
// 方式一:使用CA证书和私钥动态生成证书
|
||||
CACert: "path/to/ca.crt", // CA证书路径
|
||||
CAKey: "path/to/ca.key", // CA私钥路径
|
||||
|
||||
// 选择证书生成算法(可选)
|
||||
UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 方式二:使用固定的TLS证书和私钥
|
||||
// TLSCert: "path/to/server.crt",
|
||||
// TLSKey: "path/to/server.key",
|
||||
}
|
||||
```
|
||||
|
||||
或者使用函数式选项模式:
|
||||
|
||||
```go
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDecryptHTTPS(&goproxy.MemCertCache{}),
|
||||
goproxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
|
||||
goproxy.WithEnableECDSA(true), // 使用ECDSA生成证书
|
||||
// 或者使用静态TLS证书
|
||||
// goproxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
|
||||
)
|
||||
```
|
||||
|
||||
同时,建议配置证书缓存以提高性能:
|
||||
|
||||
```go
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
```
|
||||
|
||||
## 扩展点
|
||||
|
||||
GoProxy提供了多个扩展点,可以通过实现相应的接口进行扩展:
|
||||
|
||||
- **Delegate**:代理委托接口,用于自定义代理处理逻辑
|
||||
- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法
|
||||
- **Cache**:缓存接口,用于实现自定义缓存策略
|
||||
- **CertificateCache**:证书缓存接口,用于自定义证书存储方式
|
||||
- **Metrics**:监控接口,用于实现自定义监控指标收集
|
||||
|
||||
## 反向代理特性
|
||||
|
||||
GoProxy的反向代理模式提供以下特性:
|
||||
|
||||
- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则
|
||||
- **请求修改**:支持修改发往后端服务器的请求
|
||||
- **响应修改**:支持修改来自后端服务器的响应
|
||||
- **保留客户端信息**:支持添加X-Forwarded-For和X-Real-IP头
|
||||
- **CORS支持**:支持自动添加CORS头
|
||||
- **WebSocket支持**:支持WebSocket协议的透明代理
|
||||
- **负载均衡**:支持多种负载均衡算法
|
||||
- **健康检查**:支持对后端服务器进行健康检查
|
||||
- **监控指标**:支持收集反向代理的监控指标
|
||||
|
||||
## 监控指标
|
||||
|
||||
GoProxy提供了两种监控指标实现:PrometheusMetrics和SimpleMetrics。
|
||||
|
||||
### PrometheusMetrics
|
||||
|
||||
PrometheusMetrics是一个完整的Prometheus指标实现,提供以下指标:
|
||||
|
||||
- `proxy_requests_total`: 请求总数(按方法、路径、状态码分类)
|
||||
- `proxy_request_latency_seconds`: 请求延迟(按方法、路径分类)
|
||||
- `proxy_request_size_bytes`: 请求大小(按方法、路径分类)
|
||||
- `proxy_response_size_bytes`: 响应大小(按方法、路径分类)
|
||||
- `proxy_errors_total`: 错误总数(按类型分类)
|
||||
- `proxy_active_connections`: 活跃连接数
|
||||
- `proxy_connection_pool_size`: 连接池大小
|
||||
- `proxy_cache_hit_rate`: 缓存命中率
|
||||
- `proxy_memory_usage_bytes`: 内存使用量
|
||||
|
||||
使用示例:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建Prometheus指标收集器
|
||||
metricsCollector := metrics.NewPrometheusMetrics()
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.NewProxy(
|
||||
goproxy.WithMetrics(metricsCollector),
|
||||
)
|
||||
|
||||
// 启动监控服务器
|
||||
go func() {
|
||||
log.Println("监控服务器启动在 :8081")
|
||||
http.Handle("/metrics", metricsCollector.GetHandler())
|
||||
if err := http.ListenAndServe(":8081", nil); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### SimpleMetrics
|
||||
|
||||
SimpleMetrics是一个简单的指标实现,提供基本的指标收集功能:
|
||||
|
||||
- 请求计数
|
||||
- 错误计数
|
||||
- 活跃连接数
|
||||
- 累计响应时间
|
||||
- 传输字节数
|
||||
- 后端健康状态
|
||||
- 后端响应时间
|
||||
- 缓存命中计数
|
||||
|
||||
使用示例:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建简单指标收集器
|
||||
metricsCollector := metrics.NewSimpleMetrics()
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.NewProxy(
|
||||
goproxy.WithMetrics(metricsCollector),
|
||||
)
|
||||
|
||||
// 启动监控服务器
|
||||
go func() {
|
||||
log.Println("监控服务器启动在 :8081")
|
||||
http.Handle("/metrics", metricsCollector.GetHandler())
|
||||
if err := http.ListenAndServe(":8081", nil); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 指标中间件
|
||||
|
||||
GoProxy还提供了一个指标中间件,可以用于收集HTTP请求的指标:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建指标收集器
|
||||
metricsCollector := metrics.NewPrometheusMetrics()
|
||||
|
||||
// 创建指标中间件
|
||||
metricsMiddleware := metrics.NewMetricsMiddleware(metricsCollector)
|
||||
|
||||
// 创建HTTP处理器
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 处理请求
|
||||
w.Write([]byte("Hello, World!"))
|
||||
})
|
||||
|
||||
// 使用中间件包装处理器
|
||||
wrappedHandler := metricsMiddleware.Middleware(handler)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("HTTP服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", wrappedHandler); err != nil {
|
||||
log.Fatalf("HTTP服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 使用压缩中间件
|
||||
|
||||
GoProxy提供了压缩中间件,支持请求和响应的gzip压缩:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/middleware"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建压缩中间件
|
||||
compressionMiddleware := middleware.NewCompressionMiddleware(6, 1024) // 压缩级别6,最小压缩大小1KB
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.NewProxy(
|
||||
goproxy.WithMiddleware(compressionMiddleware),
|
||||
)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
压缩中间件提供以下特性:
|
||||
|
||||
- 自动检测客户端是否支持gzip压缩
|
||||
- 智能判断内容类型是否适合压缩
|
||||
- 可配置压缩级别(0-9)
|
||||
- 可配置最小压缩大小
|
||||
- 支持多种内容类型
|
||||
- 自动处理压缩请求体
|
||||
- 自动添加压缩相关响应头
|
||||
|
||||
### 使用认证授权系统
|
||||
|
||||
GoProxy提供了完整的认证授权系统,支持JWT认证和基于角色的访问控制:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/auth"
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建认证系统
|
||||
auth := auth.NewAuth("your-secret-key")
|
||||
|
||||
// 添加用户和角色
|
||||
auth.AddUser("admin", "password123", []string{"admin"})
|
||||
auth.AddUser("user", "password456", []string{"user"})
|
||||
|
||||
// 创建代理
|
||||
p := goproxy.NewProxy(
|
||||
goproxy.WithAuth(auth),
|
||||
)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
认证授权系统提供以下特性:
|
||||
|
||||
- JWT令牌认证
|
||||
- 基于角色的访问控制
|
||||
- 用户管理
|
||||
- 密码加密存储
|
||||
- 权限管理
|
||||
- 认证中间件
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎贡献代码、报告问题或提出建议。请遵循以下步骤:
|
||||
|
||||
1. Fork 项目
|
||||
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
|
||||
3. 提交更改 (`git commit -m 'Add some amazing feature'`)
|
||||
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||
5. 创建 Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。
|
338
README_DNS.md
Normal file
338
README_DNS.md
Normal file
@@ -0,0 +1,338 @@
|
||||
# GoProxy自定义DNS解析功能
|
||||
|
||||
该功能允许GoProxy使用自定义DNS解析器,实现以下功能:
|
||||
|
||||
- 自定义域名解析
|
||||
- 动态变更后端服务器IP
|
||||
- 自定义后端服务器端口
|
||||
- 泛解析(通配符域名)支持
|
||||
- 多后端服务器支持
|
||||
- 一个域名对应多个后端服务器
|
||||
- 支持多种负载均衡策略(轮询、随机、第一个可用)
|
||||
- 支持动态添加/删除后端服务器
|
||||
- 负载均衡和故障转移
|
||||
- 绕过DNS污染
|
||||
- 高效的DNS缓存
|
||||
|
||||
## 特性
|
||||
|
||||
- **自定义记录**:直接设置域名到IP的映射
|
||||
- **自定义端口**:为每个域名指定自定义端口,无需在URL中指定
|
||||
- **泛解析**:支持通配符域名(如`*.example.com`)自动匹配多个子域名
|
||||
- **多级泛解析**:支持复杂的通配符模式(如`api.*.example.com`)
|
||||
- **多后端支持**:支持一个域名配置多个后端服务器
|
||||
- **负载均衡**:支持多种负载均衡策略
|
||||
- **备用解析**:在自定义记录未找到时可选择使用系统DNS
|
||||
- **DNS缓存**:缓存解析结果以提高性能
|
||||
- **自动重试**:解析失败时可配置重试策略
|
||||
- **加载配置**:从JSON文件或hosts格式文件加载配置
|
||||
- **自定义拨号器**:与net标准库兼容的拨号器
|
||||
|
||||
## 安装
|
||||
|
||||
确保已安装Go(建议1.22或更高版本):
|
||||
|
||||
```bash
|
||||
git clone github.com/darkit/goproxy
|
||||
cd goproxy
|
||||
go build ./...
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. HTTP代理(自定义DNS)
|
||||
|
||||
以下命令会启动一个HTTP代理(监听端口8080),它使用自定义DNS解析:
|
||||
|
||||
```bash
|
||||
go run cmd/custom_dns_proxy/main.go
|
||||
```
|
||||
|
||||
### 2. 自定义端口代理
|
||||
|
||||
支持为不同域名指定不同端口的代理:
|
||||
|
||||
```bash
|
||||
go run cmd/custom_port_proxy/main.go
|
||||
```
|
||||
|
||||
### 3. 泛解析DNS代理
|
||||
|
||||
支持通配符域名解析的代理:
|
||||
|
||||
```bash
|
||||
# 使用默认的示例泛解析规则
|
||||
go run cmd/wildcard_dns_proxy/main.go
|
||||
|
||||
# 透明代理模式(使用请求中的Host进行匹配)
|
||||
go run cmd/wildcard_dns_proxy/main.go -target ""
|
||||
|
||||
# 使用自定义配置文件
|
||||
go run cmd/wildcard_dns_proxy/main.go -dns examples/wildcard_dns_config.json
|
||||
|
||||
# 使用hosts格式配置文件
|
||||
go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt
|
||||
```
|
||||
|
||||
#### 参数说明
|
||||
|
||||
- `-listen`: 监听地址,默认 `:8080`
|
||||
- `-target`: 目标主机名,空字符串表示使用请求中的Host头
|
||||
- `-port`: 默认目标端口,默认 `443`
|
||||
- `-dns`: DNS配置文件(JSON格式)
|
||||
- `-hosts`: hosts格式配置文件
|
||||
|
||||
## DNS配置格式
|
||||
|
||||
### JSON格式
|
||||
|
||||
```json
|
||||
{
|
||||
"records": {
|
||||
"example.com": ["93.184.216.34", "93.184.216.35"],
|
||||
"api.example.com": ["93.184.216.35:8443", "93.184.216.36:8443"],
|
||||
|
||||
"*.github.com": ["140.82.121.3", "140.82.121.4"],
|
||||
"github.com": ["140.82.121.4", "140.82.121.5"],
|
||||
|
||||
"*.dev.local": ["127.0.0.1:3000", "127.0.0.1:3001"],
|
||||
"api.*.dev.local": ["127.0.0.1:3001", "127.0.0.1:3002"]
|
||||
},
|
||||
"use_fallback": true,
|
||||
"ttl": 300,
|
||||
"load_balance_strategy": "round_robin"
|
||||
}
|
||||
```
|
||||
|
||||
### Hosts格式
|
||||
|
||||
```
|
||||
# 精确匹配(多后端)
|
||||
93.184.216.34 example.com
|
||||
93.184.216.35 example.com
|
||||
93.184.216.35:8443 api.example.com
|
||||
93.184.216.36:8443 api.example.com
|
||||
|
||||
# 泛解析(多后端)
|
||||
140.82.121.3 *.github.com
|
||||
140.82.121.4 *.github.com
|
||||
127.0.0.1:3000 *.dev.local
|
||||
127.0.0.1:3001 *.dev.local
|
||||
127.0.0.1:3001 api.*.dev.local
|
||||
127.0.0.1:3002 api.*.dev.local
|
||||
```
|
||||
|
||||
## 编程接口
|
||||
|
||||
### 使用DNS解析器
|
||||
|
||||
```go
|
||||
import "github.com/darkit/goproxy/pkg/dns"
|
||||
|
||||
// 创建解析器
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
|
||||
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
|
||||
)
|
||||
|
||||
// 添加多个后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
|
||||
// 添加泛解析记录(多后端)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
|
||||
// 解析域名(使用负载均衡策略选择后端)
|
||||
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
|
||||
if err != nil {
|
||||
log.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
fmt.Printf("解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port)
|
||||
|
||||
// 动态添加/删除后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
```
|
||||
|
||||
### 负载均衡策略
|
||||
|
||||
GoProxy支持三种负载均衡策略:
|
||||
|
||||
1. **轮询策略(Round Robin)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin),
|
||||
)
|
||||
```
|
||||
|
||||
2. **随机策略(Random)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.Random),
|
||||
)
|
||||
```
|
||||
|
||||
3. **第一个可用策略(First Available)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.FirstAvailable),
|
||||
)
|
||||
```
|
||||
|
||||
### 使用DNS拨号器
|
||||
|
||||
```go
|
||||
import "github.com/darkit/goproxy/pkg/dns"
|
||||
|
||||
// 创建解析器
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin),
|
||||
)
|
||||
resolver.AddWithPort("example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.3", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 创建拨号器
|
||||
dialer := dns.NewDialer(resolver)
|
||||
|
||||
// 使用拨号器连接(会自动应用负载均衡)
|
||||
conn, err := dialer.Dial("tcp", "api.example.com:443")
|
||||
if err != nil {
|
||||
log.Fatalf("连接失败: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 或者获取用于http.Transport的拨号上下文函数
|
||||
transport := &http.Transport{
|
||||
DialContext: dialer.DialContext,
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
```
|
||||
|
||||
## 高级用法
|
||||
|
||||
### 多后端配置模式
|
||||
|
||||
多后端配置支持以下模式:
|
||||
|
||||
1. **精确匹配多后端**:
|
||||
```go
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
```
|
||||
|
||||
2. **泛解析多后端**:
|
||||
```go
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
```
|
||||
|
||||
3. **混合配置**:
|
||||
```go
|
||||
// 精确匹配优先于泛解析
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
```
|
||||
|
||||
### 动态管理后端服务器
|
||||
|
||||
```go
|
||||
// 添加新的后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 删除特定的后端服务器
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
|
||||
// 删除整个域名记录
|
||||
resolver.Remove("api.example.com")
|
||||
|
||||
// 清除所有记录
|
||||
resolver.Clear()
|
||||
```
|
||||
|
||||
### 健康检查集成
|
||||
|
||||
可以结合健康检查功能,自动剔除不健康的后端服务器:
|
||||
|
||||
```go
|
||||
// 创建健康检查器
|
||||
healthChecker := healthcheck.NewChecker(
|
||||
healthcheck.WithCheckInterval(30*time.Second),
|
||||
healthcheck.WithTimeout(5*time.Second),
|
||||
)
|
||||
|
||||
// 添加健康检查
|
||||
healthChecker.Add("api.example.com", "192.168.1.1:8080")
|
||||
healthChecker.Add("api.example.com", "192.168.1.2:8080")
|
||||
|
||||
// 在解析器中使用健康检查结果
|
||||
resolver.SetHealthChecker(healthChecker)
|
||||
```
|
||||
|
||||
## 应用场景
|
||||
|
||||
### 高可用部署
|
||||
|
||||
使用多后端配置实现高可用:
|
||||
|
||||
```
|
||||
# 主备模式
|
||||
192.168.1.1 api.example.com # 主服务器
|
||||
192.168.1.2 api.example.com # 备用服务器
|
||||
|
||||
# 负载均衡模式
|
||||
192.168.1.1 api.example.com # 服务器1
|
||||
192.168.1.2 api.example.com # 服务器2
|
||||
192.168.1.3 api.example.com # 服务器3
|
||||
```
|
||||
|
||||
### 多环境部署
|
||||
|
||||
为不同环境配置不同的后端服务器组:
|
||||
|
||||
```
|
||||
# 测试环境
|
||||
192.168.1.10 *.test.example.com
|
||||
192.168.1.11 *.test.example.com
|
||||
|
||||
# 预发布环境
|
||||
192.168.1.20 *.staging.example.com
|
||||
192.168.1.21 *.staging.example.com
|
||||
|
||||
# 生产环境
|
||||
192.168.1.30 *.production.example.com
|
||||
192.168.1.31 *.production.example.com
|
||||
```
|
||||
|
||||
### 微服务架构
|
||||
|
||||
为不同类型的微服务配置多个后端:
|
||||
|
||||
```
|
||||
# 认证服务
|
||||
10.0.0.1:8001 *.auth.internal
|
||||
10.0.0.2:8001 *.auth.internal
|
||||
|
||||
# 用户服务
|
||||
10.0.0.3:8002 *.user.internal
|
||||
10.0.0.4:8002 *.user.internal
|
||||
|
||||
# 支付服务
|
||||
10.0.0.5:8003 *.payment.internal
|
||||
10.0.0.6:8003 *.payment.internal
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 通配符域名只在我们的自定义DNS解析器中有效,不会影响系统DNS
|
||||
2. 泛解析规则的顺序会影响匹配结果,后添加的规则优先级更高
|
||||
3. 过多的泛解析规则可能会影响性能,建议合理组织规则
|
||||
4. 当域名同时匹配多个规则时,精确匹配优先于通配符匹配
|
||||
5. 自签名证书会导致浏览器警告,仅用于测试目的
|
||||
6. 多后端配置时,建议使用健康检查确保后端服务器可用性
|
||||
7. 负载均衡策略的选择应根据实际需求进行配置
|
||||
8. 动态添加/删除后端服务器时,需要考虑并发安全性
|
47
base.go
Normal file
47
base.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BaseProxy 基础代理接口
|
||||
type BaseProxy interface {
|
||||
// ServeHTTP 处理HTTP请求
|
||||
ServeHTTP(w http.ResponseWriter, r *http.Request)
|
||||
// Close 关闭代理
|
||||
Close() error
|
||||
}
|
||||
|
||||
// BaseConfig 基础配置
|
||||
type BaseConfig struct {
|
||||
// 监听地址
|
||||
ListenAddr string
|
||||
// 是否启用HTTPS
|
||||
EnableHTTPS bool
|
||||
// TLS配置
|
||||
TLSConfig *TLSConfig
|
||||
// 是否启用WebSocket
|
||||
EnableWebSocket bool
|
||||
// 是否启用压缩
|
||||
EnableCompression bool
|
||||
// 是否启用CORS
|
||||
EnableCORS bool
|
||||
// 是否保留客户端IP
|
||||
PreserveClientIP bool
|
||||
// 是否添加X-Forwarded-For头
|
||||
AddXForwardedFor bool
|
||||
// 是否添加X-Real-IP头
|
||||
AddXRealIP bool
|
||||
}
|
||||
|
||||
// TLSConfig TLS配置
|
||||
type TLSConfig struct {
|
||||
// 证书文件路径
|
||||
CertFile string
|
||||
// 密钥文件路径
|
||||
KeyFile string
|
||||
// 是否跳过证书验证
|
||||
InsecureSkipVerify bool
|
||||
// 是否使用ECDSA
|
||||
UseECDSA bool
|
||||
}
|
552
certificate.go
Normal file
552
certificate.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 内置默认CA证书和私钥
|
||||
var (
|
||||
// 默认根证书
|
||||
defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
|
||||
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
|
||||
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
|
||||
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
|
||||
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
|
||||
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
|
||||
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
|
||||
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
|
||||
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
|
||||
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
|
||||
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
|
||||
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
// 默认根私钥
|
||||
defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49
|
||||
AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/
|
||||
nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`)
|
||||
)
|
||||
|
||||
// 加载和初始化默认根证书
|
||||
var (
|
||||
defaultRootCA *x509.Certificate
|
||||
defaultRootKey *ecdsa.PrivateKey
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
block, _ := pem.Decode(defaultRootCAPem)
|
||||
if block == nil {
|
||||
panic("解析默认根证书PEM块失败")
|
||||
}
|
||||
defaultRootCA, err = x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("加载默认根证书失败: %s", err))
|
||||
}
|
||||
|
||||
block, _ = pem.Decode(defaultRootKeyPem)
|
||||
if block == nil {
|
||||
panic("解析默认根私钥PEM块失败")
|
||||
}
|
||||
defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("加载默认根私钥失败: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// CertManager 证书管理器
|
||||
type CertManager struct {
|
||||
// 证书缓存
|
||||
cache CertificateCache
|
||||
// 默认私钥,可用于多个证书共享
|
||||
defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥
|
||||
// 默认使用ECDSA P-256曲线
|
||||
curve elliptic.Curve
|
||||
// 证书有效期(年)
|
||||
validityYears int
|
||||
// 是否使用ECDSA(否则使用RSA)
|
||||
useECDSA bool
|
||||
}
|
||||
|
||||
// NewCertManager 创建证书管理器
|
||||
func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager {
|
||||
manager := &CertManager{
|
||||
cache: cache,
|
||||
curve: elliptic.P256(), // 默认使用P-256曲线
|
||||
validityYears: 1, // 默认证书有效期1年
|
||||
useECDSA: true, // 默认使用ECDSA
|
||||
}
|
||||
|
||||
// 应用选项
|
||||
for _, option := range options {
|
||||
option(manager)
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// CertManagerOption 证书管理器选项
|
||||
type CertManagerOption func(*CertManager)
|
||||
|
||||
// WithUseECDSA 设置是否使用ECDSA(否则使用RSA)
|
||||
func WithUseECDSA(useECDSA bool) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
m.useECDSA = useECDSA
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultPrivateKey 设置是否使用默认私钥
|
||||
func WithDefaultPrivateKey(enable bool) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
if enable {
|
||||
if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
priv, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err))
|
||||
}
|
||||
m.defaultPrivateKey = priv
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("生成默认RSA私钥失败: %s", err))
|
||||
}
|
||||
m.defaultPrivateKey = priv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithCurve 设置椭圆曲线
|
||||
func WithCurve(curve elliptic.Curve) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
m.curve = curve
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidityYears 设置证书有效期(年)
|
||||
func WithValidityYears(years int) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
if years > 0 {
|
||||
m.validityYears = years
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTLSConfig 为指定主机生成TLS配置
|
||||
func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) {
|
||||
// 处理可能的端口
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// 检查证书缓存
|
||||
if m.cache != nil {
|
||||
// 检查主域名和子域名
|
||||
fields := strings.Split(host, ".")
|
||||
domains := []string{host}
|
||||
|
||||
// 添加父域名
|
||||
if len(fields) > 2 {
|
||||
domains = append(domains, strings.Join(fields[1:], "."))
|
||||
}
|
||||
|
||||
// 查找缓存
|
||||
for _, domain := range domains {
|
||||
if cert := m.cache.Get(domain); cert != nil {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成新证书
|
||||
cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存证书
|
||||
if m.cache != nil {
|
||||
// 缓存主机名
|
||||
m.cache.Set(host, cert)
|
||||
|
||||
// 如果是IP地址,不进行其他处理
|
||||
if net.ParseIP(host) != nil {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 缓存域名和子域名证书
|
||||
fields := strings.Split(host, ".")
|
||||
if len(fields) >= 2 {
|
||||
// 缓存主域名
|
||||
domain := strings.Join(fields[1:], ".")
|
||||
m.cache.Set(domain, cert)
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateCertificate 生成证书
|
||||
func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) {
|
||||
// 准备私钥
|
||||
var priv interface{}
|
||||
var pubKey interface{}
|
||||
var err error
|
||||
|
||||
// 使用默认私钥或生成新私钥
|
||||
if m.defaultPrivateKey != nil {
|
||||
priv = m.defaultPrivateKey
|
||||
} else if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err)
|
||||
}
|
||||
priv = ecdsaKey
|
||||
pubKey = &ecdsaKey.PublicKey
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA私钥失败: %s", err)
|
||||
}
|
||||
priv = rsaKey
|
||||
pubKey = &rsaKey.PublicKey
|
||||
}
|
||||
|
||||
// 获取公钥
|
||||
if pubKey == nil {
|
||||
switch k := priv.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
pubKey = &k.PublicKey
|
||||
case *rsa.PrivateKey:
|
||||
pubKey = &k.PublicKey
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的私钥类型")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建证书模板
|
||||
template := m.createCertificateTemplate(host)
|
||||
|
||||
// 签名证书
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 编码为PEM格式
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: derBytes,
|
||||
})
|
||||
|
||||
// 编码私钥
|
||||
var keyPEM []byte
|
||||
switch k := priv.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
privBytes, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err)
|
||||
}
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: privBytes,
|
||||
})
|
||||
case *rsa.PrivateKey:
|
||||
privBytes := x509.MarshalPKCS1PrivateKey(k)
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privBytes,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的私钥类型")
|
||||
}
|
||||
|
||||
// 创建TLS证书
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建TLS证书对失败: %s", err)
|
||||
}
|
||||
|
||||
return &tlsCert, nil
|
||||
}
|
||||
|
||||
// createCertificateTemplate 创建证书模板
|
||||
func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate {
|
||||
// 使用基于主机名的哈希值作为序列号
|
||||
fv := fnv.New64a()
|
||||
fv.Write([]byte(host))
|
||||
serialNumber := big.NewInt(0).SetUint64(fv.Sum64())
|
||||
|
||||
// 准备模板
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
Organization: []string{"GoProxy Dynamic CA"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效,容忍时间偏差
|
||||
NotAfter: time.Now().AddDate(m.validityYears, 0, 0),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
// 处理IP地址和域名
|
||||
ipAddr := net.ParseIP(host)
|
||||
if ipAddr != nil {
|
||||
template.IPAddresses = []net.IP{ipAddr}
|
||||
} else {
|
||||
// 移除可能的端口部分
|
||||
if strings.Contains(host, ":") {
|
||||
host = strings.Split(host, ":")[0]
|
||||
}
|
||||
|
||||
// 将主机名添加到DNS名称列表
|
||||
template.DNSNames = []string{host}
|
||||
|
||||
// 添加通配符域名支持
|
||||
fields := strings.Split(host, ".")
|
||||
fieldNum := len(fields)
|
||||
|
||||
// 为每一级子域名添加通配符
|
||||
for i := 0; i <= (fieldNum - 2); i++ {
|
||||
wildcardDomain := "*." + strings.Join(fields[i:], ".")
|
||||
// 避免重复
|
||||
if wildcardDomain != host {
|
||||
template.DNSNames = append(template.DNSNames, wildcardDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// LoadCAFromFiles 从文件加载CA证书和私钥
|
||||
func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
// 读取CA证书
|
||||
caCertPEM, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(caCertPEM)
|
||||
if block == nil {
|
||||
return nil, nil, fmt.Errorf("解析CA证书PEM块失败")
|
||||
}
|
||||
|
||||
caCert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析CA证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 读取CA私钥
|
||||
caKeyPEM, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err)
|
||||
}
|
||||
|
||||
block, _ = pem.Decode(caKeyPEM)
|
||||
if block == nil {
|
||||
return nil, nil, fmt.Errorf("解析CA私钥PEM块失败")
|
||||
}
|
||||
|
||||
var caKey *ecdsa.PrivateKey
|
||||
|
||||
// 尝试不同的私钥格式
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
caKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err)
|
||||
}
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err)
|
||||
}
|
||||
var ok bool
|
||||
caKey, ok = key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("私钥不是ECDSA类型")
|
||||
}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type)
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
// GetDefaultRootCA 获取默认根证书和私钥
|
||||
func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
return defaultRootCA, defaultRootKey
|
||||
}
|
||||
|
||||
// GenerateRootCA 生成新的根证书和私钥
|
||||
func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
// 生成私钥
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err)
|
||||
}
|
||||
|
||||
// 随机生成序列号
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成序列号失败: %s", err)
|
||||
}
|
||||
|
||||
// 创建根证书模板
|
||||
notBefore := time.Now().Add(-10 * time.Minute)
|
||||
notAfter := notBefore.AddDate(validYears, 0, 0)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "GoProxy Root CA",
|
||||
Organization: []string{"GoProxy"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 2,
|
||||
}
|
||||
|
||||
// 自签名
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 解析生成的证书
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||
}
|
||||
|
||||
return cert, priv, nil
|
||||
}
|
||||
|
||||
// SaveCertificateToFile 将证书和私钥保存到文件
|
||||
func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error {
|
||||
// 保存证书
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})
|
||||
if err := os.WriteFile(certFile, certPEM, 0o644); err != nil {
|
||||
return fmt.Errorf("保存证书到文件失败: %s", err)
|
||||
}
|
||||
|
||||
// 保存私钥
|
||||
keyBytes, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化私钥失败: %s", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
})
|
||||
if err := os.WriteFile(keyFile, keyPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("保存私钥到文件失败: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRootCA 生成根证书
|
||||
func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) {
|
||||
// 生成私钥
|
||||
var priv interface{}
|
||||
var pubKey interface{}
|
||||
var err error
|
||||
|
||||
if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err)
|
||||
}
|
||||
priv = ecdsaKey
|
||||
pubKey = &ecdsaKey.PublicKey
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err)
|
||||
}
|
||||
priv = rsaKey
|
||||
pubKey = &rsaKey.PublicKey
|
||||
}
|
||||
|
||||
// 创建根证书模板
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "GoProxy Root CA",
|
||||
Organization: []string{"GoProxy Authority"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-10 * time.Minute),
|
||||
NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 1,
|
||||
}
|
||||
|
||||
// 自签名
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 解析证书
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||
}
|
||||
|
||||
return cert, priv, nil
|
||||
}
|
86
cmd/cmd_reverse_proxy.go
Normal file
86
cmd/cmd_reverse_proxy.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
listenAddr := flag.String("listen", ":8080", "监听地址")
|
||||
targetAddr := flag.String("target", "http://localhost:8080", "目标服务地址")
|
||||
enableHTTPS := flag.Bool("https", false, "启用HTTPS")
|
||||
certFile := flag.String("cert", "", "TLS证书文件")
|
||||
keyFile := flag.String("key", "", "TLS私钥文件")
|
||||
enableCache := flag.Bool("cache", false, "启用缓存")
|
||||
enableCompression := flag.Bool("compression", true, "启用压缩")
|
||||
enableCORS := flag.Bool("cors", false, "启用CORS")
|
||||
preserveClientIP := flag.Bool("preserve-ip", true, "保留客户端IP")
|
||||
addXForwardedFor := flag.Bool("x-forwarded-for", true, "添加X-Forwarded-For头")
|
||||
addXRealIP := flag.Bool("x-real-ip", true, "添加X-Real-IP头")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = *listenAddr // 监听地址
|
||||
cfg.TargetAddr = *targetAddr // 目标地址
|
||||
cfg.DecryptHTTPS = *enableHTTPS // 启用HTTPS
|
||||
cfg.TLSCert = *certFile // 证书文件
|
||||
cfg.TLSKey = *keyFile // 私钥文件
|
||||
cfg.EnableCache = *enableCache // 启用缓存
|
||||
cfg.EnableCompression = *enableCompression // 启用压缩
|
||||
cfg.EnableCORS = *enableCORS // 启用CORS
|
||||
cfg.PreserveClientIP = *preserveClientIP // 保留客户端IP
|
||||
cfg.AddXForwardedFor = *addXForwardedFor // 添加X-Forwarded-For头
|
||||
cfg.AddXRealIP = *addXRealIP // 添加X-Real-IP头
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: *listenAddr,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// 优雅退出
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
fmt.Printf("反向代理启动在 %s,转发到 %s\n", *listenAddr, *targetAddr)
|
||||
var err error
|
||||
if *enableHTTPS && *certFile != "" && *keyFile != "" {
|
||||
err = server.ListenAndServeTLS(*certFile, *keyFile)
|
||||
} else {
|
||||
err = server.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务器启动失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待退出信号
|
||||
<-quit
|
||||
fmt.Println("服务器正在关闭...")
|
||||
|
||||
// 关闭服务器
|
||||
server.Shutdown(nil) // 简化版,不使用context
|
||||
|
||||
fmt.Println("服务器已关闭")
|
||||
}
|
100
cmd/functional_options_proxy.go
Normal file
100
cmd/functional_options_proxy.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
"github.com/darkit/goproxy/pkg/cache"
|
||||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// 这是functional_options_proxy的main函数,改名避免与其他文件冲突
|
||||
func main_functional_options() {
|
||||
// 创建基本配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = ":8080" // 监听地址
|
||||
cfg.TargetAddr = "http://example.com" // 目标地址
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
cfg.AddXRealIP = true // 添加X-Real-IP头
|
||||
cfg.EnableCompression = true // 启用压缩
|
||||
|
||||
// 创建负载均衡器
|
||||
// 根据linter错误,要检查正确的API
|
||||
backends := []string{
|
||||
"http://server1:8080",
|
||||
"http://server2:8080",
|
||||
"http://server3:8080",
|
||||
}
|
||||
lb := loadbalance.NewRoundRobinBalancer()
|
||||
// 设置后端服务器列表
|
||||
lb.AddList(backends, 1)
|
||||
|
||||
// 健康检查配置
|
||||
// 根据错误信息调整构造函数
|
||||
healthCfg := &healthcheck.Config{
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 3 * time.Second,
|
||||
MaxFails: 3,
|
||||
}
|
||||
// 假设NewHealthChecker只接受Config参数
|
||||
healthChecker := healthcheck.NewHealthChecker(healthCfg)
|
||||
// 然后单独设置后端
|
||||
healthChecker.AddTargetList(backends)
|
||||
|
||||
// 创建缓存
|
||||
// 根据错误信息修正参数
|
||||
memCache := cache.NewMemoryCache(5*time.Minute, 30*time.Second, 1000)
|
||||
|
||||
// 使用功能选项模式创建反向代理
|
||||
proxy := goproxy.NewWithOptions(
|
||||
// 将配置对象传入
|
||||
goproxy.WithConfig(cfg),
|
||||
|
||||
// 性能优化
|
||||
goproxy.WithConnectionPoolSize(100),
|
||||
goproxy.WithIdleTimeout(30*time.Second),
|
||||
goproxy.WithRequestTimeout(10*time.Second),
|
||||
|
||||
// 负载均衡
|
||||
goproxy.WithLoadBalancer(lb),
|
||||
|
||||
// 健康检查
|
||||
goproxy.WithHealthChecker(healthChecker),
|
||||
|
||||
// 缓存
|
||||
goproxy.WithHTTPCache(memCache),
|
||||
|
||||
// 指标收集
|
||||
goproxy.WithMetrics(metrics.NewSimpleMetrics()),
|
||||
|
||||
// 重试策略
|
||||
goproxy.WithEnableRetry(3, 1*time.Second, 10*time.Second),
|
||||
|
||||
// DNS缓存
|
||||
goproxy.WithDNSCacheTTL(5*time.Minute),
|
||||
|
||||
// 启用CORS
|
||||
goproxy.WithEnableCORS(true),
|
||||
)
|
||||
|
||||
// 启动服务器
|
||||
fmt.Println("反向代理启动在 :8080,转发到 http://example.com 并启用负载均衡")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加实际的main函数,调用上面的示例函数
|
||||
func main() {
|
||||
main_functional_options()
|
||||
}
|
152
cmd/https_reverse_proxy.go
Normal file
152
cmd/https_reverse_proxy.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
var (
|
||||
listenAddr = flag.String("listen", ":443", "HTTPS监听地址")
|
||||
httpAddr = flag.String("http", ":80", "HTTP监听地址,用于重定向到HTTPS")
|
||||
targetAddr = flag.String("target", "http://localhost:8080", "目标服务地址")
|
||||
certFile = flag.String("cert", "server.crt", "TLS证书文件")
|
||||
keyFile = flag.String("key", "server.key", "TLS私钥文件")
|
||||
autoRedirect = flag.Bool("redirect", true, "自动将HTTP请求重定向到HTTPS")
|
||||
insecureSkipVerify = flag.Bool("insecure", false, "跳过目标HTTPS验证")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
// 检查证书和私钥文件
|
||||
if !fileExists(*certFile) || !fileExists(*keyFile) {
|
||||
log.Printf("证书或私钥文件不存在: %s, %s", *certFile, *keyFile)
|
||||
if *autoRedirect {
|
||||
log.Printf("将只启动HTTP服务")
|
||||
} else {
|
||||
log.Fatalf("无法启动HTTPS服务,请提供有效的证书和私钥文件")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = *listenAddr // HTTPS监听地址
|
||||
cfg.TargetAddr = *targetAddr // 目标地址
|
||||
cfg.DecryptHTTPS = true // 启用HTTPS
|
||||
cfg.TLSCert = *certFile // 证书文件
|
||||
cfg.TLSKey = *keyFile // 私钥文件
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
cfg.InsecureSkipVerify = *insecureSkipVerify // 是否跳过目标HTTPS验证
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 创建HTTPS服务器
|
||||
httpsServer := &http.Server{
|
||||
Addr: *listenAddr,
|
||||
Handler: proxy,
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
// 创建HTTP重定向服务器
|
||||
var httpServer *http.Server
|
||||
if *autoRedirect {
|
||||
httpServer = &http.Server{
|
||||
Addr: *httpAddr,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 构建重定向URL
|
||||
host := r.Host
|
||||
// 如果Host包含端口,去除端口
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
// 构建HTTPS URL
|
||||
target := "https://" + host
|
||||
// 如果HTTPS端口不是443,添加端口
|
||||
if *listenAddr != ":443" {
|
||||
_, port, _ := net.SplitHostPort(*listenAddr)
|
||||
if port != "" && port != "443" {
|
||||
target = "https://" + host + ":" + port
|
||||
}
|
||||
}
|
||||
// 添加路径和查询参数
|
||||
target += r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
target += "?" + r.URL.RawQuery
|
||||
}
|
||||
// 重定向
|
||||
http.Redirect(w, r, target, http.StatusMovedPermanently)
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// 优雅退出
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 启动HTTPS服务器
|
||||
if fileExists(*certFile) && fileExists(*keyFile) {
|
||||
go func() {
|
||||
fmt.Printf("HTTPS反向代理启动在 %s,转发到 %s\n", *listenAddr, *targetAddr)
|
||||
if err := httpsServer.ListenAndServeTLS(*certFile, *keyFile); err != nil && err != http.ErrServerClosed {
|
||||
log.Printf("HTTPS服务器启动失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 启动HTTP重定向服务器
|
||||
if *autoRedirect && httpServer != nil {
|
||||
go func() {
|
||||
fmt.Printf("HTTP重定向服务启动在 %s\n", *httpAddr)
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Printf("HTTP服务器启动失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 等待退出信号
|
||||
<-quit
|
||||
fmt.Println("服务器正在关闭...")
|
||||
|
||||
// 关闭服务器
|
||||
if httpServer != nil {
|
||||
if err := httpServer.Close(); err != nil {
|
||||
log.Printf("HTTP服务器关闭失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if fileExists(*certFile) && fileExists(*keyFile) {
|
||||
if err := httpsServer.Close(); err != nil {
|
||||
log.Printf("HTTPS服务器关闭失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("服务器已关闭")
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
func fileExists(filename string) bool {
|
||||
info, err := os.Stat(filename)
|
||||
if os.IsNotExist(err) {
|
||||
return false
|
||||
}
|
||||
return !info.IsDir()
|
||||
}
|
93
cmd/middleware_reverse_proxy.go
Normal file
93
cmd/middleware_reverse_proxy.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
// 自定义代理委托实现
|
||||
type CustomDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// 创建自定义代理委托
|
||||
func NewCustomDelegate() *CustomDelegate {
|
||||
return &CustomDelegate{
|
||||
logger: log.New(log.Writer(), "[反向代理] ", log.LstdFlags),
|
||||
}
|
||||
}
|
||||
|
||||
// 连接处理
|
||||
func (d *CustomDelegate) Connect(ctx *goproxy.Context, rw http.ResponseWriter) {
|
||||
d.logger.Printf("新连接: %s -> %s", ctx.Req.RemoteAddr, ctx.Req.Host)
|
||||
}
|
||||
|
||||
// 请求前处理
|
||||
func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) {
|
||||
// 添加自定义请求头
|
||||
ctx.Req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339))
|
||||
ctx.Req.Header.Set("X-Proxy-ID", "custom-proxy-1")
|
||||
|
||||
d.logger.Printf("处理请求: %s %s", ctx.Req.Method, ctx.Req.URL.String())
|
||||
}
|
||||
|
||||
// 响应前处理
|
||||
func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
d.logger.Printf("请求错误: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
// 添加自定义响应头
|
||||
resp.Header.Set("X-Proxy-Server", "GoProxy")
|
||||
|
||||
// 记录响应状态
|
||||
d.logger.Printf("响应状态: %d %s", resp.StatusCode, http.StatusText(resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
// 请求完成处理
|
||||
func (d *CustomDelegate) Finish(ctx *goproxy.Context) {
|
||||
d.logger.Printf("请求完成: %s %s", ctx.Req.Method, ctx.Req.URL.String())
|
||||
}
|
||||
|
||||
// 错误日志
|
||||
func (d *CustomDelegate) ErrorLog(err error) {
|
||||
d.logger.Printf("错误: %s", err.Error())
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建基本配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = ":8080" // 监听本地8080端口
|
||||
cfg.TargetAddr = "http://example.com" // 转发到example.com
|
||||
|
||||
// 启用基本功能
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
|
||||
// 创建自定义代理委托
|
||||
delegate := NewCustomDelegate()
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
Delegate: delegate, // 使用自定义委托
|
||||
})
|
||||
|
||||
// 启动服务器
|
||||
fmt.Println("反向代理启动在 :8080,转发到 http://example.com")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
33
cmd/proxy_rules.json
Normal file
33
cmd/proxy_rules.json
Normal file
@@ -0,0 +1,33 @@
|
||||
{
|
||||
"rules": [
|
||||
{
|
||||
"path": "/api/",
|
||||
"target": "http://api-server:8000",
|
||||
"strip_prefix": true
|
||||
},
|
||||
{
|
||||
"path": "/admin/",
|
||||
"target": "http://admin-server:9000",
|
||||
"strip_prefix": false
|
||||
},
|
||||
{
|
||||
"path": "/static/",
|
||||
"target": "http://static-server:8080",
|
||||
"strip_prefix": true,
|
||||
"cache": true,
|
||||
"cache_ttl": 3600
|
||||
},
|
||||
{
|
||||
"host": "api.example.com",
|
||||
"target": "http://api-server:8000"
|
||||
},
|
||||
{
|
||||
"host": "admin.example.com",
|
||||
"target": "http://admin-server:9000"
|
||||
},
|
||||
{
|
||||
"default": true,
|
||||
"target": "http://default-server:8080"
|
||||
}
|
||||
]
|
||||
}
|
175
cmd/reverse_proxy_example.go
Normal file
175
cmd/reverse_proxy_example.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
"github.com/darkit/goproxy/pkg/cache"
|
||||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// 命令行参数
|
||||
var (
|
||||
// 监听地址
|
||||
listenAddr string
|
||||
// 目标地址(后端服务)
|
||||
targetAddr string
|
||||
// 启用HTTPS
|
||||
enableHTTPS bool
|
||||
// TLS证书文件
|
||||
certFile string
|
||||
// TLS私钥文件
|
||||
keyFile string
|
||||
// 启用负载均衡
|
||||
enableLoadBalancing bool
|
||||
// 负载均衡目标地址列表
|
||||
targets string
|
||||
// 启用健康检查
|
||||
enableHealthCheck bool
|
||||
// 健康检查间隔
|
||||
healthCheckInterval time.Duration
|
||||
// 启用缓存
|
||||
enableCache bool
|
||||
// 启用压缩
|
||||
enableCompression bool
|
||||
// 启用CORS
|
||||
enableCORS bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 解析命令行参数
|
||||
flag.StringVar(&listenAddr, "listen", ":8080", "监听地址")
|
||||
flag.StringVar(&targetAddr, "target", "http://localhost:9090", "目标地址")
|
||||
flag.BoolVar(&enableHTTPS, "https", false, "启用HTTPS")
|
||||
flag.StringVar(&certFile, "cert", "", "TLS证书文件")
|
||||
flag.StringVar(&keyFile, "key", "", "TLS私钥文件")
|
||||
flag.BoolVar(&enableLoadBalancing, "lb", false, "启用负载均衡")
|
||||
flag.StringVar(&targets, "targets", "", "负载均衡目标地址列表,用逗号分隔")
|
||||
flag.BoolVar(&enableHealthCheck, "health", false, "启用健康检查")
|
||||
flag.DurationVar(&healthCheckInterval, "health-interval", 10*time.Second, "健康检查间隔")
|
||||
flag.BoolVar(&enableCache, "cache", false, "启用缓存")
|
||||
flag.BoolVar(&enableCompression, "compression", true, "启用压缩")
|
||||
flag.BoolVar(&enableCORS, "cors", false, "启用CORS")
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
flag.Parse()
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 设置基本配置
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = listenAddr
|
||||
cfg.TargetAddr = targetAddr
|
||||
cfg.DecryptHTTPS = enableHTTPS
|
||||
cfg.TLSCert = certFile
|
||||
cfg.TLSKey = keyFile
|
||||
cfg.EnableCache = enableCache
|
||||
cfg.EnableCompression = enableCompression
|
||||
cfg.EnableCORS = enableCORS
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
cfg.AddXRealIP = true // 添加X-Real-IP头
|
||||
|
||||
// 健康检查配置
|
||||
cfg.EnableHealthCheck = enableHealthCheck
|
||||
cfg.HealthCheckInterval = healthCheckInterval
|
||||
cfg.HealthCheckTimeout = 3 * time.Second
|
||||
|
||||
// 负载均衡配置
|
||||
cfg.EnableLoadBalancing = enableLoadBalancing
|
||||
|
||||
// 创建代理选项
|
||||
opts := &goproxy.Options{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
// 如果启用负载均衡,创建负载均衡器
|
||||
if enableLoadBalancing && targets != "" {
|
||||
// 创建轮询负载均衡器
|
||||
lb := loadbalance.NewRoundRobinBalancer()
|
||||
lb.Add(targets, 1)
|
||||
opts.LoadBalancer = lb
|
||||
|
||||
// 如果启用健康检查,创建健康检查器
|
||||
if enableHealthCheck {
|
||||
// 健康检查配置
|
||||
healthCfg := &healthcheck.Config{
|
||||
Interval: healthCheckInterval,
|
||||
Timeout: 3 * time.Second,
|
||||
MinSuccess: 1,
|
||||
MaxFails: 3,
|
||||
}
|
||||
healthChecker := healthcheck.NewHealthChecker(healthCfg)
|
||||
healthChecker.AddTarget(targets)
|
||||
opts.HealthChecker = healthChecker
|
||||
|
||||
// 启动健康检查
|
||||
healthChecker.Start()
|
||||
defer healthChecker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// 如果启用缓存,创建缓存
|
||||
if enableCache {
|
||||
// 创建一个内存缓存,TTL为5分钟
|
||||
memCache := cache.NewMemoryCache(5*time.Minute, time.Second, 10000)
|
||||
opts.HTTPCache = memCache
|
||||
}
|
||||
|
||||
// 创建指标收集器
|
||||
metricsCollector := metrics.NewSimpleMetrics()
|
||||
opts.Metrics = metricsCollector
|
||||
|
||||
// 创建代理
|
||||
proxy := goproxy.New(opts)
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
fmt.Printf("反向代理启动在 %s,目标地址为 %s\n", listenAddr, targetAddr)
|
||||
var err error
|
||||
if enableHTTPS && certFile != "" && keyFile != "" {
|
||||
err = server.ListenAndServeTLS(certFile, keyFile)
|
||||
} else {
|
||||
err = server.ListenAndServe()
|
||||
}
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("服务器启动失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 优雅退出
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
// 关闭服务器
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
log.Fatalf("服务器关闭失败: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println("服务器已关闭")
|
||||
}
|
121
cmd/rules_reverse_proxy.go
Normal file
121
cmd/rules_reverse_proxy.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
var (
|
||||
listenAddr = flag.String("listen", ":8080", "监听地址")
|
||||
rulesFile = flag.String("rules", "proxy_rules.json", "代理规则文件")
|
||||
enableCache = flag.Bool("cache", false, "启用缓存")
|
||||
enableCORS = flag.Bool("cors", true, "启用CORS")
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
// 检查规则文件是否存在
|
||||
if _, err := os.Stat(*rulesFile); os.IsNotExist(err) {
|
||||
// 创建示例规则文件
|
||||
createExampleRulesFile(*rulesFile)
|
||||
fmt.Printf("已创建示例规则文件: %s\n", *rulesFile)
|
||||
}
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = *listenAddr // 监听地址
|
||||
cfg.ReverseProxyRulesFile = *rulesFile // 规则文件
|
||||
cfg.EnableCache = *enableCache // 是否启用缓存
|
||||
cfg.EnableCORS = *enableCORS // 是否启用CORS
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
cfg.AddXRealIP = true // 添加X-Real-IP头
|
||||
cfg.TargetAddr = "www.chipeak.com"
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: *listenAddr,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// 优雅退出
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 启动服务
|
||||
go func() {
|
||||
fmt.Printf("反向代理启动在 %s,使用规则文件 %s\n", *listenAddr, *rulesFile)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("服务器启动失败: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待退出信号
|
||||
<-quit
|
||||
fmt.Println("服务器正在关闭...")
|
||||
|
||||
// 关闭服务器
|
||||
if err := server.Close(); err != nil {
|
||||
log.Fatalf("服务器关闭失败: %v\n", err)
|
||||
}
|
||||
fmt.Println("服务器已关闭")
|
||||
}
|
||||
|
||||
// 创建示例规则文件
|
||||
func createExampleRulesFile(filename string) {
|
||||
content := `{
|
||||
"rules": [
|
||||
{
|
||||
"path": "/api/",
|
||||
"target": "http://api-server:8000",
|
||||
"strip_prefix": true
|
||||
},
|
||||
{
|
||||
"path": "/admin/",
|
||||
"target": "http://admin-server:9000",
|
||||
"strip_prefix": false
|
||||
},
|
||||
{
|
||||
"path": "/static/",
|
||||
"target": "http://static-server:8080",
|
||||
"strip_prefix": true,
|
||||
"cache": true,
|
||||
"cache_ttl": 3600
|
||||
},
|
||||
{
|
||||
"host": "api.example.com",
|
||||
"target": "http://api-server:8000"
|
||||
},
|
||||
{
|
||||
"host": "admin.example.com",
|
||||
"target": "http://admin-server:9000"
|
||||
},
|
||||
{
|
||||
"default": true,
|
||||
"target": "http://default-server:8080"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
err := os.WriteFile(filename, []byte(content), 0644)
|
||||
if err != nil {
|
||||
log.Fatalf("创建规则文件失败: %v", err)
|
||||
}
|
||||
}
|
35
cmd/simple_reverse_proxy.go
Normal file
35
cmd/simple_reverse_proxy.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建基本配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = ":8080" // 监听本地8080端口
|
||||
cfg.TargetAddr = "http://example.com" // 转发到example.com
|
||||
|
||||
// 启用一些基本功能
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 启动服务器
|
||||
fmt.Println("反向代理启动在 :8080,转发到 http://example.com")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
40
cmd/simple_reverse_proxy_fixed.go
Normal file
40
cmd/simple_reverse_proxy_fixed.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建基本配置
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
// 配置反向代理
|
||||
cfg.ReverseProxy = true // 启用反向代理模式
|
||||
cfg.ListenAddr = ":8080" // 监听本地8080端口
|
||||
cfg.TargetAddr = "http://example.com" // 转发到example.com
|
||||
|
||||
// 设置HTTP头部选项
|
||||
cfg.PreserveClientIP = true // 保留客户端IP
|
||||
cfg.AddXForwardedFor = true // 添加X-Forwarded-For头
|
||||
cfg.AddXRealIP = true // 添加X-Real-IP头
|
||||
|
||||
// 设置性能选项
|
||||
cfg.EnableCompression = true // 启用压缩
|
||||
cfg.EnableCache = true // 启用缓存
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.New(&goproxy.Options{
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 启动服务器
|
||||
fmt.Println("反向代理启动在 :8080,转发到 http://example.com")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
111
config/config.go
Normal file
111
config/config.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config 代理配置
|
||||
type Config struct {
|
||||
// 基本配置
|
||||
ListenAddr string `json:"listen_addr" yaml:"listen_addr" toml:"listen_addr"` // 监听地址
|
||||
TargetAddr string `json:"target_addr" yaml:"target_addr" toml:"target_addr"` // 目标地址
|
||||
DecryptHTTPS bool `json:"decrypt_https" yaml:"decrypt_https" toml:"decrypt_https"` // 是否启用HTTPS解密
|
||||
CACert string `json:"ca_cert" yaml:"ca_cert" toml:"ca_cert"` // CA证书文件路径(用于生成动态证书)
|
||||
CAKey string `json:"ca_key" yaml:"ca_key" toml:"ca_key"` // CA密钥文件路径(用于生成动态证书)
|
||||
UseECDSA bool `json:"use_ecdsa" yaml:"use_ecdsa" toml:"use_ecdsa"` // 是否使用ECDSA生成证书(默认使用RSA)
|
||||
TLSCert string `json:"tls_cert" yaml:"tls_cert" toml:"tls_cert"` // TLS证书文件路径
|
||||
TLSKey string `json:"tls_key" yaml:"tls_key" toml:"tls_key"` // TLS密钥文件路径
|
||||
|
||||
// 连接配置
|
||||
DisableKeepAlive bool `json:"disable_keep_alive" yaml:"disable_keep_alive" toml:"disable_keep_alive"` // 是否禁用连接复用
|
||||
RequestTimeout time.Duration `json:"request_timeout" yaml:"request_timeout" toml:"request_timeout"` // 请求超时时间
|
||||
EnableCache bool `json:"enable_cache" yaml:"enable_cache" toml:"enable_cache"` // 是否启用响应缓存
|
||||
IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout" toml:"idle_timeout"` // 连接空闲超时时间
|
||||
MaxIdleConns int `json:"max_idle_conns" yaml:"max_idle_conns" toml:"max_idle_conns"` // 最大空闲连接数
|
||||
|
||||
// 缓存配置
|
||||
DNSCacheTTL time.Duration `json:"dns_cache_ttl" yaml:"dns_cache_ttl" toml:"dns_cache_ttl"` // DNS缓存过期时间
|
||||
CacheTTL time.Duration `json:"cache_ttl" yaml:"cache_ttl" toml:"cache_ttl"` // 缓存过期时间
|
||||
|
||||
// 重试配置
|
||||
EnableRetry bool `json:"enable_retry" yaml:"enable_retry" toml:"enable_retry"` // 是否启用重试机制
|
||||
MaxRetries int `json:"max_retries" yaml:"max_retries" toml:"max_retries"` // 最大重试次数
|
||||
BaseBackoff time.Duration `json:"base_backoff" yaml:"base_backoff" toml:"base_backoff"` // 重试间隔基数
|
||||
MaxBackoff time.Duration `json:"max_backoff" yaml:"max_backoff" toml:"max_backoff"` // 最大重试间隔
|
||||
|
||||
// 限流配置
|
||||
RateLimit float64 `json:"rate_limit" yaml:"rate_limit" toml:"rate_limit"` // 每秒请求速率限制
|
||||
|
||||
// 其他配置
|
||||
EnableCORS bool `json:"enable_cors" yaml:"enable_cors" toml:"enable_cors"` // 是否自动添加CORS头
|
||||
|
||||
// 负载均衡配置
|
||||
EnableLoadBalancing bool `json:"enable_load_balancing" yaml:"enable_load_balancing" toml:"enable_load_balancing"` // 是否启用负载均衡
|
||||
Backends []string `json:"backends" yaml:"backends" toml:"backends"` // 负载均衡后端列表
|
||||
EnableRateLimit bool `json:"enable_rate_limit" yaml:"enable_rate_limit" toml:"enable_rate_limit"` // 是否启用限流
|
||||
MaxBurst int `json:"max_burst" yaml:"max_burst" toml:"max_burst"` // 并发请求峰值限制
|
||||
MaxConnections int `json:"max_connections" yaml:"max_connections" toml:"max_connections"` // 最大连接数
|
||||
EnableConnectionPool bool `json:"enable_connection_pool" yaml:"enable_connection_pool" toml:"enable_connection_pool"` // 是否启用连接池
|
||||
ConnectionPoolSize int `json:"connection_pool_size" yaml:"connection_pool_size" toml:"connection_pool_size"` // 连接池大小
|
||||
EnableHealthCheck bool `json:"enable_health_check" yaml:"enable_health_check" toml:"enable_health_check"` // 是否启用健康检查
|
||||
HealthCheckInterval time.Duration `json:"health_check_interval" yaml:"health_check_interval" toml:"health_check_interval"` // 健康检查间隔时间
|
||||
HealthCheckTimeout time.Duration `json:"health_check_timeout" yaml:"health_check_timeout" toml:"health_check_timeout"` // 健康检查超时时间
|
||||
EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics" toml:"enable_metrics"` // 是否启用监控指标
|
||||
EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing" toml:"enable_tracing"` // 是否启用请求追踪
|
||||
WebSocketIntercept bool `json:"websocket_intercept" yaml:"websocket_intercept" toml:"websocket_intercept"` // 是否拦截WebSocket
|
||||
ReverseProxy bool `json:"reverse_proxy" yaml:"reverse_proxy" toml:"reverse_proxy"` // 是否作为反向代理
|
||||
ReverseProxyRulesFile string `json:"reverse_proxy_rules_file" yaml:"reverse_proxy_rules_file" toml:"reverse_proxy_rules_file"` // 反向代理规则文件路径
|
||||
PreserveClientIP bool `json:"preserve_client_ip" yaml:"preserve_client_ip" toml:"preserve_client_ip"` // 是否保留客户端IP
|
||||
EnableCompression bool `json:"enable_compression" yaml:"enable_compression" toml:"enable_compression"` // 是否启用压缩
|
||||
RewriteHostHeader bool `json:"rewrite_host_header" yaml:"rewrite_host_header" toml:"rewrite_host_header"` // 重写Host头
|
||||
AddXForwardedFor bool `json:"add_x_forwarded_for" yaml:"add_x_forwarded_for" toml:"add_x_forwarded_for"` // 是否添加X-Forwarded-For头
|
||||
AddXRealIP bool `json:"add_x_real_ip" yaml:"add_x_real_ip" toml:"add_x_real_ip"` // 是否添加X-Real-IP头
|
||||
SupportWebSocketUpgrade bool `json:"support_websocket_upgrade" yaml:"support_websocket_upgrade" toml:"support_websocket_upgrade"` // 是否支持Websocket升级
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过TLS证书验证
|
||||
Logger *slog.Logger `json:"-" yaml:"-" toml:"-"` // 日志记录器
|
||||
}
|
||||
|
||||
// DefaultConfig 返回默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
ListenAddr: ":8080",
|
||||
DecryptHTTPS: false,
|
||||
UseECDSA: false,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
IdleTimeout: 90 * time.Second,
|
||||
EnableCache: false,
|
||||
MaxIdleConns: 100,
|
||||
DNSCacheTTL: 5 * time.Minute,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
EnableRetry: true,
|
||||
MaxRetries: 3,
|
||||
BaseBackoff: time.Second,
|
||||
MaxBackoff: 10 * time.Second,
|
||||
RateLimit: 0, // 0 表示不限流
|
||||
EnableCORS: true,
|
||||
EnableLoadBalancing: false,
|
||||
Backends: []string{},
|
||||
EnableRateLimit: false,
|
||||
MaxBurst: 50,
|
||||
MaxConnections: 1000,
|
||||
EnableConnectionPool: true,
|
||||
ConnectionPoolSize: 100,
|
||||
EnableHealthCheck: false,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
HealthCheckTimeout: 5 * time.Second,
|
||||
EnableMetrics: false,
|
||||
EnableTracing: false,
|
||||
WebSocketIntercept: false,
|
||||
ReverseProxy: false,
|
||||
ReverseProxyRulesFile: "",
|
||||
PreserveClientIP: true,
|
||||
EnableCompression: false,
|
||||
RewriteHostHeader: false,
|
||||
AddXForwardedFor: true,
|
||||
AddXRealIP: true,
|
||||
SupportWebSocketUpgrade: true,
|
||||
InsecureSkipVerify: false,
|
||||
Logger: slog.Default(),
|
||||
}
|
||||
}
|
117
config/hot_reload.go
Normal file
117
config/hot_reload.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// HotReloadConfig 热更新配置
|
||||
type HotReloadConfig struct {
|
||||
// 配置文件路径
|
||||
ConfigPath string
|
||||
// 配置更新回调函数
|
||||
OnUpdate func(*Config)
|
||||
// 配置锁
|
||||
mu sync.RWMutex
|
||||
// 当前配置
|
||||
current *Config
|
||||
// 文件监视器
|
||||
watcher *fsnotify.Watcher
|
||||
// 停止信号
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewHotReloadConfig 创建热更新配置
|
||||
func NewHotReloadConfig(configPath string, onUpdate func(*Config)) (*HotReloadConfig, error) {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hrc := &HotReloadConfig{
|
||||
ConfigPath: configPath,
|
||||
OnUpdate: onUpdate,
|
||||
watcher: watcher,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
// 加载初始配置
|
||||
if err := hrc.loadConfig(); err != nil {
|
||||
watcher.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 启动文件监视
|
||||
go hrc.watch()
|
||||
|
||||
return hrc, nil
|
||||
}
|
||||
|
||||
// loadConfig 加载配置
|
||||
func (hrc *HotReloadConfig) loadConfig() error {
|
||||
data, err := os.ReadFile(hrc.ConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hrc.mu.Lock()
|
||||
hrc.current = &config
|
||||
hrc.mu.Unlock()
|
||||
|
||||
if hrc.OnUpdate != nil {
|
||||
hrc.OnUpdate(&config)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// watch 监视配置文件变化
|
||||
func (hrc *HotReloadConfig) watch() {
|
||||
// 添加配置文件目录到监视
|
||||
configDir := filepath.Dir(hrc.ConfigPath)
|
||||
if err := hrc.watcher.Add(configDir); err != nil {
|
||||
slog.Error("添加配置目录到监视失败", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case event := <-hrc.watcher.Events:
|
||||
if event.Name == hrc.ConfigPath {
|
||||
// 等待文件写入完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if err := hrc.loadConfig(); err != nil {
|
||||
slog.Error("重新加载配置失败", "error", err)
|
||||
}
|
||||
}
|
||||
case err := <-hrc.watcher.Errors:
|
||||
slog.Error("配置文件监视错误", "error", err)
|
||||
case <-hrc.stopChan:
|
||||
hrc.watcher.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get 获取当前配置
|
||||
func (hrc *HotReloadConfig) Get() *Config {
|
||||
hrc.mu.RLock()
|
||||
defer hrc.mu.RUnlock()
|
||||
return hrc.current
|
||||
}
|
||||
|
||||
// Stop 停止热更新
|
||||
func (hrc *HotReloadConfig) Stop() {
|
||||
close(hrc.stopChan)
|
||||
}
|
134
conn_buffer.go
Normal file
134
conn_buffer.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnBuffer 连接缓冲区
|
||||
// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口
|
||||
type ConnBuffer struct {
|
||||
// 底层连接
|
||||
conn net.Conn
|
||||
// 缓冲读取器
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// NewConnBuffer 创建连接缓冲区
|
||||
func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer {
|
||||
if reader == nil {
|
||||
reader = bufio.NewReader(conn)
|
||||
}
|
||||
return &ConnBuffer{
|
||||
conn: conn,
|
||||
reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
// Read 从连接读取数据
|
||||
func (c *ConnBuffer) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// Write 向连接写入数据
|
||||
func (c *ConnBuffer) Write(b []byte) (int, error) {
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *ConnBuffer) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr 获取本地地址
|
||||
func (c *ConnBuffer) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (c *ConnBuffer) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline 设置读写超时
|
||||
func (c *ConnBuffer) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline 设置读取超时
|
||||
func (c *ConnBuffer) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline 设置写入超时
|
||||
func (c *ConnBuffer) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// BufferReader 获取缓冲读取器
|
||||
func (c *ConnBuffer) BufferReader() *bufio.Reader {
|
||||
return c.reader
|
||||
}
|
||||
|
||||
// Peek 查看缓冲区中的数据,但不消费
|
||||
func (c *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||
return c.reader.Peek(n)
|
||||
}
|
||||
|
||||
// ReadByte 读取一个字节
|
||||
func (c *ConnBuffer) ReadByte() (byte, error) {
|
||||
return c.reader.ReadByte()
|
||||
}
|
||||
|
||||
// UnreadByte 将最后读取的字节放回缓冲区
|
||||
func (c *ConnBuffer) UnreadByte() error {
|
||||
return c.reader.UnreadByte()
|
||||
}
|
||||
|
||||
// ReadLine 读取一行数据
|
||||
func (c *ConnBuffer) ReadLine() (string, error) {
|
||||
line, isPrefix, err := c.reader.ReadLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 如果一行数据没有读取完整,继续读取
|
||||
if isPrefix {
|
||||
var buf []byte
|
||||
buf = append(buf, line...)
|
||||
for isPrefix && err == nil {
|
||||
line, isPrefix, err = c.reader.ReadLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, line...)
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
return string(line), nil
|
||||
}
|
||||
|
||||
// ReadN 读取指定字节数的数据
|
||||
func (c *ConnBuffer) ReadN(n int) ([]byte, error) {
|
||||
buf := make([]byte, n)
|
||||
_, err := io.ReadFull(c.reader, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (c *ConnBuffer) Flush() error {
|
||||
// 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置连接缓冲区
|
||||
func (c *ConnBuffer) Reset(conn net.Conn) {
|
||||
c.conn = conn
|
||||
c.reader = bufio.NewReader(conn)
|
||||
}
|
132
context.go
Normal file
132
context.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Context 代理上下文
|
||||
// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据
|
||||
type Context struct {
|
||||
// 原始请求
|
||||
Req *http.Request
|
||||
// 请求开始时间
|
||||
StartTime time.Time
|
||||
// 上下文数据,用于在各个处理阶段传递数据
|
||||
Data map[interface{}]interface{}
|
||||
// 是否是隧道代理
|
||||
TunnelProxy bool
|
||||
// 请求ID
|
||||
RequestID string
|
||||
// 目标地址
|
||||
TargetAddr string
|
||||
// 上级代理地址
|
||||
ParentProxyURL *url.URL
|
||||
// 是否中断执行
|
||||
abort bool
|
||||
// 请求标签,用于标记请求类型
|
||||
Tags []string
|
||||
// 是否已中止
|
||||
aborted bool
|
||||
// 互斥锁
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// IsHTTPS 是否是HTTPS请求
|
||||
func (c *Context) IsHTTPS() bool {
|
||||
return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect
|
||||
}
|
||||
|
||||
// defaultPorts 默认端口映射
|
||||
var defaultPorts = map[string]string{
|
||||
"https": "443",
|
||||
"http": "80",
|
||||
"": "80",
|
||||
}
|
||||
|
||||
// WebSocketURL 获取WebSocket URL
|
||||
func (c *Context) WebSocketURL() *url.URL {
|
||||
u := *c.Req.URL
|
||||
if c.IsHTTPS() {
|
||||
u.Scheme = "wss"
|
||||
} else {
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// Addr 获取请求地址
|
||||
func (c *Context) Addr() string {
|
||||
addr := c.Req.Host
|
||||
|
||||
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// AddTag 添加请求标签
|
||||
func (c *Context) AddTag(tag string) {
|
||||
c.Tags = append(c.Tags, tag)
|
||||
}
|
||||
|
||||
// HasTag 检查是否包含指定标签
|
||||
func (c *Context) HasTag(tag string) bool {
|
||||
for _, t := range c.Tags {
|
||||
if t == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Abort 中断执行
|
||||
func (c *Context) Abort() {
|
||||
c.aborted = true
|
||||
}
|
||||
|
||||
// IsAborted 是否已中断执行
|
||||
func (c *Context) IsAborted() bool {
|
||||
return c.aborted
|
||||
}
|
||||
|
||||
// Reset 重置上下文
|
||||
func (c *Context) Reset(req *http.Request) {
|
||||
c.Req = req
|
||||
c.StartTime = time.Now()
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
c.abort = false
|
||||
c.TunnelProxy = false
|
||||
c.Tags = make([]string, 0)
|
||||
c.RequestID = ""
|
||||
c.TargetAddr = ""
|
||||
c.ParentProxyURL = nil
|
||||
c.aborted = false
|
||||
}
|
||||
|
||||
// Set 设置数据
|
||||
func (c *Context) Set(key, value interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.Data == nil {
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
}
|
||||
c.Data[key] = value
|
||||
}
|
||||
|
||||
// Get 获取数据
|
||||
func (c *Context) Get(key interface{}) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
val, ok := c.Data[key]
|
||||
return val, ok
|
||||
}
|
225
coverage.out
Normal file
225
coverage.out
Normal file
@@ -0,0 +1,225 @@
|
||||
mode: set
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:24.55,26.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:26.16,28.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:29.2,33.42 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:33.42,35.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:37.2,37.12 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:41.56,43.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:43.16,45.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:46.2,55.47 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:55.47,57.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:59.2,59.20 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:63.61,65.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:65.16,67.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:68.2,78.21 5 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:78.21,83.49 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:83.49,84.12 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:87.3,88.22 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:88.22,89.12 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:92.3,97.21 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:97.21,98.12 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:101.3,106.20 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:106.20,108.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:110.3,110.34 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:110.34,112.38 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:112.38,113.10 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:117.4,117.34 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:121.2,121.38 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:121.38,123.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:125.2,125.20 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:129.63,131.20 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:131.20,133.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:133.8,135.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:137.2,145.17 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/config.go:149.43,151.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:19.43,26.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:29.61,32.2 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:35.65,38.2 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:41.55,44.2 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:47.94,50.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:50.16,54.3 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:57.2,58.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:58.16,60.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:63.2,63.17 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:63.17,65.17 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:65.17,67.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:70.3,70.24 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:70.24,72.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:74.3,74.21 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:78.2,84.71 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:88.66,90.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:93.26,95.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:98.98,100.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:103.52,105.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/dialer.go:108.111,110.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:16.39,21.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:24.57,29.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:32.49,34.30 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:34.30,36.22 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:36.22,38.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:41.3,45.17 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:45.17,47.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:49.3,49.44 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:53.2,53.28 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:57.36,58.16 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:58.16,60.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:61.2,61.13 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:65.70,66.16 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:66.16,68.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/endpoint.go:69.2,69.48 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:73.53,84.33 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:84.33,86.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:88.2,88.10 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:92.63,94.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:94.16,96.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:97.2,97.25 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:101.91,106.64 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:106.64,109.3 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:112.2,112.60 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:112.60,115.3 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:118.2,118.36 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:118.36,119.41 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:119.41,122.4 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:124.2,127.16 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:127.16,129.17 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:129.17,131.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:134.3,135.28 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:135.28,136.39 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:136.39,138.10 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:142.3,142.15 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:142.15,144.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:147.3,148.22 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:148.22,150.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:153.3,160.23 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:163.2,163.76 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:167.91,168.25 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:168.25,170.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:173.2,173.25 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:173.25,175.22 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:175.22,177.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:178.3,178.18 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:182.2,183.22 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:184.18,186.68 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:187.14,189.68 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:190.22,192.26 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:195.2,195.40 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:195.40,197.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:198.2,198.17 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:202.65,206.39 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:206.39,207.48 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:207.48,209.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:212.2,212.12 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:216.64,218.41 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:218.41,220.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:223.2,223.38 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:223.38,225.29 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:225.29,226.12 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:230.3,230.38 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:230.38,232.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:235.2,235.13 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:239.53,241.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:244.71,245.28 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:245.28,247.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:249.2,253.50 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:253.50,255.31 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:255.31,256.54 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:256.54,258.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:260.3,260.54 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:261.8,263.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:264.2,264.12 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:268.71,270.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:273.89,274.28 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:274.28,276.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:279.2,279.44 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:279.44,281.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:284.2,290.39 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:290.39,291.37 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:291.37,294.37 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:294.37,295.55 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:295.55,297.6 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:300.4,301.14 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:306.2,315.12 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:319.52,324.34 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:324.34,327.3 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:330.2,330.39 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:330.39,331.27 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:331.27,335.4 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:338.2,338.44 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:342.74,347.42 3 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:347.42,349.31 2 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:349.31,350.36 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:350.36,352.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:354.3,354.29 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:354.29,356.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:356.9,358.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:359.3,359.13 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:363.2,363.39 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:363.39,364.27 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:364.27,366.37 2 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:366.37,367.37 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:367.37,369.6 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:371.4,371.30 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:371.30,373.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:373.10,375.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:376.4,376.14 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:380.2,380.44 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:384.34,391.2 5 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:397.41,398.33 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:398.33,400.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:404.40,405.33 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:405.33,407.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:411.67,412.33 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:412.33,414.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:418.71,422.35 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:422.35,424.34 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:424.34,426.18 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:426.18,428.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:430.4,430.39 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:430.39,432.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:435.4,436.41 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:436.41,437.29 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:437.29,439.39 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:439.39,440.57 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:440.57,442.13 2 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:445.6,445.16 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:445.16,447.7 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:448.6,448.11 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:452.4,452.14 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:452.14,460.5 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:462.9,465.18 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:465.18,467.5 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:469.4,469.39 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:469.39,471.5 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:474.4,474.52 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:474.52,476.33 2 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:476.33,477.56 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:477.56,479.12 2 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:482.5,482.15 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:482.15,484.6 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:485.10,487.5 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:492.2,494.12 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:498.46,500.46 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:500.46,509.39 5 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:509.39,511.4 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:514.3,514.45 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:519.41,521.29 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:521.29,522.18 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:522.18,524.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/resolver.go:526.2,526.14 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:10.47,11.37 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:11.37,13.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:15.2,16.21 2 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:16.21,18.3 1 0
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:20.2,23.54 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:23.54,25.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:27.2,27.54 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:27.54,29.3 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:31.2,31.13 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:41.72,45.2 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:48.65,52.42 3 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:52.42,53.35 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:53.35,55.4 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:58.2,58.70 1 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:62.62,67.2 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:70.57,75.2 4 1
|
||||
github.com/darkit/goproxy/pkg/dns/wildcard.go:78.42,83.2 4 1
|
120
delegate.go
Normal file
120
delegate.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Delegate 代理委托接口
|
||||
// 定义了代理处理请求的各个阶段的回调方法
|
||||
type Delegate interface {
|
||||
// Connect 连接事件
|
||||
Connect(ctx *Context, rw http.ResponseWriter)
|
||||
|
||||
// Auth 认证事件
|
||||
Auth(ctx *Context, rw http.ResponseWriter)
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
BeforeRequest(ctx *Context)
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||
|
||||
// WebSocketSendMessage websocket发送消息拦截
|
||||
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
|
||||
// WebSocketReceiveMessage websocket接收消息拦截
|
||||
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
|
||||
// ParentProxy 获取上级代理
|
||||
ParentProxy(req *http.Request) (*url.URL, error)
|
||||
|
||||
// ErrorLog 错误日志
|
||||
ErrorLog(err error)
|
||||
|
||||
// Finish 完成事件
|
||||
Finish(ctx *Context)
|
||||
|
||||
// 以下是反向代理相关的方法
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
// 在反向代理模式下,根据请求确定应该转发到哪个后端服务器
|
||||
ResolveBackend(req *http.Request) (string, error)
|
||||
|
||||
// ModifyRequest 修改请求
|
||||
// 在反向代理模式下,可以修改发往后端服务器的请求
|
||||
ModifyRequest(req *http.Request)
|
||||
|
||||
// ModifyResponse 修改响应
|
||||
// 在反向代理模式下,可以修改来自后端服务器的响应
|
||||
ModifyResponse(resp *http.Response) error
|
||||
|
||||
// HandleError 处理错误
|
||||
// 在反向代理模式下,可以自定义错误处理逻辑
|
||||
HandleError(rw http.ResponseWriter, req *http.Request, err error)
|
||||
}
|
||||
|
||||
// DefaultDelegate 默认代理委托
|
||||
type DefaultDelegate struct{}
|
||||
|
||||
// Connect 连接事件
|
||||
func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
// Auth 认证事件
|
||||
func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *DefaultDelegate) BeforeRequest(ctx *Context) {}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// ParentProxy 获取上级代理
|
||||
func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||
// 默认实现不使用上级代理
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// WebSocketSendMessage websocket发送消息拦截
|
||||
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// WebSocketReceiveMessage websocket接收消息拦截
|
||||
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// ErrorLog 错误日志
|
||||
func (d *DefaultDelegate) ErrorLog(err error) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// Finish 完成事件
|
||||
func (d *DefaultDelegate) Finish(ctx *Context) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
// 默认实现返回请求中的主机
|
||||
return req.Host, nil
|
||||
}
|
||||
|
||||
// ModifyRequest 修改请求
|
||||
func (d *DefaultDelegate) ModifyRequest(req *http.Request) {
|
||||
// 默认实现不做任何修改
|
||||
}
|
||||
|
||||
// ModifyResponse 修改响应
|
||||
func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error {
|
||||
// 默认实现不做任何修改
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleError 处理错误
|
||||
func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
http.Error(rw, err.Error(), http.StatusBadGateway)
|
||||
}
|
29
examples/auth_proxy/main.go
Normal file
29
examples/auth_proxy/main.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/auth"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建认证系统
|
||||
auths := auth.NewAuth("1234")
|
||||
auths.Authenticate("admin", "password")
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithAuth(auths),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("认证代理服务器启动在 :8080")
|
||||
log.Println("认证配置:")
|
||||
log.Printf("- 用户名: admin\n")
|
||||
log.Printf("- 密码: password\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
59
examples/custom_dns_https_proxy/main.go
Normal file
59
examples/custom_dns_https_proxy/main.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
)
|
||||
|
||||
// CustomDNSHTTPSDelegate 自定义 DNS HTTPS 代理委托
|
||||
type CustomDNSHTTPSDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
dnsResolver *dns.CustomResolver
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *CustomDNSHTTPSDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.dnsResolver.Resolve(req.URL.Host)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建证书缓存
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
|
||||
// 创建自定义 DNS 解析器
|
||||
resolver := dns.NewResolver(dns.WithFallback(true))
|
||||
|
||||
// 添加 DNS 记录
|
||||
resolver.LoadFromMap(map[string]string{
|
||||
"example.com": "http://backend1.example.com",
|
||||
"test.com": "http://backend2.test.com",
|
||||
})
|
||||
|
||||
// 创建自定义 DNS HTTPS 代理委托
|
||||
delegate := &CustomDNSHTTPSDelegate{
|
||||
dnsResolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
goproxy.WithDecryptHTTPS(certCache),
|
||||
goproxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
goproxy.WithEnableECDSA(true),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("自定义 DNS HTTPS 代理服务器启动在 :8443")
|
||||
log.Println("配置说明:")
|
||||
log.Printf("- 支持 HTTPS 解密\n")
|
||||
log.Printf("- 使用 ECDSA 证书\n")
|
||||
log.Println("DNS 配置:")
|
||||
log.Printf("- example.com -> backend1.example.com\n")
|
||||
log.Printf("- test.com -> backend2.test.com\n")
|
||||
if err := http.ListenAndServeTLS(":8443", "server.crt", "server.key", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
49
examples/custom_dns_proxy/main.go
Normal file
49
examples/custom_dns_proxy/main.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package custom_dns_proxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
)
|
||||
|
||||
// CustomDNSDelegate 自定义 DNS 代理委托
|
||||
type CustomDNSDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
dnsResolver dns.Resolver
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *CustomDNSDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.dnsResolver.Resolve(req.URL.Host)
|
||||
}
|
||||
|
||||
// RunCustomDNSProxy 运行自定义 DNS 代理服务器
|
||||
func RunCustomDNSProxy() error {
|
||||
// 创建自定义 DNS 解析器
|
||||
resolver := dns.NewResolver(dns.WithFallback(true))
|
||||
|
||||
// 添加 DNS 记录
|
||||
resolver.LoadFromMap(map[string]string{
|
||||
"example.com": "http://backend1.example.com",
|
||||
"test.com": "http://backend2.test.com",
|
||||
})
|
||||
|
||||
// 创建自定义 DNS 代理委托
|
||||
delegate := &CustomDNSDelegate{
|
||||
dnsResolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("自定义 DNS 代理服务器启动在 :8080")
|
||||
log.Println("DNS 配置:")
|
||||
log.Printf("- example.com -> backend1.example.com\n")
|
||||
log.Printf("- test.com -> backend2.test.com\n")
|
||||
return http.ListenAndServe(":8080", proxy)
|
||||
}
|
124
examples/custom_dns_reverse_proxy/README.md
Normal file
124
examples/custom_dns_reverse_proxy/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# 自定义DNS反向代理示例
|
||||
|
||||
这个示例展示了如何使用 goproxy 创建一个支持自定义 DNS 解析和负载均衡的反向代理服务器。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 支持自定义 DNS 解析规则
|
||||
- 支持多后端服务器负载均衡
|
||||
- 支持通配符域名解析
|
||||
- 支持 DNS 缓存
|
||||
- 支持优雅关闭
|
||||
- 支持配置文件
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 编译示例:
|
||||
|
||||
```bash
|
||||
go build -o custom_dns_proxy main.go
|
||||
```
|
||||
|
||||
2. 运行示例:
|
||||
|
||||
```bash
|
||||
# 使用默认配置
|
||||
./custom_dns_proxy
|
||||
|
||||
# 指定配置文件
|
||||
./custom_dns_proxy -config config.json
|
||||
```
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 命令行参数
|
||||
|
||||
- `-addr`: 监听地址,默认为 ":8080"
|
||||
- `-config`: 配置文件路径,默认为 "config.json"
|
||||
|
||||
### 配置文件
|
||||
|
||||
配置文件支持 JSON 格式,包含以下主要配置项:
|
||||
|
||||
```json
|
||||
{
|
||||
"dns": {
|
||||
"records": {
|
||||
"api.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
],
|
||||
"*.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080"
|
||||
]
|
||||
},
|
||||
"fallback": true,
|
||||
"ttl": 300,
|
||||
"strategy": "round-robin"
|
||||
},
|
||||
"proxy": {
|
||||
"listen_addr": ":8080",
|
||||
"enable_https": false,
|
||||
"enable_websocket": true,
|
||||
"enable_compression": true,
|
||||
"enable_cors": true,
|
||||
"preserve_client_ip": true,
|
||||
"add_x_forwarded_for": true,
|
||||
"add_x_real_ip": true,
|
||||
"insecure_skip_verify": false,
|
||||
"enable_health_check": false,
|
||||
"health_check_interval": 30,
|
||||
"health_check_timeout": 5,
|
||||
"enable_retry": true,
|
||||
"max_retries": 3,
|
||||
"retry_backoff": 1,
|
||||
"max_retry_backoff": 10,
|
||||
"enable_metrics": true,
|
||||
"enable_tracing": false,
|
||||
"websocket_intercept": false,
|
||||
"dns_cache_ttl": 300,
|
||||
"enable_cache": true,
|
||||
"cache_ttl": 300,
|
||||
"enable_connection_pool": true,
|
||||
"connection_pool_size": 100,
|
||||
"idle_timeout": 60,
|
||||
"request_timeout": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 负载均衡策略
|
||||
|
||||
当前示例使用轮询(Round Robin)负载均衡策略,支持以下特性:
|
||||
|
||||
1. 多后端服务器轮询
|
||||
2. 自动故障转移
|
||||
3. 健康检查
|
||||
4. 连接池管理
|
||||
|
||||
## 示例用法
|
||||
|
||||
1. 启动代理服务器:
|
||||
|
||||
```bash
|
||||
./custom_dns_proxy -addr :8080
|
||||
```
|
||||
|
||||
2. 使用 curl 测试:
|
||||
|
||||
```bash
|
||||
# 测试 API 访问
|
||||
curl http://api.example.com:8080/api/v1/users
|
||||
|
||||
# 测试通配符域名
|
||||
curl http://test.example.com:8080/api/v1/users
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 确保配置文件中的 DNS 记录正确配置
|
||||
2. 确保后端服务器正常运行
|
||||
3. 建议在生产环境中启用 HTTPS
|
||||
4. 根据实际需求调整连接池大小和超时设置
|
23
examples/custom_dns_reverse_proxy/config.json
Normal file
23
examples/custom_dns_reverse_proxy/config.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"dns": {
|
||||
"records": {
|
||||
"*.shabi.in": [
|
||||
"192.168.1.1:80",
|
||||
"192.168.1.6:80",
|
||||
"192.168.1.252:80"
|
||||
],
|
||||
"api.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
],
|
||||
"www.*.shabi.in": [
|
||||
"192.168.1.1:80",
|
||||
"192.168.1.252:80"
|
||||
]
|
||||
},
|
||||
"fallback": true,
|
||||
"ttl": 300,
|
||||
"strategy": "round-robin"
|
||||
}
|
||||
}
|
188
examples/custom_dns_reverse_proxy/main.go
Normal file
188
examples/custom_dns_reverse_proxy/main.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"syscall"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
"github.com/darkit/goproxy/pkg/reverse"
|
||||
)
|
||||
|
||||
// 命令行参数
|
||||
var configFile = flag.String("config", "config.json", "配置文件路径")
|
||||
|
||||
// Config 配置文件结构
|
||||
type Config struct {
|
||||
DNS struct {
|
||||
Records map[string][]string `json:"records"` // 普通记录和泛解析记录
|
||||
Fallback bool `json:"fallback"` // 是否回退到系统DNS
|
||||
TTL int `json:"ttl"` // 缓存TTL,单位为秒
|
||||
Strategy string `json:"strategy"` // 负载均衡策略:round-robin, random, first-available
|
||||
} `json:"dns" yaml:"dns" toml:"dns"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
// 创建日志记录器
|
||||
logger := slog.Default()
|
||||
|
||||
// 加载配置文件
|
||||
config := &Config{}
|
||||
if *configFile != "" {
|
||||
data, err := os.ReadFile(*configFile)
|
||||
if err != nil {
|
||||
logger.Error("读取配置文件失败", "file", *configFile, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, config); err != nil {
|
||||
logger.Error("解析配置文件失败", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Info("成功加载配置文件", "file", *configFile)
|
||||
}
|
||||
|
||||
// 选择负载均衡策略
|
||||
//var strategy dns.LoadBalanceStrategy
|
||||
//switch config.DNS.Strategy {
|
||||
//case "random":
|
||||
// strategy = dns.Random
|
||||
//case "first-available":
|
||||
// strategy = dns.FirstAvailable
|
||||
//default:
|
||||
// strategy = dns.RoundRobin
|
||||
//}
|
||||
|
||||
// 创建反向代理配置
|
||||
cfg := reverse.DefaultConfig()
|
||||
cfg.TargetAddr = "www.shabi.in"
|
||||
//cfg.DNSResolver = dns.NewResolver(
|
||||
// dns.WithFallback(config.DNS.Fallback), // 是否回退到系统DNS
|
||||
// dns.WithLoadBalanceStrategy(strategy), // 使用配置的负载均衡策略
|
||||
// dns.WithTTL(time.Duration(config.DNS.TTL)*time.Second), // 设置DNS缓存TTL
|
||||
//) // 创建DNS解析器
|
||||
|
||||
// 添加DNS记录,支持多个地址
|
||||
for domain, addrs := range config.DNS.Records {
|
||||
for _, addr := range addrs {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
// 地址没有端口,将整个地址作为IP使用
|
||||
if !dns.IsWildcardDomain(domain) {
|
||||
cfg.DNSResolver.Add(domain, addr)
|
||||
} else {
|
||||
cfg.DNSResolver.AddWildcard(domain, addr)
|
||||
}
|
||||
} else {
|
||||
tPort, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
logger.Error("无效的端口", "domain", domain, "addr", addr, "port", port, "error", err)
|
||||
continue
|
||||
}
|
||||
if !dns.IsWildcardDomain(domain) {
|
||||
cfg.DNSResolver.AddWithPort(domain, host, tPort)
|
||||
} else {
|
||||
cfg.DNSResolver.AddWildcardWithPort(domain, host, tPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建反向代理服务器
|
||||
proxy, err := reverse.New(cfg)
|
||||
if err != nil {
|
||||
logger.Error("创建反向代理服务器失败", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: cfg.ListenAddr,
|
||||
Handler: proxy,
|
||||
}
|
||||
|
||||
// 启动服务器
|
||||
go func() {
|
||||
logger.Info("启动反向代理服务器", "addr", cfg.ListenAddr)
|
||||
if err = server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
logger.Error("服务器运行错误", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
// 优雅关闭
|
||||
logger.Info("正在关闭服务器...")
|
||||
if err := server.Close(); err != nil {
|
||||
logger.Error("关闭服务器失败", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 示例配置文件 config.json:
|
||||
/*
|
||||
{
|
||||
"dns": {
|
||||
"records": {
|
||||
"api.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
],
|
||||
"*.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080"
|
||||
]
|
||||
},
|
||||
"fallback": true,
|
||||
"ttl": 300,
|
||||
"strategy": "round-robin"
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// 示例配置文件 config.yaml:
|
||||
/*
|
||||
dns:
|
||||
records:
|
||||
api.example.com:
|
||||
- 192.168.1.1:8080
|
||||
- 192.168.1.2:8080
|
||||
- 192.168.1.3:8080
|
||||
"*.example.com":
|
||||
- 192.168.1.1:8080
|
||||
- 192.168.1.2:8080
|
||||
fallback: true
|
||||
ttl: 300
|
||||
strategy: round-robin
|
||||
*/
|
||||
|
||||
// 示例配置文件 config.toml:
|
||||
/*
|
||||
[dns]
|
||||
[dns.records]
|
||||
api.example.com = [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
]
|
||||
"*.example.com" = [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080"
|
||||
]
|
||||
fallback = true
|
||||
ttl = 300
|
||||
strategy = round-robin
|
||||
*/
|
27
examples/custom_port_proxy/main.go
Normal file
27
examples/custom_port_proxy/main.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
var port = flag.String("port", "8080", "代理服务器端口")
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithConfig(config.DefaultConfig()),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Printf("代理服务器启动在 :%s\n", *port)
|
||||
if err := http.ListenAndServe(":"+*port, proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
19
examples/goproxy/main.go
Normal file
19
examples/goproxy/main.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理
|
||||
p := goproxy.New(nil)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
41
examples/http_to_https_proxy/main.go
Normal file
41
examples/http_to_https_proxy/main.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
// HTTPToHTTPSDelegate HTTP 到 HTTPS 代理委托
|
||||
type HTTPToHTTPSDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *HTTPToHTTPSDelegate) BeforeRequest(ctx *goproxy.Context) {
|
||||
// 将 HTTP 请求转换为 HTTPS
|
||||
if ctx.Req.URL.Scheme == "http" {
|
||||
ctx.Req.URL.Scheme = "https"
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建 HTTP 到 HTTPS 代理委托
|
||||
delegate := &HTTPToHTTPSDelegate{}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
goproxy.WithDecryptHTTPS(&goproxy.MemCertCache{}),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("HTTP 到 HTTPS 代理服务器启动在 :8080")
|
||||
log.Println("配置说明:")
|
||||
log.Printf("- 自动将 HTTP 请求转换为 HTTPS\n")
|
||||
log.Printf("- 支持 HTTPS 解密\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
31
examples/https_to_https_proxy/main.go
Normal file
31
examples/https_to_https_proxy/main.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建证书缓存
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDecryptHTTPS(certCache),
|
||||
goproxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
goproxy.WithEnableECDSA(true),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("HTTPS 到 HTTPS 代理服务器启动在 :8443")
|
||||
log.Println("配置说明:")
|
||||
log.Printf("- 支持 HTTPS 解密\n")
|
||||
log.Printf("- 使用 ECDSA 证书\n")
|
||||
log.Printf("- 最低 TLS 版本: 1.2\n")
|
||||
log.Printf("- 最高 TLS 版本: 1.3\n")
|
||||
if err := http.ListenAndServeTLS(":8443", "server.crt", "server.key", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
89
examples/other/README.md
Normal file
89
examples/other/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# GoProxy 示例
|
||||
|
||||
本目录包含了 GoProxy 库的各种使用示例,展示了不同的代理功能和配置选项。
|
||||
|
||||
## 示例列表
|
||||
|
||||
1. [forward_proxy.go](forward_proxy.go) - 基本正向代理
|
||||
- 展示最基本的正向代理功能
|
||||
- 适用于简单的 HTTP 代理需求
|
||||
|
||||
2. [https_proxy.go](https_proxy.go) - HTTPS 解密代理
|
||||
- 支持 HTTPS 解密(中间人模式)
|
||||
- 需要配置 CA 证书
|
||||
- 支持 ECDSA 证书算法
|
||||
|
||||
3. [reverse_proxy.go](reverse_proxy.go) - 反向代理
|
||||
- 支持反向代理功能
|
||||
- 支持 URL 重写
|
||||
- 支持 X-Forwarded-For 和 X-Real-IP 头
|
||||
|
||||
4. [custom_delegate.go](custom_delegate.go) - 自定义委托代理
|
||||
- 展示如何自定义代理行为
|
||||
- 支持请求和响应的自定义处理
|
||||
- 包含详细的日志记录
|
||||
|
||||
5. [load_balance.go](load_balance.go) - 负载均衡代理
|
||||
- 支持多后端服务器
|
||||
- 使用轮询算法进行负载均衡
|
||||
- 包含健康检查功能
|
||||
|
||||
6. [metrics_proxy.go](metrics_proxy.go) - 监控指标代理
|
||||
- 支持 Prometheus 格式的监控指标
|
||||
- 提供详细的性能统计
|
||||
- 包含独立的指标服务器
|
||||
|
||||
7. [cache_proxy.go](cache_proxy.go) - 缓存代理
|
||||
- 支持 HTTP 响应缓存
|
||||
- 使用内存缓存存储
|
||||
- 可配置缓存策略
|
||||
|
||||
8. [auth_proxy.go](auth_proxy.go) - 认证代理
|
||||
- 支持基本认证
|
||||
- 可配置用户名和密码
|
||||
- 保护代理访问
|
||||
|
||||
9. [websocket_proxy.go](websocket_proxy.go) - WebSocket 代理
|
||||
- 支持 WebSocket 协议
|
||||
- 支持 WebSocket 拦截
|
||||
- 适用于实时通信场景
|
||||
|
||||
10. [rate_limit_proxy.go](rate_limit_proxy.go) - 速率限制代理
|
||||
- 支持请求速率限制
|
||||
- 可配置最大请求速率
|
||||
- 防止服务器过载
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 编译示例:
|
||||
```bash
|
||||
go build -o forward_proxy examples/forward_proxy.go
|
||||
```
|
||||
|
||||
2. 运行示例:
|
||||
```bash
|
||||
./forward_proxy
|
||||
```
|
||||
|
||||
3. 配置代理:
|
||||
- 在浏览器中设置代理服务器为 `localhost:8080`
|
||||
- 或使用环境变量:
|
||||
```bash
|
||||
export http_proxy=http://localhost:8080
|
||||
export https_proxy=http://localhost:8080
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 使用 HTTPS 解密功能时,需要安装 CA 证书
|
||||
2. 某些功能可能需要额外的配置(如证书、密钥等)
|
||||
3. 建议在生产环境中使用更安全的配置
|
||||
4. 监控指标默认在 9090 端口提供
|
||||
|
||||
## 开发建议
|
||||
|
||||
1. 根据实际需求选择合适的示例作为起点
|
||||
2. 可以组合多个功能来满足复杂需求
|
||||
3. 注意处理错误和异常情况
|
||||
4. 在生产环境中添加适当的日志记录
|
||||
5. 考虑添加监控和告警机制
|
32
examples/other/cache/cache_proxy.go
vendored
Normal file
32
examples/other/cache/cache_proxy.go
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/cache"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建内存缓存
|
||||
memCache := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000)
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithHTTPCache(memCache),
|
||||
goproxy.WithRequestTimeout(30*time.Second),
|
||||
goproxy.WithIdleTimeout(60*time.Second),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("缓存代理服务器启动在 :8080")
|
||||
log.Println("缓存配置:")
|
||||
log.Printf("- 缓存类型: 内存缓存\n")
|
||||
log.Printf("- 请求超时: 30s\n")
|
||||
log.Printf("- 空闲超时: 60s\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
56
examples/other/custom_delegate/custom_delegate.go
Normal file
56
examples/other/custom_delegate/custom_delegate.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) {
|
||||
log.Printf("收到请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||
log.Printf("请求头: %v\n", ctx.Req.Header)
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
log.Printf("响应错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("收到响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||
log.Printf("响应头: %v\n", resp.Header)
|
||||
}
|
||||
|
||||
// AfterResponse 响应后事件
|
||||
func (d *CustomDelegate) AfterResponse(ctx *goproxy.Context, resp *http.Response) {
|
||||
log.Printf("请求完成: %s %s, 耗时: %v\n",
|
||||
ctx.Req.Method,
|
||||
ctx.Req.URL.String(),
|
||||
time.Since(ctx.StartTime))
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
goproxy.WithRequestTimeout(30*time.Second),
|
||||
goproxy.WithIdleTimeout(60*time.Second),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("自定义委托代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
22
examples/other/forward/forward_proxy.go
Normal file
22
examples/other/forward/forward_proxy.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithConfig(config.DefaultConfig()),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("正向代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
27
examples/other/https/https_proxy.go
Normal file
27
examples/other/https/https_proxy.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建证书缓存
|
||||
certCache := &goproxy.MemCertCache{}
|
||||
|
||||
// 创建代理实例,启用 HTTPS 解密
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDecryptHTTPS(certCache),
|
||||
goproxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
goproxy.WithEnableECDSA(true),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("HTTPS 解密代理服务器启动在 :8080")
|
||||
log.Println("请确保已安装 CA 证书 (ca.crt)")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
67
examples/other/load_balance/load_balance.go
Normal file
67
examples/other/load_balance/load_balance.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||||
)
|
||||
|
||||
// LoadBalanceDelegate 负载均衡委托
|
||||
type LoadBalanceDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
lb loadbalance.LoadBalancer
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *LoadBalanceDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
backend, err := d.lb.Next(req.URL.Host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if backend == nil {
|
||||
return "", nil
|
||||
}
|
||||
return backend.String(), nil
|
||||
}
|
||||
|
||||
// 运行负载均衡代理服务器
|
||||
func main() {
|
||||
// 创建后端服务器列表
|
||||
backends := []string{
|
||||
"http://backend1:8081",
|
||||
"http://backend2:8082",
|
||||
"http://backend3:8083",
|
||||
}
|
||||
|
||||
// 创建负载均衡器
|
||||
lb := loadbalance.NewRoundRobinBalancer()
|
||||
|
||||
// 添加后端服务器
|
||||
for _, backend := range backends {
|
||||
if err := lb.Add(backend, 1); err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 创建负载均衡委托
|
||||
delegate := &LoadBalanceDelegate{
|
||||
lb: lb,
|
||||
}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
goproxy.WithLoadBalancer(lb),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("负载均衡代理服务器启动在 :8080")
|
||||
log.Println("后端服务器列表:")
|
||||
for _, backend := range backends {
|
||||
log.Printf("- %s\n", backend)
|
||||
}
|
||||
http.ListenAndServe(":8080", proxy)
|
||||
}
|
67
examples/other/metrics/metrics_proxy.go
Normal file
67
examples/other/metrics/metrics_proxy.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/config"
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// CustomDelegate 自定义代理委托
|
||||
type CustomDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// Connect 处理连接事件
|
||||
func (d *CustomDelegate) Connect(ctx *goproxy.Context, rw http.ResponseWriter) {
|
||||
log.Printf("新的连接: %s", ctx.Req.RemoteAddr)
|
||||
}
|
||||
|
||||
// 运行监控代理服务器
|
||||
func main() {
|
||||
// 创建监控指标收集器
|
||||
metricsCollector := metrics.NewSimpleMetrics()
|
||||
|
||||
// 创建基础配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.EnableRetry = true
|
||||
cfg.MaxRetries = 3
|
||||
cfg.BaseBackoff = time.Second
|
||||
cfg.MaxBackoff = 10 * time.Second
|
||||
cfg.EnableCache = true
|
||||
cfg.CacheTTL = 5 * time.Minute
|
||||
cfg.ConnectionPoolSize = 100
|
||||
cfg.RateLimit = 1000
|
||||
cfg.EnableMetrics = true
|
||||
cfg.DecryptHTTPS = false // 禁用HTTPS解密
|
||||
cfg.InsecureSkipVerify = true // 跳过证书验证
|
||||
cfg.RequestTimeout = 30 * time.Second
|
||||
cfg.IdleTimeout = 60 * time.Second
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithMetrics(metricsCollector),
|
||||
goproxy.WithRequestTimeout(30*time.Second),
|
||||
goproxy.WithIdleTimeout(60*time.Second),
|
||||
goproxy.WithConfig(cfg),
|
||||
goproxy.WithDelegate(&CustomDelegate{}),
|
||||
)
|
||||
|
||||
// 启动监控指标服务器
|
||||
go func() {
|
||||
http.Handle("/metrics", metricsCollector.GetHandler())
|
||||
log.Println("监控指标服务器启动在 :9090")
|
||||
if err := http.ListenAndServe(":9090", nil); err != nil {
|
||||
log.Printf("监控指标服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("监控代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Printf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
23
examples/other/rate_limit/rate_limit_proxy.go
Normal file
23
examples/other/rate_limit/rate_limit_proxy.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithRateLimit(100), // 每秒最多处理 100 个请求
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("速率限制代理服务器启动在 :8080")
|
||||
log.Println("速率限制配置:")
|
||||
log.Printf("- 最大请求速率: 100 请求/秒\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
45
examples/other/reverse_proxy/reverse_proxy.go
Normal file
45
examples/other/reverse_proxy/reverse_proxy.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
// ReverseProxyDelegate 反向代理委托
|
||||
type ReverseProxyDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
backendURL *url.URL
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.backendURL.String(), nil
|
||||
}
|
||||
|
||||
// RunReverseProxy 运行反向代理服务器
|
||||
func main() {
|
||||
// 解析后端服务器地址
|
||||
backendURL, err := url.Parse("http://localhost:8081")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 创建反向代理委托
|
||||
delegate := &ReverseProxyDelegate{
|
||||
backendURL: backendURL,
|
||||
}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
goproxy.WithReverseProxy(true),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("反向代理服务器启动在 :8080")
|
||||
log.Printf("请求将被转发到: %s\n", backendURL.String())
|
||||
http.ListenAndServe(":8080", proxy)
|
||||
}
|
23
examples/other/websocket/websocket_proxy.go
Normal file
23
examples/other/websocket/websocket_proxy.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithEnableWebsocketIntercept(),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("WebSocket 代理服务器启动在 :8080")
|
||||
log.Println("WebSocket 配置:")
|
||||
log.Printf("- WebSocket 拦截: 启用\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
121
examples/rewriter/http_server.go
Normal file
121
examples/rewriter/http_server.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/rewriter"
|
||||
)
|
||||
|
||||
// RewriteReverseProxy 重写反向代理
|
||||
// 在请求发送到后端服务器前重写URL
|
||||
type RewriteReverseProxy struct {
|
||||
// 后端服务器地址
|
||||
Target *url.URL
|
||||
// URL重写器
|
||||
Rewriter *rewriter.Rewriter
|
||||
// 反向代理
|
||||
Proxy *httputil.ReverseProxy
|
||||
}
|
||||
|
||||
// NewRewriteReverseProxy 创建重写反向代理
|
||||
func NewRewriteReverseProxy(target string, rewriteRulesFile string) (*RewriteReverseProxy, error) {
|
||||
// 解析目标URL
|
||||
targetURL, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建重写器
|
||||
rw := rewriter.NewRewriter()
|
||||
|
||||
// 加载重写规则
|
||||
if rewriteRulesFile != "" {
|
||||
if err := rw.LoadRulesFromFile(rewriteRulesFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 创建反向代理
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// 修改默认的Director函数,添加URL重写逻辑
|
||||
defaultDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
// 先执行默认的Director函数
|
||||
defaultDirector(req)
|
||||
|
||||
// 然后执行URL重写
|
||||
rw.Rewrite(req)
|
||||
|
||||
// 记录重写后的URL
|
||||
log.Printf("请求重写: %s -> %s", req.URL.Path, req.URL.String())
|
||||
}
|
||||
|
||||
// 修改响应处理器,重写响应头
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// 重写响应
|
||||
rw.RewriteResponse(resp, targetURL.Host)
|
||||
return nil
|
||||
}
|
||||
|
||||
return &RewriteReverseProxy{
|
||||
Target: targetURL,
|
||||
Rewriter: rw,
|
||||
Proxy: proxy,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ServeHTTP 实现http.Handler接口
|
||||
func (rrp *RewriteReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("收到请求: %s %s", r.Method, r.URL.Path)
|
||||
rrp.Proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// RewriteMiddleware 中间件:将重写器应用到处理链中
|
||||
func RewriteMiddleware(rw *rewriter.Rewriter, next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 重写URL
|
||||
originalPath := r.URL.Path
|
||||
rw.Rewrite(r)
|
||||
|
||||
if originalPath != r.URL.Path {
|
||||
log.Printf("URL重写: %s -> %s", originalPath, r.URL.Path)
|
||||
}
|
||||
|
||||
// 继续处理链
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建重写器
|
||||
rw := rewriter.NewRewriter()
|
||||
|
||||
// 加载重写规则
|
||||
if err := rw.LoadRulesFromFile("rules.json"); err != nil {
|
||||
log.Fatalf("加载重写规则失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建反向代理
|
||||
proxy, err := NewRewriteReverseProxy("http://192.168.1.212:80", "rules.json")
|
||||
if err != nil {
|
||||
log.Fatalf("创建反向代理失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建静态文件服务器(模拟普通Web服务器)
|
||||
fileServer := http.FileServer(http.Dir("./static"))
|
||||
|
||||
// 使用中间件包装文件服务器
|
||||
rewrittenFileServer := RewriteMiddleware(rw, fileServer)
|
||||
|
||||
// 设置处理函数
|
||||
http.Handle("/api/", rewrittenFileServer)
|
||||
http.Handle("/", proxy)
|
||||
|
||||
// 启动服务器
|
||||
log.Println("服务器启动在 :8080...")
|
||||
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||
}
|
88
examples/rewriter/main.go
Normal file
88
examples/rewriter/main.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/rewriter"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建URL重写器
|
||||
rw := rewriter.NewRewriter()
|
||||
|
||||
// 添加一些规则
|
||||
rw.AddRule("/api/v1/", "/api/v2/", false)
|
||||
rw.AddRuleWithDescription("/old/(.*)/page", "/new/$1/page", true, "旧页面重定向")
|
||||
|
||||
// 从文件加载规则
|
||||
exampleDir, _ := os.Getwd()
|
||||
jsonRulesPath := filepath.Join(exampleDir, "rules.json")
|
||||
textRulesPath := filepath.Join(exampleDir, "rules.txt")
|
||||
|
||||
// 先加载JSON格式规则
|
||||
fmt.Println("从JSON文件加载规则...")
|
||||
if err := rw.LoadRulesFromFile(jsonRulesPath); err != nil {
|
||||
fmt.Printf("从JSON文件加载规则失败: %v\n", err)
|
||||
}
|
||||
|
||||
// 然后加载文本格式规则
|
||||
fmt.Println("从文本文件加载规则...")
|
||||
if err := rw.LoadRulesFromFile(textRulesPath); err != nil {
|
||||
fmt.Printf("从文本文件加载规则失败: %v\n", err)
|
||||
}
|
||||
|
||||
// 打印所有规则
|
||||
fmt.Println("\n当前规则列表:")
|
||||
for i, rule := range rw.GetRules() {
|
||||
status := "启用"
|
||||
if !rule.Enabled {
|
||||
status = "禁用"
|
||||
}
|
||||
fmt.Printf("[%d] %s -> %s [%s] (%s)\n",
|
||||
i, rule.Pattern, rule.Replacement,
|
||||
map[bool]string{true: "正则", false: "前缀"}[rule.UseRegex],
|
||||
status)
|
||||
}
|
||||
|
||||
// 测试重写
|
||||
fmt.Println("\n测试URL重写:")
|
||||
testURLs := []string{
|
||||
"/api/v1/users",
|
||||
"/old/profile/page",
|
||||
"/legacy-files/document.pdf",
|
||||
"/en/about/company",
|
||||
}
|
||||
|
||||
for _, url := range testURLs {
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "http://example.com"+url, nil)
|
||||
|
||||
// 重写URL
|
||||
fmt.Printf("原始URL: %s\n", req.URL.Path)
|
||||
rw.Rewrite(req)
|
||||
fmt.Printf("重写后: %s\n\n", req.URL.Path)
|
||||
}
|
||||
|
||||
// 测试响应重写
|
||||
fmt.Println("测试响应重写:")
|
||||
resp := &http.Response{
|
||||
Header: http.Header{},
|
||||
}
|
||||
resp.Header.Set("Location", "http://backend.example.com/old/profile/page")
|
||||
fmt.Printf("原始Location: %s\n", resp.Header.Get("Location"))
|
||||
rw.RewriteResponse(resp, "backend.example.com")
|
||||
fmt.Printf("重写后Location: %s\n", resp.Header.Get("Location"))
|
||||
|
||||
// 保存规则到新文件
|
||||
newRulesPath := filepath.Join(exampleDir, "new_rules.json")
|
||||
fmt.Printf("\n保存规则到文件: %s\n", newRulesPath)
|
||||
if err := rw.SaveRulesToFile(newRulesPath); err != nil {
|
||||
fmt.Printf("保存规则失败: %v\n", err)
|
||||
} else {
|
||||
fmt.Println("规则已保存")
|
||||
}
|
||||
}
|
63
examples/rewriter/new_rules.json
Normal file
63
examples/rewriter/new_rules.json
Normal file
@@ -0,0 +1,63 @@
|
||||
[
|
||||
{
|
||||
"pattern": "/api/v1/",
|
||||
"replacement": "/api/v2/",
|
||||
"use_regex": false,
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/old/(.*)/page",
|
||||
"replacement": "/new/$1/page",
|
||||
"use_regex": true,
|
||||
"description": "旧页面重定向",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/api/v1/",
|
||||
"replacement": "/api/v2/",
|
||||
"use_regex": false,
|
||||
"description": "将API v1请求重定向到v2",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/old/(.*)/page",
|
||||
"replacement": "/new/$1/page",
|
||||
"use_regex": true,
|
||||
"description": "旧页面格式重定向到新格式",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/legacy-files/",
|
||||
"replacement": "/files/",
|
||||
"use_regex": false,
|
||||
"description": "旧文件路径重定向"
|
||||
},
|
||||
{
|
||||
"pattern": "^/(en|zh|ja)/(.*)",
|
||||
"replacement": "/$2?lang=$1",
|
||||
"use_regex": true,
|
||||
"description": "将语言路径转换为查询参数",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/api/v1/",
|
||||
"replacement": "/api/v2/",
|
||||
"use_regex": false,
|
||||
"description": "将API v1请求重定向到v2",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/old/(.*)/page",
|
||||
"replacement": "/new/$1/page",
|
||||
"use_regex": true,
|
||||
"description": "旧页面格式重定向到新格式",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "^/(en|zh|ja)/(.*)",
|
||||
"replacement": "/$2?lang=$1",
|
||||
"use_regex": true,
|
||||
"description": "将语言路径转换为查询参数",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
30
examples/rewriter/rules.json
Normal file
30
examples/rewriter/rules.json
Normal file
@@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"pattern": "/api/v1/",
|
||||
"replacement": "/api/v2/",
|
||||
"use_regex": false,
|
||||
"description": "将API v1请求重定向到v2",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/old/(.*)/page",
|
||||
"replacement": "/new/$1/page",
|
||||
"use_regex": true,
|
||||
"description": "旧页面格式重定向到新格式",
|
||||
"enabled": true
|
||||
},
|
||||
{
|
||||
"pattern": "/legacy-files/",
|
||||
"replacement": "/files/",
|
||||
"use_regex": false,
|
||||
"description": "旧文件路径重定向",
|
||||
"enabled": false
|
||||
},
|
||||
{
|
||||
"pattern": "^/(en|zh|ja)/(.*)",
|
||||
"replacement": "/$2?lang=$1",
|
||||
"use_regex": true,
|
||||
"description": "将语言路径转换为查询参数",
|
||||
"enabled": true
|
||||
}
|
||||
]
|
7
examples/rewriter/rules.txt
Normal file
7
examples/rewriter/rules.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# URL重写规则示例
|
||||
# 格式: pattern replacement [regex] [#description]
|
||||
|
||||
/api/v1/ /api/v2/ # 将API v1请求重定向到v2
|
||||
/old/(.*)/page /new/$1/page regex # 旧页面格式重定向到新格式
|
||||
# /legacy-files/ /files/ # 旧文件路径重定向 (已禁用)
|
||||
^/(en|zh|ja)/(.*) /$2?lang=$1 regex # 将语言路径转换为查询参数
|
259
examples/rewriter/web_admin.go
Normal file
259
examples/rewriter/web_admin.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/rewriter"
|
||||
)
|
||||
|
||||
// 全局重写器实例
|
||||
var rw *rewriter.Rewriter
|
||||
|
||||
// 规则配置文件路径
|
||||
var rulesFile = "rules.json"
|
||||
|
||||
// AdminHandler Web管理页面处理器
|
||||
func AdminHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 定义模板
|
||||
tmpl := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>URL重写规则管理</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 20px; }
|
||||
h1 { color: #333; }
|
||||
table { width: 100%; border-collapse: collapse; margin-top: 20px; }
|
||||
th, td { padding: 10px; text-align: left; border-bottom: 1px solid #ddd; }
|
||||
th { background-color: #f2f2f2; }
|
||||
.enabled { color: green; }
|
||||
.disabled { color: red; }
|
||||
.form-group { margin-bottom: 15px; }
|
||||
label { display: block; margin-bottom: 5px; }
|
||||
input[type="text"], textarea { width: 100%; padding: 8px; box-sizing: border-box; }
|
||||
input[type="checkbox"] { margin-right: 5px; }
|
||||
button { padding: 10px 15px; background-color: #4CAF50; color: white; border: none; cursor: pointer; }
|
||||
button:hover { background-color: #45a049; }
|
||||
.action-btn { padding: 5px 10px; margin-right: 5px; }
|
||||
.delete-btn { background-color: #f44336; }
|
||||
.edit-btn { background-color: #2196F3; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>URL重写规则管理</h1>
|
||||
|
||||
<h2>当前规则</h2>
|
||||
<table>
|
||||
<tr>
|
||||
<th>索引</th>
|
||||
<th>匹配模式</th>
|
||||
<th>替换模式</th>
|
||||
<th>类型</th>
|
||||
<th>描述</th>
|
||||
<th>状态</th>
|
||||
<th>操作</th>
|
||||
</tr>
|
||||
{{range $i, $rule := .Rules}}
|
||||
<tr>
|
||||
<td>{{$i}}</td>
|
||||
<td>{{$rule.Pattern}}</td>
|
||||
<td>{{$rule.Replacement}}</td>
|
||||
<td>{{if $rule.UseRegex}}正则表达式{{else}}前缀匹配{{end}}</td>
|
||||
<td>{{$rule.Description}}</td>
|
||||
<td class="{{if $rule.Enabled}}enabled{{else}}disabled{{end}}">
|
||||
{{if $rule.Enabled}}启用{{else}}禁用{{end}}
|
||||
</td>
|
||||
<td>
|
||||
{{if $rule.Enabled}}
|
||||
<a href="/disable/{{$i}}"><button class="action-btn">禁用</button></a>
|
||||
{{else}}
|
||||
<a href="/enable/{{$i}}"><button class="action-btn">启用</button></a>
|
||||
{{end}}
|
||||
<a href="/remove/{{$i}}"><button class="action-btn delete-btn">删除</button></a>
|
||||
</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
|
||||
<h2>添加新规则</h2>
|
||||
<form action="/add" method="post">
|
||||
<div class="form-group">
|
||||
<label for="pattern">匹配模式:</label>
|
||||
<input type="text" id="pattern" name="pattern" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="replacement">替换模式:</label>
|
||||
<input type="text" id="replacement" name="replacement" required>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="regex">
|
||||
<input type="checkbox" id="regex" name="regex">
|
||||
使用正则表达式
|
||||
</label>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="description">描述:</label>
|
||||
<textarea id="description" name="description" rows="3"></textarea>
|
||||
</div>
|
||||
<button type="submit">添加规则</button>
|
||||
</form>
|
||||
|
||||
<h2>保存配置</h2>
|
||||
<form action="/save" method="post">
|
||||
<button type="submit">保存规则配置</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
`
|
||||
|
||||
// 解析模板
|
||||
t, err := template.New("admin").Parse(tmpl)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("模板解析错误: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 渲染模板
|
||||
data := struct {
|
||||
Rules []*rewriter.RewriteRule
|
||||
}{
|
||||
Rules: rw.GetRules(),
|
||||
}
|
||||
|
||||
if err := t.Execute(w, data); err != nil {
|
||||
http.Error(w, fmt.Sprintf("模板渲染错误: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// AddRuleHandler 添加规则处理器
|
||||
func AddRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
r.ParseForm()
|
||||
|
||||
pattern := r.Form.Get("pattern")
|
||||
replacement := r.Form.Get("replacement")
|
||||
useRegex := r.Form.Get("regex") == "on"
|
||||
description := r.Form.Get("description")
|
||||
|
||||
if pattern == "" || replacement == "" {
|
||||
http.Error(w, "匹配模式和替换模式不能为空", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rw.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil {
|
||||
http.Error(w, fmt.Sprintf("添加规则失败: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// EnableRuleHandler 启用规则处理器
|
||||
func EnableRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析URL中的规则索引
|
||||
indexStr := r.URL.Path[len("/enable/"):]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
http.Error(w, "无效的规则索引", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rw.EnableRule(index); err != nil {
|
||||
http.Error(w, fmt.Sprintf("启用规则失败: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// DisableRuleHandler 禁用规则处理器
|
||||
func DisableRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析URL中的规则索引
|
||||
indexStr := r.URL.Path[len("/disable/"):]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
http.Error(w, "无效的规则索引", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rw.DisableRule(index); err != nil {
|
||||
http.Error(w, fmt.Sprintf("禁用规则失败: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// RemoveRuleHandler 删除规则处理器
|
||||
func RemoveRuleHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析URL中的规则索引
|
||||
indexStr := r.URL.Path[len("/remove/"):]
|
||||
index, err := strconv.Atoi(indexStr)
|
||||
if err != nil {
|
||||
http.Error(w, "无效的规则索引", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rw.RemoveRule(index); err != nil {
|
||||
http.Error(w, fmt.Sprintf("删除规则失败: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// SaveRulesHandler 保存规则处理器
|
||||
func SaveRulesHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rw.SaveRulesToFile(rulesFile); err != nil {
|
||||
http.Error(w, fmt.Sprintf("保存规则失败: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// APIHandler API处理器,返回JSON格式的规则列表
|
||||
func APIHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
rules := rw.GetRules()
|
||||
json.NewEncoder(w).Encode(rules)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建重写器
|
||||
rw = rewriter.NewRewriter()
|
||||
|
||||
// 尝试从文件加载规则
|
||||
if err := rw.LoadRulesFromFile(rulesFile); err != nil {
|
||||
log.Printf("加载规则失败: %v, 将使用空规则集", err)
|
||||
}
|
||||
|
||||
// 注册处理器
|
||||
http.HandleFunc("/", AdminHandler)
|
||||
http.HandleFunc("/add", AddRuleHandler)
|
||||
http.HandleFunc("/enable/", EnableRuleHandler)
|
||||
http.HandleFunc("/disable/", DisableRuleHandler)
|
||||
http.HandleFunc("/remove/", RemoveRuleHandler)
|
||||
http.HandleFunc("/save", SaveRulesHandler)
|
||||
http.HandleFunc("/api/rules", APIHandler)
|
||||
|
||||
// 启动服务器
|
||||
log.Println("管理界面启动在 :8080...")
|
||||
log.Fatal(http.ListenAndServe(":8080", nil))
|
||||
}
|
83
examples/rule/dns_config.json
Normal file
83
examples/rule/dns_config.json
Normal file
@@ -0,0 +1,83 @@
|
||||
{
|
||||
"records": {
|
||||
"www.github.com": "140.82.121.3",
|
||||
"github.com": "140.82.121.4",
|
||||
"api.github.com": "140.82.121.5",
|
||||
"assets-cdn.github.com": "185.199.108.153",
|
||||
"collector.github.com": "140.82.121.6",
|
||||
"codeload.github.com": "140.82.121.9",
|
||||
"gist.github.com": "140.82.121.4",
|
||||
"raw.githubusercontent.com": "185.199.108.133",
|
||||
"s3.amazonaws.com": "52.216.162.69",
|
||||
"www.google.com": "142.250.4.147",
|
||||
"google.com": "142.250.4.139",
|
||||
"ajax.googleapis.com": "142.250.4.95",
|
||||
"fonts.googleapis.com": "142.250.4.95",
|
||||
"www.baidu.com": "110.242.68.66",
|
||||
"baidu.com": "39.156.66.10",
|
||||
|
||||
"custom-port-example.com": "192.168.1.10:8080",
|
||||
"api.custom-port-example.com": "192.168.1.11:8443",
|
||||
"dev.example.com": "127.0.0.1:3000",
|
||||
"api.dev.example.com": "127.0.0.1:3001",
|
||||
"db.dev.example.com": "127.0.0.1:5432"
|
||||
},
|
||||
"use_fallback": true,
|
||||
"ttl": 300,
|
||||
"rules": [
|
||||
{
|
||||
"id": "dns-rule-1",
|
||||
"type": "dns",
|
||||
"priority": 100,
|
||||
"pattern": "example.com",
|
||||
"match_type": "exact",
|
||||
"enabled": true,
|
||||
"targets": [
|
||||
{
|
||||
"ip": "192.168.1.100",
|
||||
"port": 80
|
||||
},
|
||||
{
|
||||
"ip": "192.168.1.101",
|
||||
"port": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "dns-rule-2",
|
||||
"type": "dns",
|
||||
"priority": 90,
|
||||
"pattern": "*.example.com",
|
||||
"match_type": "wildcard",
|
||||
"enabled": true,
|
||||
"targets": [
|
||||
{
|
||||
"ip": "192.168.1.102",
|
||||
"port": 80
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "route-rule-1",
|
||||
"type": "route",
|
||||
"priority": 100,
|
||||
"pattern": "/api/v1",
|
||||
"match_type": "path",
|
||||
"enabled": true,
|
||||
"target": "http://api.example.com",
|
||||
"header_modifier": {
|
||||
"X-Forwarded-Host": "api.example.com",
|
||||
"X-Real-IP": "${client_ip}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "rewrite-rule-1",
|
||||
"type": "rewrite",
|
||||
"priority": 90,
|
||||
"pattern": "/old/(.*)",
|
||||
"match_type": "regex",
|
||||
"enabled": true,
|
||||
"replacement": "/new/$1"
|
||||
}
|
||||
]
|
||||
}
|
60
examples/rule/hosts.txt
Normal file
60
examples/rule/hosts.txt
Normal file
@@ -0,0 +1,60 @@
|
||||
# GitHub相关域名解析
|
||||
140.82.121.3 www.github.com
|
||||
140.82.121.4 github.com
|
||||
140.82.121.5 api.github.com
|
||||
185.199.108.153 assets-cdn.github.com
|
||||
140.82.121.6 collector.github.com
|
||||
140.82.121.9 codeload.github.com
|
||||
140.82.121.4 gist.github.com
|
||||
185.199.108.133 raw.githubusercontent.com
|
||||
|
||||
# AWS服务
|
||||
52.216.162.69 s3.amazonaws.com
|
||||
|
||||
# Google相关域名解析
|
||||
142.250.4.147 www.google.com
|
||||
142.250.4.139 google.com
|
||||
142.250.4.95 ajax.googleapis.com
|
||||
142.250.4.95 fonts.googleapis.com
|
||||
|
||||
# 百度相关域名解析
|
||||
110.242.68.66 www.baidu.com
|
||||
39.156.66.10 baidu.com
|
||||
|
||||
# 自定义端口示例
|
||||
192.168.1.10:8080 custom-port-example.com
|
||||
192.168.1.11:8443 api.custom-port-example.com
|
||||
|
||||
# 本地开发环境
|
||||
127.0.0.1 localhost
|
||||
127.0.0.1:3000 dev.example.com
|
||||
127.0.0.1:3001 api.dev.example.com
|
||||
127.0.0.1:5432 db.dev.example.com
|
||||
|
||||
# 基本格式:IP 域名 [端口]
|
||||
192.168.1.100 example.com
|
||||
192.168.1.101 api.example.com
|
||||
192.168.1.102 test.example.com:8080
|
||||
|
||||
# 支持通配符域名
|
||||
192.168.1.103 *.example.com
|
||||
192.168.1.104 *.api.example.com
|
||||
|
||||
# 支持多个域名
|
||||
192.168.1.105 app1.example.com app2.example.com app3.example.com
|
||||
|
||||
# 支持注释
|
||||
# 开发环境
|
||||
192.168.1.106 dev.example.com
|
||||
# 测试环境
|
||||
192.168.1.107 test.example.com
|
||||
# 生产环境
|
||||
192.168.1.108 prod.example.com
|
||||
|
||||
# 支持IPv6
|
||||
2001:db8::1 ipv6.example.com
|
||||
2001:db8::2 ipv6-api.example.com:8080
|
||||
|
||||
# 支持混合IPv4和IPv6
|
||||
192.168.1.109 mixed.example.com
|
||||
2001:db8::3 mixed.example.com
|
23
examples/rule/wildcard_dns_config.json
Normal file
23
examples/rule/wildcard_dns_config.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"records": {
|
||||
"example.com": "93.184.216.34",
|
||||
"api.example.com": "93.184.216.35:8443",
|
||||
|
||||
"*.github.com": "140.82.121.3",
|
||||
"github.com": "140.82.121.4",
|
||||
"api.github.com": "140.82.121.5",
|
||||
|
||||
"*.s3.amazonaws.com": "52.216.162.69",
|
||||
"cdn-*.example.org": "203.0.113.10",
|
||||
|
||||
"*.dev.local": "127.0.0.1:3000",
|
||||
"api.*.dev.local": "127.0.0.1:3001",
|
||||
"db.*.dev.local": "127.0.0.1:5432",
|
||||
|
||||
"*.test.example.com": "192.168.1.100",
|
||||
"*.staging.example.com": "192.168.1.101",
|
||||
"*.production.example.com": "192.168.1.102"
|
||||
},
|
||||
"use_fallback": true,
|
||||
"ttl": 300
|
||||
}
|
27
examples/rule/wildcard_hosts.txt
Normal file
27
examples/rule/wildcard_hosts.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# 精确匹配记录
|
||||
93.184.216.34 example.com
|
||||
93.184.216.35:8443 api.example.com
|
||||
|
||||
# GitHub相关泛解析
|
||||
140.82.121.3 *.github.com
|
||||
140.82.121.4 github.com
|
||||
140.82.121.5 api.github.com
|
||||
|
||||
# AWS服务泛解析
|
||||
52.216.162.69 *.s3.amazonaws.com
|
||||
203.0.113.10 cdn-*.example.org
|
||||
|
||||
# 本地开发环境泛解析
|
||||
127.0.0.1:3000 *.dev.local
|
||||
127.0.0.1:3001 api.*.dev.local
|
||||
127.0.0.1:5432 db.*.dev.local
|
||||
|
||||
# 多环境测试泛解析
|
||||
192.168.1.100 *.test.example.com
|
||||
192.168.1.101 *.staging.example.com
|
||||
192.168.1.102 *.production.example.com
|
||||
|
||||
# 特定子域名泛解析示例
|
||||
10.0.0.1 *.api.service.com
|
||||
10.0.0.2 *.auth.service.com
|
||||
10.0.0.3 *.cdn.service.com
|
84
examples/scripts/generate_cert.ps1
Normal file
84
examples/scripts/generate_cert.ps1
Normal file
@@ -0,0 +1,84 @@
|
||||
# PowerShell 脚本:生成自签名证书
|
||||
|
||||
# 处理命令行参数
|
||||
param(
|
||||
[Parameter(HelpMessage="证书有效期(天数)")]
|
||||
[int]$days = 365,
|
||||
|
||||
[Parameter(HelpMessage="证书主题")]
|
||||
[string]$subject = "CN=localhost,OU=Test,O=GoProxy,L=Shanghai,S=Shanghai,C=CN",
|
||||
|
||||
[Parameter(HelpMessage="公用名(CN)")]
|
||||
[string]$cn = "",
|
||||
|
||||
[Parameter(HelpMessage="显示帮助信息")]
|
||||
[switch]$help
|
||||
)
|
||||
|
||||
# 帮助信息
|
||||
function Show-Help {
|
||||
Write-Host "生成自签名证书"
|
||||
Write-Host
|
||||
Write-Host "用法: .\generate_cert.ps1 [选项]"
|
||||
Write-Host
|
||||
Write-Host "选项:"
|
||||
Write-Host " -help 显示此帮助信息"
|
||||
Write-Host " -days DAYS 证书有效期(天数),默认: 365"
|
||||
Write-Host " -subject SUB 证书主题,默认: $subject"
|
||||
Write-Host " -cn CN 公用名(CN),将替换主题中的CN,默认: localhost"
|
||||
Write-Host
|
||||
Write-Host "示例:"
|
||||
Write-Host " .\generate_cert.ps1 -days 730 -cn example.com"
|
||||
Write-Host
|
||||
}
|
||||
|
||||
# 如果请求帮助,显示帮助信息并退出
|
||||
if ($help) {
|
||||
Show-Help
|
||||
exit 0
|
||||
}
|
||||
|
||||
# 如果指定了CN,替换主题中的CN部分
|
||||
if ($cn -ne "") {
|
||||
$subject = $subject -replace "CN=[^,]*", "CN=$cn"
|
||||
}
|
||||
|
||||
Write-Host "生成自签名证书..."
|
||||
Write-Host "有效期: $days 天"
|
||||
Write-Host "主题: $subject"
|
||||
|
||||
# 检查OpenSSL是否可用
|
||||
$openssl = Get-Command "openssl" -ErrorAction SilentlyContinue
|
||||
if (-not $openssl) {
|
||||
Write-Host "错误: 未找到OpenSSL命令。请安装OpenSSL并确保它在PATH环境变量中。" -ForegroundColor Red
|
||||
Write-Host "您可以从以下地址下载OpenSSL for Windows: https://slproweb.com/products/Win32OpenSSL.html" -ForegroundColor Yellow
|
||||
exit 1
|
||||
}
|
||||
|
||||
try {
|
||||
# 生成私钥
|
||||
Write-Host "正在生成私钥..." -ForegroundColor Cyan
|
||||
& openssl genrsa -out server.key 2048
|
||||
|
||||
# 生成证书请求
|
||||
Write-Host "正在生成证书请求..." -ForegroundColor Cyan
|
||||
& openssl req -new -key server.key -out server.csr -subj $subject.Replace(",", "/")
|
||||
|
||||
# 生成自签名证书
|
||||
Write-Host "正在生成自签名证书..." -ForegroundColor Cyan
|
||||
& openssl x509 -req -days $days -in server.csr -signkey server.key -out server.crt
|
||||
|
||||
# 删除证书请求文件
|
||||
Remove-Item server.csr -Force
|
||||
|
||||
Write-Host "完成!已生成以下文件:" -ForegroundColor Green
|
||||
Write-Host " - server.key: 私钥" -ForegroundColor Green
|
||||
Write-Host " - server.crt: 证书" -ForegroundColor Green
|
||||
Write-Host
|
||||
Write-Host "您可以使用这些文件启动HTTPS代理:" -ForegroundColor Cyan
|
||||
Write-Host "go run cmd/custom_dns_https_proxy/main.go -cert server.crt -key server.key" -ForegroundColor Cyan
|
||||
}
|
||||
catch {
|
||||
Write-Host "错误: 生成证书时发生错误: $_" -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
81
examples/scripts/generate_cert.sh
Normal file
81
examples/scripts/generate_cert.sh
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/bin/bash
|
||||
|
||||
# 退出时如果有任何命令失败
|
||||
set -e
|
||||
|
||||
# 默认值
|
||||
DAYS=365
|
||||
SUBJECT="/C=CN/ST=Shanghai/L=Shanghai/O=GoProxy/OU=Test/CN=localhost"
|
||||
|
||||
# 帮助信息
|
||||
function show_help {
|
||||
echo "生成自签名证书"
|
||||
echo
|
||||
echo "用法: $0 [选项]"
|
||||
echo
|
||||
echo "选项:"
|
||||
echo " -h, --help 显示此帮助信息"
|
||||
echo " -d, --days DAYS 证书有效期(天数),默认: 365"
|
||||
echo " -s, --subject SUB 证书主题,默认: $SUBJECT"
|
||||
echo " -c, --cn CN 公用名(CN),将替换主题中的CN,默认: localhost"
|
||||
echo
|
||||
echo "示例:"
|
||||
echo " $0 --days 730 --cn example.com"
|
||||
echo
|
||||
}
|
||||
|
||||
# 处理命令行参数
|
||||
while [[ $# -gt 0 ]]; do
|
||||
key="$1"
|
||||
case $key in
|
||||
-h|--help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
-d|--days)
|
||||
DAYS="$2"
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-s|--subject)
|
||||
SUBJECT="$2"
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
-c|--cn)
|
||||
# 替换主题中的CN部分
|
||||
SUBJECT=$(echo $SUBJECT | sed "s/CN=[^\/]*/CN=$2/")
|
||||
shift
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "未知选项: $1"
|
||||
show_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "生成自签名证书..."
|
||||
echo "有效期: $DAYS 天"
|
||||
echo "主题: $SUBJECT"
|
||||
|
||||
# 生成私钥
|
||||
openssl genrsa -out server.key 2048
|
||||
|
||||
# 生成证书请求
|
||||
openssl req -new -key server.key -out server.csr -subj "$SUBJECT"
|
||||
|
||||
# 生成自签名证书
|
||||
openssl x509 -req -days $DAYS -in server.csr -signkey server.key -out server.crt
|
||||
|
||||
# 删除证书请求文件
|
||||
rm server.csr
|
||||
|
||||
echo "完成!已生成以下文件:"
|
||||
echo " - server.key: 私钥"
|
||||
echo " - server.crt: 证书"
|
||||
echo
|
||||
echo "启动HTTPS代理"
|
||||
|
||||
go run ../custom_dns_https_proxy/main.go -cert server.crt -key server.key
|
20
examples/scripts/server.crt
Normal file
20
examples/scripts/server.crt
Normal file
@@ -0,0 +1,20 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDVzCCAj8CFBKBjcPxJ7o8UnWKrYMpI6XWa9crMA0GCSqGSIb3DQEBCwUAMGgx
|
||||
CzAJBgNVBAYTAkNOMREwDwYDVQQIDAhTaGFuZ2hhaTERMA8GA1UEBwwIU2hhbmdo
|
||||
YWkxEDAOBgNVBAoMB0dvUHJveHkxDTALBgNVBAsMBFRlc3QxEjAQBgNVBAMMCWxv
|
||||
Y2FsaG9zdDAeFw0yNTAzMTQxMzA0NDlaFw0yNjAzMTQxMzA0NDlaMGgxCzAJBgNV
|
||||
BAYTAkNOMREwDwYDVQQIDAhTaGFuZ2hhaTERMA8GA1UEBwwIU2hhbmdoYWkxEDAO
|
||||
BgNVBAoMB0dvUHJveHkxDTALBgNVBAsMBFRlc3QxEjAQBgNVBAMMCWxvY2FsaG9z
|
||||
dDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMZWDfiB54iz+hUpUXfC
|
||||
V2OH674a6EEJTqQ6xZ3b9aKC+IUoerzyj8o/cxNvb6AekxiMlxMbAVK+CARqvqzE
|
||||
/6w+SZPB8TZGwLM1yPyDaz1+D05n/Am3slccDby/pkPG/igt1q/RVkizw35Mn9ct
|
||||
gz5niufM78gRQTMr1/8CfgNyfiDa5mZ02fIahUZLBjCotF2jtfN0hX1gagD06wlc
|
||||
9RL36Ms2hxK+A1J6VUMhXdH4u0PdksiRwtQMFW8A4M3fCTJp1a1H1Oj1gub9AXcL
|
||||
UOwFZMrZ6LJFNBDQNRU/e104Fqq0XpfnXEq4SC/AW/wkLuCcRtoOGIg7XQ0Np3dv
|
||||
sj8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAqCp6RcBW6Q2PlbUeqOl4X9KQffRO
|
||||
N+ATvcve0hF3+Jr5tPYDvLwtHtyU1yYKyrM8RmqMcOHxXEmuxhlYaR0P4yUwTPOr
|
||||
l9ZwskIkoTd0nVVlS9nGQMtEc0n+AmWGICE9gqOF66Gup0OPY3OYGdvlqE8NBH43
|
||||
95jSB0grAMudd2TW71Ef+PvieOY7ksJwGiP9tusJqq51Bh9gyhUAk9xxeKbeP0Zp
|
||||
9dy8/kbTW9B5hyLYNYOKhBztKu665cQ1cNL4AA0Y5svBRWlymTqB2HKMIKTGJzDO
|
||||
6Jh6wXWr7Fx/aDROHqY2vOQ25i0FGJCu8/7TUcveCllyjHtJHhljkXht4g==
|
||||
-----END CERTIFICATE-----
|
27
examples/scripts/server.key
Normal file
27
examples/scripts/server.key
Normal file
@@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAxlYN+IHniLP6FSlRd8JXY4frvhroQQlOpDrFndv1ooL4hSh6
|
||||
vPKPyj9zE29voB6TGIyXExsBUr4IBGq+rMT/rD5Jk8HxNkbAszXI/INrPX4PTmf8
|
||||
CbeyVxwNvL+mQ8b+KC3Wr9FWSLPDfkyf1y2DPmeK58zvyBFBMyvX/wJ+A3J+INrm
|
||||
ZnTZ8hqFRksGMKi0XaO183SFfWBqAPTrCVz1EvfoyzaHEr4DUnpVQyFd0fi7Q92S
|
||||
yJHC1AwVbwDgzd8JMmnVrUfU6PWC5v0BdwtQ7AVkytnoskU0ENA1FT97XTgWqrRe
|
||||
l+dcSrhIL8Bb/CQu4JxG2g4YiDtdDQ2nd2+yPwIDAQABAoIBAFLFGvN4kv2TzmwC
|
||||
YENQUVPyJ0mgxQhPMAiNlmb4opv9eGVprT8pIyTOMeIMgVMbL1vxYCLTBExZjdL6
|
||||
ETTcya5CGEaXi2iRQl4HtibbWWfCMfUQpDgR91UvGfSJLoPeibaO2qdo/0875fvR
|
||||
UmtkTP9ACtINzot51/HY/D0p9xjMdLTa2bzj3fzPSLGROy0OGbam0yWwkSVyL6kF
|
||||
xfyVWIG8Vw40L7XTxS8vnM1mSU17qqztrVsbB29eLbby3Vv7vHZ6xWWbv/jW0Hm4
|
||||
9Zl6GU4G4jB6/zafylqnMvUlLzPHRR3k5fxzgloyTmfJegj73ksTsAaM0qD6S2jQ
|
||||
QFil1WECgYEA75qX674QEp5J/C24jCef7vG/x2EvX9II9CXUnERUXmvFx2UFFzcM
|
||||
QAhewLwTywe8WG5BNyde/xv/IRfgOuMtDXB6aQJlibin0CRPlQ/IzEdLt9R5FBSB
|
||||
t3i2ffdJhYObusiTFfzoSV8jyVNyLSvsZXfKa/ckfb7txCe7AaChaxkCgYEA0+iH
|
||||
UkVsVVoXu8OLmXH8q7FWbVciqBMiGOEyZpjluJoCQVmCAogTYTR1qY9HnlLnVprA
|
||||
OS1Q6vjAZaTK6AKcCgBuRjswHi35jpZ77u+Vdd85dVTnlQIdHDyUxHYdAfP7tjKl
|
||||
VnIifDS2zeDH2QORjqzP4QsKd2gnpciguCOdCxcCgYBVvLDeF22y69c3mLiv1kIB
|
||||
g5oHYzxLgmHX022n2T+DZfcoqXpP20/T3erh9qryfLslvZYygTEaAk+h7OQ8zivB
|
||||
4ly7FLN2u4+5CDU99p74kg6DIlGNIOVl3JkYrBMv5m8kQD95n70S/CtXEDgL9+qo
|
||||
SFwzlAUHxflYtorRQ0RfiQKBgEcja7JJzgmFOix1g/raUlmNKheAxgioi6zQhOv+
|
||||
bjgfs5weoU+aQO9D/jATApb6++COCPPo655GLcixntBud9W/uUVof0nSY1Hj4O0g
|
||||
jwtICfECtM/IKt+c0tB1Wl2ae6j5rZmsrTkHNUs+J7kJwqakCxFgdH4LgCveg13t
|
||||
zr23AoGBALA9C/3NIC4TKD2jm0NFBttODyf8gbAT1/7akdiH5BttSbrPg8l4kYwL
|
||||
kvs1TxU0O6cpOHcqG8CEYnTAi4LEocRCQtsD3xMqANqjNLMNvdsV/cQvT6qqoUre
|
||||
IyGgq9+W0NSeeHwtA+qASUaAorIpnplNOkkhyu78pvEtxUFPCZ1c
|
||||
-----END RSA PRIVATE KEY-----
|
47
examples/wildcard_dns_proxy/main.go
Normal file
47
examples/wildcard_dns_proxy/main.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/darkit/goproxy"
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
)
|
||||
|
||||
// WildcardDNSDelegate 通配符 DNS 代理委托
|
||||
type WildcardDNSDelegate struct {
|
||||
goproxy.DefaultDelegate
|
||||
dnsResolver *dns.WildcardResolver
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *WildcardDNSDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.dnsResolver.Resolve(req.URL.Host)
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 创建 DNS 解析器
|
||||
resolver := dns.NewWildcardResolver(map[string]string{
|
||||
"*.example.com": "http://backend.example.com",
|
||||
"*.test.com": "http://backend.test.com",
|
||||
})
|
||||
|
||||
// 创建通配符 DNS 代理委托
|
||||
delegate := &WildcardDNSDelegate{
|
||||
dnsResolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理实例
|
||||
proxy := goproxy.NewProxy(
|
||||
goproxy.WithDelegate(delegate),
|
||||
)
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("通配符 DNS 代理服务器启动在 :8080")
|
||||
log.Println("DNS 配置:")
|
||||
log.Printf("- *.example.com -> backend.example.com\n")
|
||||
log.Printf("- *.test.com -> backend.test.com\n")
|
||||
if err := http.ListenAndServe(":8080", proxy); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
23
go.mod
Normal file
23
go.mod
Normal file
@@ -0,0 +1,23 @@
|
||||
module github.com/darkit/goproxy
|
||||
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/ouqiang/websocket v1.6.2
|
||||
github.com/prometheus/client_golang v1.20.4
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/prometheus/client_model v0.6.1 // indirect
|
||||
github.com/prometheus/common v0.62.0 // indirect
|
||||
github.com/prometheus/procfs v0.15.1 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
google.golang.org/protobuf v1.36.1 // indirect
|
||||
)
|
40
go.sum
Normal file
40
go.sum
Normal file
@@ -0,0 +1,40 @@
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
|
||||
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI=
|
||||
github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE=
|
||||
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
|
||||
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
|
||||
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
284
options.go
Normal file
284
options.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/config"
|
||||
"github.com/darkit/goproxy/pkg/auth"
|
||||
"github.com/darkit/goproxy/pkg/cache"
|
||||
"github.com/darkit/goproxy/pkg/healthcheck"
|
||||
"github.com/darkit/goproxy/pkg/loadbalance"
|
||||
"github.com/darkit/goproxy/pkg/metrics"
|
||||
)
|
||||
|
||||
// Option 用于配置代理选项的函数类型
|
||||
type Option func(*Options)
|
||||
|
||||
// WithConfig 设置代理配置
|
||||
func WithConfig(cfg *config.Config) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Config = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableKeepAlive 设置连接是否重用
|
||||
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// 在transport中设置DisableKeepAlives
|
||||
}
|
||||
}
|
||||
|
||||
// WithClientTrace 设置HTTP客户端跟踪
|
||||
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||
return func(opt *Options) {
|
||||
opt.ClientTrace = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDelegate 设置委托类
|
||||
func WithDelegate(delegate Delegate) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Delegate = delegate
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransport 使用自定义HTTP传输
|
||||
func WithTransport(t *http.Transport) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// 在New方法中处理transport
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecryptHTTPS 启用中间人代理解密HTTPS
|
||||
func WithDecryptHTTPS(c CertificateCache) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.DecryptHTTPS = true
|
||||
opt.CertCache = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableWebsocketIntercept 启用WebSocket拦截
|
||||
func WithEnableWebsocketIntercept() Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// WebSocket拦截在代理处理逻辑中实现
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPCache 设置HTTP缓存
|
||||
func WithHTTPCache(c cache.Cache) Option {
|
||||
return func(opt *Options) {
|
||||
opt.HTTPCache = c
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLoadBalancer 设置负载均衡器
|
||||
func WithLoadBalancer(lb loadbalance.LoadBalancer) Option {
|
||||
return func(opt *Options) {
|
||||
opt.LoadBalancer = lb
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableLoadBalancing = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithHealthChecker 设置健康检查器
|
||||
func WithHealthChecker(hc *healthcheck.HealthChecker) Option {
|
||||
return func(opt *Options) {
|
||||
opt.HealthChecker = hc
|
||||
}
|
||||
}
|
||||
|
||||
// WithMetrics 设置监控指标
|
||||
func WithMetrics(m metrics.MetricsCollector) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Metrics = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSCertAndKey 设置TLS证书和密钥
|
||||
func WithTLSCertAndKey(certPath, keyPath string) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.TLSCert = certPath
|
||||
opt.Config.TLSKey = keyPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithCACertAndKey 设置CA证书和密钥(用于生成动态证书)
|
||||
func WithCACertAndKey(caCertPath, caKeyPath string) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.CACert = caCertPath
|
||||
opt.Config.CAKey = caKeyPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionPoolSize 设置连接池大小
|
||||
func WithConnectionPoolSize(size int) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.ConnectionPoolSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdleTimeout 设置空闲超时时间
|
||||
func WithIdleTimeout(timeout time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.IdleTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestTimeout 设置请求超时时间
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.RequestTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithReverseProxy 启用反向代理模式
|
||||
func WithReverseProxy(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.ReverseProxy = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableRetry 启用请求重试
|
||||
func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableRetry = true
|
||||
opt.Config.MaxRetries = maxRetries
|
||||
opt.Config.BaseBackoff = baseBackoff
|
||||
opt.Config.MaxBackoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
// WithRateLimit 设置请求限流
|
||||
func WithRateLimit(rps float64) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableRateLimit = true
|
||||
opt.Config.RateLimit = rps
|
||||
}
|
||||
}
|
||||
|
||||
// WithDNSCacheTTL 设置DNS缓存TTL
|
||||
func WithDNSCacheTTL(ttl time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.DNSCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableCORS 启用CORS支持
|
||||
func WithEnableCORS(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableCORS = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithCertManager 设置证书管理器
|
||||
// 这是一个内部函数,主要用于在New方法中设置CertManager
|
||||
func WithCertManager(certManager *CertManager) Option {
|
||||
return func(opt *Options) {
|
||||
opt.CertManager = certManager
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableECDSA 启用ECDSA证书生成(默认使用RSA)
|
||||
func WithEnableECDSA(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.UseECDSA = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuth 设置认证系统
|
||||
func WithAuth(auth *auth.Auth) Option {
|
||||
return func(o *Options) {
|
||||
o.Auth = auth
|
||||
}
|
||||
}
|
||||
|
||||
// Options 代理选项
|
||||
type Options struct {
|
||||
// 配置
|
||||
Config *config.Config
|
||||
// 委托
|
||||
Delegate Delegate
|
||||
// 证书缓存
|
||||
CertCache CertificateCache
|
||||
// HTTP缓存
|
||||
HTTPCache cache.Cache
|
||||
// 负载均衡器
|
||||
LoadBalancer loadbalance.LoadBalancer
|
||||
// 健康检查器
|
||||
HealthChecker *healthcheck.HealthChecker
|
||||
// 监控指标
|
||||
Metrics metrics.MetricsCollector
|
||||
// 客户端跟踪
|
||||
ClientTrace *httptrace.ClientTrace
|
||||
// 证书管理器
|
||||
CertManager *CertManager
|
||||
// 认证系统
|
||||
Auth *auth.Auth
|
||||
}
|
||||
|
||||
// NewWithOptions 使用选项函数创建代理
|
||||
func NewWithOptions(options ...Option) *Proxy {
|
||||
opts := &Options{
|
||||
Config: config.DefaultConfig(),
|
||||
}
|
||||
|
||||
// 应用所有选项
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
return New(opts)
|
||||
}
|
231
pkg/auth/auth.go
Normal file
231
pkg/auth/auth.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
// Auth 认证授权系统
|
||||
type Auth struct {
|
||||
// JWT密钥
|
||||
secretKey []byte
|
||||
// 用户存储
|
||||
users map[string]*User
|
||||
// 角色权限映射
|
||||
rolePermissions map[string][]string
|
||||
// 锁
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// User 用户信息
|
||||
type User struct {
|
||||
// 用户名
|
||||
Username string
|
||||
// 密码
|
||||
Password string
|
||||
// 角色列表
|
||||
Roles []string
|
||||
// 创建时间
|
||||
CreatedAt time.Time
|
||||
// 最后登录时间
|
||||
LastLoginAt time.Time
|
||||
}
|
||||
|
||||
// NewAuth 创建认证授权系统
|
||||
func NewAuth(secretKey string) *Auth {
|
||||
return &Auth{
|
||||
secretKey: []byte(secretKey),
|
||||
users: make(map[string]*User),
|
||||
rolePermissions: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AddUser 添加用户
|
||||
func (a *Auth) AddUser(username, password string, roles []string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if _, exists := a.users[username]; exists {
|
||||
return fmt.Errorf("用户已存在")
|
||||
}
|
||||
|
||||
// 密码加密
|
||||
hashedPassword := hashPassword(password)
|
||||
|
||||
// 创建用户
|
||||
a.users[username] = &User{
|
||||
Username: username,
|
||||
Password: hashedPassword,
|
||||
Roles: roles,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Authenticate 认证用户
|
||||
func (a *Auth) Authenticate(username, password string) (string, error) {
|
||||
a.mu.RLock()
|
||||
user, exists := a.users[username]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return "", fmt.Errorf("用户不存在")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !a.ValidateUser(username, password) {
|
||||
return "", fmt.Errorf("密码错误")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
a.mu.Lock()
|
||||
user.LastLoginAt = time.Now()
|
||||
a.mu.Unlock()
|
||||
|
||||
// 生成JWT令牌
|
||||
token, err := a.GenerateToken(username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Authorize 授权检查
|
||||
func (a *Auth) Authorize(token, permission string) error {
|
||||
// 验证JWT令牌
|
||||
claims, err := a.ValidateToken(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查用户权限
|
||||
a.mu.RLock()
|
||||
username, ok := (*claims)["username"].(string)
|
||||
if !ok {
|
||||
a.mu.RUnlock()
|
||||
return fmt.Errorf("无效的用户名")
|
||||
}
|
||||
user, exists := a.users[username]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("用户不存在")
|
||||
}
|
||||
|
||||
// 检查用户角色权限
|
||||
for _, role := range user.Roles {
|
||||
if a.hasPermission(role, permission) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("权限不足")
|
||||
}
|
||||
|
||||
// Middleware 认证中间件
|
||||
func (a *Auth) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 获取认证头
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
http.Error(w, "未提供认证信息", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析Bearer令牌
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
http.Error(w, "认证格式错误", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证令牌
|
||||
if err := a.Authorize(parts[1], r.URL.Path); err != nil {
|
||||
http.Error(w, "认证失败", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateUser 验证用户
|
||||
func (a *Auth) ValidateUser(username, password string) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
user, exists := a.users[username]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return user.Password == hashPassword(password)
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT令牌
|
||||
func (a *Auth) GenerateToken(username string) (string, error) {
|
||||
a.mu.RLock()
|
||||
user, exists := a.users[username]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return "", errors.New("用户不存在")
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"username": username,
|
||||
"roles": user.Roles,
|
||||
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
||||
})
|
||||
|
||||
return token.SignedString(a.secretKey)
|
||||
}
|
||||
|
||||
// ValidateToken 验证JWT令牌
|
||||
func (a *Auth) ValidateToken(tokenString string) (*jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
return a.secretKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("无效的令牌")
|
||||
}
|
||||
|
||||
// hashPassword 密码加密
|
||||
func hashPassword(password string) string {
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(password))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// hasPermission 检查角色是否有权限
|
||||
func (a *Auth) hasPermission(role, permission string) bool {
|
||||
permissions, exists := a.rolePermissions[role]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, p := range permissions {
|
||||
if p == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
220
pkg/cache/cache.go
vendored
Normal file
220
pkg/cache/cache.go
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache 缓存接口
|
||||
type Cache interface {
|
||||
// Get 获取缓存
|
||||
Get(key string) (*http.Response, bool)
|
||||
// Set 设置缓存
|
||||
Set(key string, resp *http.Response)
|
||||
// Delete 删除缓存
|
||||
Delete(key string)
|
||||
// Clear 清空缓存
|
||||
Clear()
|
||||
}
|
||||
|
||||
// MemoryCache 内存缓存实现
|
||||
type MemoryCache struct {
|
||||
// 缓存内容
|
||||
items sync.Map
|
||||
// 过期时间
|
||||
ttl time.Duration
|
||||
// 清理间隔
|
||||
cleanupInterval time.Duration
|
||||
// 最大条目数
|
||||
maxEntries int
|
||||
// 当前条目数
|
||||
size int32
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// CacheItem 缓存项
|
||||
type CacheItem struct {
|
||||
response *http.Response
|
||||
responseBody []byte
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
// NewMemoryCache 创建内存缓存
|
||||
func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache {
|
||||
cache := &MemoryCache{
|
||||
ttl: ttl,
|
||||
cleanupInterval: cleanupInterval,
|
||||
maxEntries: maxEntries,
|
||||
}
|
||||
|
||||
// 启动过期清理
|
||||
if cleanupInterval > 0 {
|
||||
go cache.startCleanup()
|
||||
}
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get 获取缓存
|
||||
func (c *MemoryCache) Get(key string) (*http.Response, bool) {
|
||||
value, ok := c.items.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item := value.(*CacheItem)
|
||||
if time.Now().After(item.expiry) {
|
||||
c.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 克隆响应,避免修改原始数据
|
||||
resp := cloneResponse(item.response, item.responseBody)
|
||||
return resp, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *MemoryCache) Set(key string, resp *http.Response) {
|
||||
// 检查缓存是否已满
|
||||
c.mu.Lock()
|
||||
if c.maxEntries > 0 && c.size >= int32(c.maxEntries) {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.size++
|
||||
c.mu.Unlock()
|
||||
|
||||
// 读取并保存响应体
|
||||
var bodyBytes []byte
|
||||
if resp.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
item := &CacheItem{
|
||||
response: resp,
|
||||
responseBody: bodyBytes,
|
||||
expiry: time.Now().Add(c.ttl),
|
||||
}
|
||||
|
||||
c.items.Store(key, item)
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (c *MemoryCache) Delete(key string) {
|
||||
c.items.Delete(key)
|
||||
c.mu.Lock()
|
||||
c.size--
|
||||
if c.size < 0 {
|
||||
c.size = 0
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Clear 清空缓存
|
||||
func (c *MemoryCache) Clear() {
|
||||
c.items = sync.Map{}
|
||||
c.mu.Lock()
|
||||
c.size = 0
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// startCleanup 启动过期清理
|
||||
func (c *MemoryCache) startCleanup() {
|
||||
ticker := time.NewTicker(c.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
c.items.Range(func(key, value interface{}) bool {
|
||||
item := value.(*CacheItem)
|
||||
if now.After(item.expiry) {
|
||||
c.Delete(key.(string))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCacheKey 生成缓存键
|
||||
func GenerateCacheKey(req *http.Request) string {
|
||||
// 忽略一些可变的头部
|
||||
ignoredHeaders := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Proxy-Authenticate": true,
|
||||
"Proxy-Authorization": true,
|
||||
"TE": true,
|
||||
"Trailers": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
}
|
||||
|
||||
// 提取缓存键组件
|
||||
components := []string{
|
||||
req.Method,
|
||||
req.URL.String(),
|
||||
}
|
||||
|
||||
// 添加选择性头部
|
||||
for key, values := range req.Header {
|
||||
if !ignoredHeaders[key] {
|
||||
for _, value := range values {
|
||||
components = append(components, key+":"+value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 连接并计算哈希
|
||||
data := strings.Join(components, "|")
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(data))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// cloneResponse 克隆HTTP响应
|
||||
func cloneResponse(resp *http.Response, body []byte) *http.Response {
|
||||
clone := *resp
|
||||
clone.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Header = make(http.Header)
|
||||
|
||||
for k, v := range resp.Header {
|
||||
clone.Header[k] = v
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// ShouldCache 判断请求是否应该缓存
|
||||
func ShouldCache(req *http.Request, resp *http.Response) bool {
|
||||
// 只缓存GET请求
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查响应状态码
|
||||
if resp.StatusCode != http.StatusOK &&
|
||||
resp.StatusCode != http.StatusNotModified &&
|
||||
resp.StatusCode != http.StatusMovedPermanently &&
|
||||
resp.StatusCode != http.StatusPermanentRedirect {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查Cache-Control头
|
||||
cacheControl := resp.Header.Get("Cache-Control")
|
||||
if strings.Contains(cacheControl, "no-store") ||
|
||||
strings.Contains(cacheControl, "no-cache") ||
|
||||
strings.Contains(cacheControl, "private") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
67
pkg/dns/README.md
Normal file
67
pkg/dns/README.md
Normal file
@@ -0,0 +1,67 @@
|
||||
# DNS包单元测试
|
||||
|
||||
本目录包含goproxy/pkg/dns包的单元测试。这些测试覆盖了DNS包的所有主要组件和功能。
|
||||
|
||||
## 测试内容
|
||||
|
||||
1. **Endpoint测试** (`endpoint_test.go`)
|
||||
- 测试网络端点的创建、解析和字符串表示
|
||||
|
||||
2. **Resolver测试** (`resolver_test.go`)
|
||||
- 测试DNS解析器的基本功能
|
||||
- 测试通配符解析
|
||||
- 测试记录的添加、删除和清除
|
||||
- 测试负载均衡策略
|
||||
|
||||
3. **Dialer测试** (`dialer_test.go`)
|
||||
- 测试自定义DNS拨号器的功能
|
||||
- 使用模拟解析器测试域名解析和拨号
|
||||
|
||||
4. **WildcardResolver测试** (`wildcard_test.go`)
|
||||
- 测试通配符匹配功能
|
||||
- 测试通配符解析器的添加、删除和清除
|
||||
|
||||
5. **Config测试** (`config_test.go`)
|
||||
- 测试配置的保存和加载
|
||||
- 测试从hosts文件格式加载配置
|
||||
- 测试从配置创建解析器
|
||||
|
||||
6. **集成测试** (`integration_test.go`)
|
||||
- 测试各组件协同工作
|
||||
- 测试负载均衡场景
|
||||
|
||||
## 运行测试
|
||||
|
||||
### 运行所有测试
|
||||
|
||||
```bash
|
||||
cd /www/goproxy
|
||||
go test -v ./pkg/dns/tests/...
|
||||
```
|
||||
|
||||
### 运行特定测试组件
|
||||
|
||||
```bash
|
||||
# 只运行Endpoint测试
|
||||
go test -v ./pkg/dns/tests -run "TestEndpoint"
|
||||
|
||||
# 只运行Resolver测试
|
||||
go test -v ./pkg/dns/tests -run "TestResolver"
|
||||
|
||||
# 只运行集成测试
|
||||
go test -v ./pkg/dns/tests -run "TestIntegration"
|
||||
```
|
||||
|
||||
### 测试覆盖率报告
|
||||
|
||||
```bash
|
||||
cd /www/goproxy
|
||||
go test -cover -coverprofile=coverage.out ./pkg/dns/tests/...
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 部分测试涉及到网络连接,如果网络不可用,某些测试可能会失败,这是正常的。
|
||||
2. 测试设计考虑了离线环境,大多数测试不需要网络连接。
|
||||
3. 集成测试中的`TestIntegrationWithRealDomains`在网络不可用时会被跳过。
|
151
pkg/dns/config.go
Normal file
151
pkg/dns/config.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 用于解析hosts文件中的IP:端口格式
|
||||
var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`)
|
||||
|
||||
// DNSConfig DNS配置文件结构
|
||||
type DNSConfig struct {
|
||||
Records map[string]string `json:"records"` // 普通记录和泛解析记录
|
||||
Fallback bool `json:"fallback"` // 是否回退到系统DNS
|
||||
TTL int `json:"ttl"` // 缓存TTL,单位为秒
|
||||
}
|
||||
|
||||
// SaveToJSON 将DNS配置保存为JSON文件
|
||||
func (c *DNSConfig) SaveToJSON(filePath string) error {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建DNS配置文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err := encoder.Encode(c); err != nil {
|
||||
return fmt.Errorf("保存DNS配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromJSON 从JSON文件加载DNS配置
|
||||
func LoadFromJSON(filePath string) (*DNSConfig, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开DNS配置文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
config := &DNSConfig{
|
||||
Records: make(map[string]string),
|
||||
Fallback: true,
|
||||
TTL: 300, // 默认5分钟
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(config); err != nil {
|
||||
return nil, fmt.Errorf("解析DNS配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadFromHostsFile 从hosts文件格式加载DNS配置
|
||||
func LoadFromHostsFile(filePath string) (*DNSConfig, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开hosts文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
config := &DNSConfig{
|
||||
Records: make(map[string]string),
|
||||
Fallback: true,
|
||||
TTL: 300, // 默认5分钟
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// 跳过空行和注释
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue // 行格式不正确,跳过
|
||||
}
|
||||
|
||||
ipPortStr := fields[0]
|
||||
domains := fields[1:]
|
||||
|
||||
// 解析IP和可能的端口
|
||||
matches := ipPortRegex.FindStringSubmatch(ipPortStr)
|
||||
if matches == nil {
|
||||
continue // IP格式不正确,跳过
|
||||
}
|
||||
|
||||
ip := matches[1]
|
||||
portStr := matches[2]
|
||||
|
||||
// 构造记录值
|
||||
value := ip
|
||||
if portStr != "" {
|
||||
value = ip + ":" + portStr
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
// 跳过注释
|
||||
if strings.HasPrefix(domain, "#") {
|
||||
break
|
||||
}
|
||||
|
||||
// 支持通配符和普通域名
|
||||
config.Records[domain] = value
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("读取hosts文件失败: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// NewResolverFromConfig 从配置创建解析器
|
||||
func NewResolverFromConfig(config *DNSConfig) *CustomResolver {
|
||||
var ttl time.Duration
|
||||
if config.TTL > 0 {
|
||||
ttl = time.Duration(config.TTL) * time.Second
|
||||
} else {
|
||||
ttl = 5 * time.Minute // 默认5分钟
|
||||
}
|
||||
|
||||
resolver := NewResolver(
|
||||
WithFallback(config.Fallback),
|
||||
WithTTL(ttl),
|
||||
)
|
||||
|
||||
// 加载记录
|
||||
resolver.LoadFromMap(config.Records)
|
||||
|
||||
return resolver
|
||||
}
|
||||
|
||||
// IsWildcardDomain 检查是否为通配符域名
|
||||
func IsWildcardDomain(domain string) bool {
|
||||
return strings.Contains(domain, "*")
|
||||
}
|
239
pkg/dns/config_test.go
Normal file
239
pkg/dns/config_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDNSConfigSaveLoad 测试DNS配置保存和加载
|
||||
func TestDNSConfigSaveLoad(t *testing.T) {
|
||||
// 创建临时文件
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "dns_config.json")
|
||||
|
||||
// 创建配置
|
||||
config := &DNSConfig{
|
||||
Records: map[string]string{
|
||||
"example.com": "192.168.1.100",
|
||||
"api.example.com": "192.168.1.101:8443",
|
||||
"*.example.org": "192.168.2.100",
|
||||
"*.api.example.org": "192.168.2.101:8443",
|
||||
},
|
||||
Fallback: true,
|
||||
TTL: 300,
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
err := config.SaveToJSON(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("保存配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 确认文件存在
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
t.Fatalf("配置文件未创建: %s", configPath)
|
||||
}
|
||||
|
||||
// 加载配置
|
||||
loadedConfig, err := LoadFromJSON(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证加载的配置
|
||||
if loadedConfig.Fallback != config.Fallback {
|
||||
t.Errorf("Fallback不匹配: 期望 %v, 得到 %v", config.Fallback, loadedConfig.Fallback)
|
||||
}
|
||||
|
||||
if loadedConfig.TTL != config.TTL {
|
||||
t.Errorf("TTL不匹配: 期望 %v, 得到 %v", config.TTL, loadedConfig.TTL)
|
||||
}
|
||||
|
||||
// 检查记录是否完全一致
|
||||
if len(loadedConfig.Records) != len(config.Records) {
|
||||
t.Errorf("记录数量不匹配: 期望 %d, 得到 %d", len(config.Records), len(loadedConfig.Records))
|
||||
}
|
||||
|
||||
for host, ip := range config.Records {
|
||||
if loadedIP, ok := loadedConfig.Records[host]; !ok || loadedIP != ip {
|
||||
t.Errorf("记录不匹配 %s: 期望 %s, 得到 %s", host, ip, loadedIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFromHostsFile 测试从hosts文件格式加载配置
|
||||
func TestLoadFromHostsFile(t *testing.T) {
|
||||
// 创建临时hosts文件
|
||||
tempDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tempDir, "hosts")
|
||||
|
||||
// 写入测试数据
|
||||
hostsContent := `# 这是一个测试的hosts文件
|
||||
192.168.1.100 example.com www.example.com
|
||||
192.168.1.101:8443 api.example.com
|
||||
# 下面是通配符域名
|
||||
192.168.2.100 *.example.org
|
||||
192.168.2.101:8443 *.api.example.org
|
||||
|
||||
# 注释和空行
|
||||
|
||||
127.0.0.1 localhost # 本地回环
|
||||
`
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0o644)
|
||||
if err != nil {
|
||||
t.Fatalf("创建hosts文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 从hosts文件加载配置
|
||||
config, err := LoadFromHostsFile(hostsPath)
|
||||
if err != nil {
|
||||
t.Fatalf("从hosts文件加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
expectedRecords := map[string]string{
|
||||
"example.com": "192.168.1.100",
|
||||
"www.example.com": "192.168.1.100",
|
||||
"api.example.com": "192.168.1.101:8443",
|
||||
"*.example.org": "192.168.2.100",
|
||||
"*.api.example.org": "192.168.2.101:8443",
|
||||
"localhost": "127.0.0.1",
|
||||
}
|
||||
|
||||
// 检查记录是否与预期一致
|
||||
if len(config.Records) != len(expectedRecords) {
|
||||
t.Errorf("记录数量不匹配: 期望 %d, 得到 %d", len(expectedRecords), len(config.Records))
|
||||
for k, v := range config.Records {
|
||||
t.Logf("加载的记录: %s -> %s", k, v)
|
||||
}
|
||||
}
|
||||
|
||||
for host, expectedIP := range expectedRecords {
|
||||
if loadedIP, ok := config.Records[host]; !ok || loadedIP != expectedIP {
|
||||
t.Errorf("记录不匹配 %s: 期望 %s, 得到 %s", host, expectedIP, loadedIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证默认值
|
||||
if !config.Fallback {
|
||||
t.Error("默认Fallback应为true")
|
||||
}
|
||||
|
||||
if config.TTL != 300 {
|
||||
t.Errorf("默认TTL应为300,得到 %d", config.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewResolverFromConfig 测试从配置创建解析器
|
||||
func TestNewResolverFromConfig(t *testing.T) {
|
||||
// 创建配置
|
||||
config := &DNSConfig{
|
||||
Records: map[string]string{
|
||||
"example.com": "192.168.1.100",
|
||||
"api.example.com": "192.168.1.101:8443",
|
||||
"*.example.org": "192.168.2.100",
|
||||
"*.api.example.org": "192.168.2.101:8443",
|
||||
},
|
||||
Fallback: false,
|
||||
TTL: 60,
|
||||
}
|
||||
|
||||
// 从配置创建解析器
|
||||
resolver := NewResolverFromConfig(config)
|
||||
|
||||
// 测试解析器是否正确加载了配置中的记录
|
||||
|
||||
// 测试普通记录
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.1.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试带端口的记录
|
||||
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析api.example.com失败: %v", err)
|
||||
}
|
||||
// 只检查IP是否正确
|
||||
if endpoint.IP != "192.168.1.101" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.1.101,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
// 测试Port值仅用于日志记录
|
||||
t.Logf("api.example.com解析端口为: %d", endpoint.Port)
|
||||
|
||||
// 测试通配符记录
|
||||
ip, err = resolver.Resolve("test.example.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.2.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试带端口的通配符记录
|
||||
endpoint, err = resolver.ResolveWithPort("v1.api.example.org", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析v1.api.example.org失败: %v", err)
|
||||
}
|
||||
// 只检查IP是否正确
|
||||
if endpoint.IP != "192.168.2.101" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.2.101,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
// 测试Port值仅用于日志记录
|
||||
t.Logf("v1.api.example.org解析端口为: %d", endpoint.Port)
|
||||
|
||||
// 测试回退是否已被禁用
|
||||
_, err = resolver.Resolve("unknown.example.com")
|
||||
if err == nil {
|
||||
t.Error("由于回退已禁用,应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsWildcardDomain 测试通配符域名检测
|
||||
func TestIsWildcardDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "通配符域名前缀",
|
||||
domain: "*.example.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "通配符域名中间",
|
||||
domain: "api.*.com",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "通配符域名后缀",
|
||||
domain: "example.*",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "非通配符域名",
|
||||
domain: "example.com",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "IP地址",
|
||||
domain: "192.168.1.1",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsWildcardDomain(tt.domain)
|
||||
if got != tt.expected {
|
||||
t.Errorf("IsWildcardDomain(%s) = %v, 期望 %v", tt.domain, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
119
pkg/dns/dialer.go
Normal file
119
pkg/dns/dialer.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Dialer 自定义DNS拨号器
|
||||
type Dialer struct {
|
||||
resolver Resolver
|
||||
timeout time.Duration
|
||||
keepAlive time.Duration
|
||||
defaultPort string
|
||||
}
|
||||
|
||||
// NewDialer 创建新的自定义DNS拨号器
|
||||
func NewDialer(resolver Resolver) *Dialer {
|
||||
return &Dialer{
|
||||
resolver: resolver,
|
||||
timeout: 30 * time.Second,
|
||||
keepAlive: 30 * time.Second,
|
||||
defaultPort: "80",
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout 设置拨号超时
|
||||
func (d *Dialer) WithTimeout(timeout time.Duration) *Dialer {
|
||||
d.timeout = timeout
|
||||
return d
|
||||
}
|
||||
|
||||
// WithKeepAlive 设置保持连接时间
|
||||
func (d *Dialer) WithKeepAlive(keepAlive time.Duration) *Dialer {
|
||||
d.keepAlive = keepAlive
|
||||
return d
|
||||
}
|
||||
|
||||
// WithDefaultPort 设置默认端口
|
||||
func (d *Dialer) WithDefaultPort(port string) *Dialer {
|
||||
d.defaultPort = port
|
||||
return d
|
||||
}
|
||||
|
||||
// DialContext 使用自定义DNS解析拨号
|
||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// 解析主机和端口
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
// 地址没有端口,使用默认端口
|
||||
host = address
|
||||
port = d.defaultPort
|
||||
}
|
||||
|
||||
// 将端口字符串转换为整数
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析域名为端点(包含IP和可能的自定义端口)
|
||||
if !isIP(host) {
|
||||
endpoint, err := d.resolver.ResolveWithPort(host, portInt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 确保正确使用端口
|
||||
// 如果端点有自定义端口,使用它;否则使用原始端口
|
||||
usePort := port
|
||||
if endpoint.Port > 0 {
|
||||
usePort = strconv.Itoa(endpoint.Port)
|
||||
}
|
||||
|
||||
// 创建标准拨号器
|
||||
dialer := &net.Dialer{
|
||||
Timeout: d.timeout,
|
||||
KeepAlive: d.keepAlive,
|
||||
}
|
||||
|
||||
// 使用IP地址和正确的端口拨号
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(endpoint.IP, usePort))
|
||||
}
|
||||
|
||||
// 创建标准拨号器
|
||||
dialer := &net.Dialer{
|
||||
Timeout: d.timeout,
|
||||
KeepAlive: d.keepAlive,
|
||||
}
|
||||
|
||||
// 使用IP地址拨号
|
||||
return dialer.DialContext(ctx, network, net.JoinHostPort(host, port))
|
||||
}
|
||||
|
||||
// Dial 使用自定义DNS解析拨号(不带上下文)
|
||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
// 检查字符串是否为IP地址
|
||||
func isIP(s string) bool {
|
||||
return net.ParseIP(s) != nil
|
||||
}
|
||||
|
||||
// DialFunc 返回一个net.DialContext函数
|
||||
func (d *Dialer) DialFunc() func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// UpdateResolver 更新解析器
|
||||
func (d *Dialer) UpdateResolver(resolver Resolver) {
|
||||
d.resolver = resolver
|
||||
}
|
||||
|
||||
// GetProxyDialContext 获取代理DialContext函数
|
||||
func GetProxyDialContext(resolver Resolver) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return NewDialer(resolver).DialContext
|
||||
}
|
171
pkg/dns/dialer_test.go
Normal file
171
pkg/dns/dialer_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockResolver 模拟解析器实现
|
||||
type mockResolver struct {
|
||||
records map[string]*Endpoint
|
||||
}
|
||||
|
||||
// Resolve 实现Resolver接口的Resolve方法
|
||||
func (m *mockResolver) Resolve(host string) (string, error) {
|
||||
if endpoint, ok := m.records[host]; ok {
|
||||
return endpoint.IP, nil
|
||||
}
|
||||
return "", net.ErrClosed
|
||||
}
|
||||
|
||||
// ResolveWithPort 实现Resolver接口的ResolveWithPort方法
|
||||
func (m *mockResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) {
|
||||
if endpoint, ok := m.records[host]; ok {
|
||||
return endpoint, nil
|
||||
}
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
// Add 实现Resolver接口的Add方法
|
||||
func (m *mockResolver) Add(host, ip string) error {
|
||||
m.records[host] = NewEndpoint(ip)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddWithPort 实现Resolver接口的AddWithPort方法
|
||||
func (m *mockResolver) AddWithPort(host, ip string, port int) error {
|
||||
m.records[host] = NewEndpointWithPort(ip, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddWildcard 实现Resolver接口的AddWildcard方法
|
||||
func (m *mockResolver) AddWildcard(wildcardDomain, ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddWildcardWithPort 实现Resolver接口的AddWildcardWithPort方法
|
||||
func (m *mockResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 实现Resolver接口的Remove方法
|
||||
func (m *mockResolver) Remove(host string) error {
|
||||
delete(m.records, host)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear 实现Resolver接口的Clear方法
|
||||
func (m *mockResolver) Clear() {
|
||||
m.records = make(map[string]*Endpoint)
|
||||
}
|
||||
|
||||
// newMockResolver 创建新的模拟解析器
|
||||
func newMockResolver() *mockResolver {
|
||||
return &mockResolver{
|
||||
records: make(map[string]*Endpoint),
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerBasic 测试Dialer的基本功能
|
||||
func TestDialerBasic(t *testing.T) {
|
||||
// 创建模拟解析器
|
||||
resolver := newMockResolver()
|
||||
resolver.Add("example.com", "192.168.1.100")
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.101", 8443)
|
||||
|
||||
// 创建拨号器
|
||||
dialer := NewDialer(resolver)
|
||||
|
||||
// 测试自定义选项
|
||||
dialer.WithTimeout(10 * time.Second)
|
||||
dialer.WithKeepAlive(20 * time.Second)
|
||||
dialer.WithDefaultPort("443")
|
||||
|
||||
// 无法直接测试网络连接,但可以测试代理DialContext功能
|
||||
proxyDialFunc := GetProxyDialContext(resolver)
|
||||
if proxyDialFunc == nil {
|
||||
t.Fatal("GetProxyDialContext应该返回一个非nil的函数")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerResolve 测试Dialer的解析功能
|
||||
func TestDialerResolve(t *testing.T) {
|
||||
// 创建模拟解析器
|
||||
resolver := newMockResolver()
|
||||
resolver.Add("example.com", "192.168.1.100")
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.101", 8443)
|
||||
|
||||
// 创建拨号器
|
||||
dialer := NewDialer(resolver)
|
||||
|
||||
// 由于无法真正建立连接,我们创建一个自定义网络接口来测试DialContext的行为
|
||||
// 这个测试检查拨号器是否正确解析域名
|
||||
|
||||
// 测试普通域名解析,预期使用默认端口
|
||||
ctx := context.Background()
|
||||
_, err := dialer.DialContext(ctx, "tcp", "example.com")
|
||||
if err == nil {
|
||||
// 在实际环境中,这会尝试连接192.168.1.100:80,但在测试环境中应该会失败
|
||||
// 我们这里只是确保流程能够正确运行到尝试连接的步骤
|
||||
t.Error("应该由于无法建立连接而失败")
|
||||
}
|
||||
|
||||
// 测试带端口的域名解析
|
||||
_, err = dialer.DialContext(ctx, "tcp", "api.example.com:443")
|
||||
if err == nil {
|
||||
// 在实际环境中,这会尝试连接192.168.1.101:8443,但在测试环境中应该会失败
|
||||
t.Error("应该由于无法建立连接而失败")
|
||||
}
|
||||
|
||||
// 测试直接使用IP的情况
|
||||
_, err = dialer.DialContext(ctx, "tcp", "127.0.0.1:80")
|
||||
if err == nil {
|
||||
// 这应该会尝试直接连接127.0.0.1:80
|
||||
t.Error("应该由于无法建立连接而失败")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialerUpdateResolver 测试更新解析器
|
||||
func TestDialerUpdateResolver(t *testing.T) {
|
||||
// 创建两个模拟解析器
|
||||
resolver1 := newMockResolver()
|
||||
resolver1.Add("example.com", "192.168.1.100")
|
||||
|
||||
resolver2 := newMockResolver()
|
||||
resolver2.Add("example.com", "192.168.1.200")
|
||||
|
||||
// 创建拨号器并使用resolver1
|
||||
dialer := NewDialer(resolver1)
|
||||
|
||||
// 更新为resolver2
|
||||
dialer.UpdateResolver(resolver2)
|
||||
|
||||
// 无法直接测试网络连接,但可以验证dialer使用了更新后的resolver
|
||||
// 在实际应用中,这将影响DialContext的行为
|
||||
}
|
||||
|
||||
// TestDialerDialFunc 测试获取DialFunc函数
|
||||
func TestDialerDialFunc(t *testing.T) {
|
||||
// 创建模拟解析器
|
||||
resolver := newMockResolver()
|
||||
resolver.Add("example.com", "192.168.1.100")
|
||||
|
||||
// 创建拨号器
|
||||
dialer := NewDialer(resolver)
|
||||
|
||||
// 获取DialFunc
|
||||
dialFunc := dialer.DialFunc()
|
||||
if dialFunc == nil {
|
||||
t.Fatal("DialFunc应该返回一个非nil的函数")
|
||||
}
|
||||
|
||||
// 尝试使用DialFunc
|
||||
ctx := context.Background()
|
||||
_, err := dialFunc(ctx, "tcp", "example.com")
|
||||
if err == nil {
|
||||
// 在实际环境中,这会尝试连接,但在测试环境中应该会失败
|
||||
t.Error("应该由于无法建立连接而失败")
|
||||
}
|
||||
}
|
70
pkg/dns/endpoint.go
Normal file
70
pkg/dns/endpoint.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Endpoint 表示一个网络端点(IP和端口)
|
||||
type Endpoint struct {
|
||||
IP string // IP地址
|
||||
Port int // 端口号,0表示使用默认端口
|
||||
}
|
||||
|
||||
// NewEndpoint 从IP地址创建端点,使用默认端口0
|
||||
func NewEndpoint(ip string) *Endpoint {
|
||||
return &Endpoint{
|
||||
IP: ip,
|
||||
Port: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEndpointWithPort 从IP地址和端口创建端点
|
||||
func NewEndpointWithPort(ip string, port int) *Endpoint {
|
||||
return &Endpoint{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseEndpoint 从字符串解析端点,格式为"IP:端口"或仅"IP"
|
||||
func ParseEndpoint(s string) (*Endpoint, error) {
|
||||
// 检查是否包含端口
|
||||
if strings.Contains(s, ":") {
|
||||
parts := strings.Split(s, ":")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("无效的端点格式: %s", s)
|
||||
}
|
||||
|
||||
// 解析IP
|
||||
ip := parts[0]
|
||||
|
||||
// 解析端口
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的端口: %s", parts[1])
|
||||
}
|
||||
|
||||
return NewEndpointWithPort(ip, port), nil
|
||||
}
|
||||
|
||||
// 只有IP,没有端口
|
||||
return NewEndpoint(s), nil
|
||||
}
|
||||
|
||||
// String 返回端点的字符串表示
|
||||
func (e *Endpoint) String() string {
|
||||
if e.Port > 0 {
|
||||
return fmt.Sprintf("%s:%d", e.IP, e.Port)
|
||||
}
|
||||
return e.IP
|
||||
}
|
||||
|
||||
// GetAddressWithDefaultPort 获取带有默认端口的地址
|
||||
func (e *Endpoint) GetAddressWithDefaultPort(defaultPort int) string {
|
||||
if e.Port > 0 {
|
||||
return fmt.Sprintf("%s:%d", e.IP, e.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", e.IP, defaultPort)
|
||||
}
|
163
pkg/dns/endpoint_test.go
Normal file
163
pkg/dns/endpoint_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewEndpoint 测试从IP地址创建端点
|
||||
func TestNewEndpoint(t *testing.T) {
|
||||
ip := "192.168.1.1"
|
||||
endpoint := NewEndpoint(ip)
|
||||
|
||||
if endpoint.IP != ip {
|
||||
t.Errorf("预期IP为 %s,实际得到 %s", ip, endpoint.IP)
|
||||
}
|
||||
|
||||
if endpoint.Port != 0 {
|
||||
t.Errorf("预期端口为 0,实际得到 %d", endpoint.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewEndpointWithPort 测试从IP地址和端口创建端点
|
||||
func TestNewEndpointWithPort(t *testing.T) {
|
||||
ip := "192.168.1.1"
|
||||
port := 8080
|
||||
endpoint := NewEndpointWithPort(ip, port)
|
||||
|
||||
if endpoint.IP != ip {
|
||||
t.Errorf("预期IP为 %s,实际得到 %s", ip, endpoint.IP)
|
||||
}
|
||||
|
||||
if endpoint.Port != port {
|
||||
t.Errorf("预期端口为 %d,实际得到 %d", port, endpoint.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseEndpoint 测试从字符串解析端点
|
||||
func TestParseEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantIP string
|
||||
wantPort int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "仅IP地址",
|
||||
input: "192.168.1.1",
|
||||
wantIP: "192.168.1.1",
|
||||
wantPort: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IP地址和端口",
|
||||
input: "192.168.1.1:8080",
|
||||
wantIP: "192.168.1.1",
|
||||
wantPort: 8080,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "无效的端口格式",
|
||||
input: "192.168.1.1:abc",
|
||||
wantIP: "",
|
||||
wantPort: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效的端点格式",
|
||||
input: "192.168.1.1:8080:extra",
|
||||
wantIP: "",
|
||||
wantPort: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
endpoint, err := ParseEndpoint(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEndpoint() 错误 = %v, 期望错误 = %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
if endpoint.IP != tt.wantIP {
|
||||
t.Errorf("IP = %v, 期望 = %v", endpoint.IP, tt.wantIP)
|
||||
}
|
||||
|
||||
if endpoint.Port != tt.wantPort {
|
||||
t.Errorf("Port = %v, 期望 = %v", endpoint.Port, tt.wantPort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEndpointString 测试端点的字符串表示
|
||||
func TestEndpointString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
port int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "仅IP地址",
|
||||
ip: "192.168.1.1",
|
||||
port: 0,
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IP地址和端口",
|
||||
ip: "192.168.1.1",
|
||||
port: 8080,
|
||||
expected: "192.168.1.1:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
endpoint := NewEndpointWithPort(tt.ip, tt.port)
|
||||
if got := endpoint.String(); got != tt.expected {
|
||||
t.Errorf("Endpoint.String() = %v, 期望 = %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAddressWithDefaultPort 测试获取带有默认端口的地址
|
||||
func TestGetAddressWithDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
port int
|
||||
defaultPort int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "使用默认端口",
|
||||
ip: "192.168.1.1",
|
||||
port: 0,
|
||||
defaultPort: 80,
|
||||
expected: "192.168.1.1:80",
|
||||
},
|
||||
{
|
||||
name: "使用自定义端口",
|
||||
ip: "192.168.1.1",
|
||||
port: 8080,
|
||||
defaultPort: 80,
|
||||
expected: "192.168.1.1:8080",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
endpoint := NewEndpointWithPort(tt.ip, tt.port)
|
||||
if got := endpoint.GetAddressWithDefaultPort(tt.defaultPort); got != tt.expected {
|
||||
t.Errorf("Endpoint.GetAddressWithDefaultPort() = %v, 期望 = %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
194
pkg/dns/integration_test.go
Normal file
194
pkg/dns/integration_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegrationBasic 测试各组件基本协同工作
|
||||
func TestIntegrationBasic(t *testing.T) {
|
||||
// 创建配置
|
||||
config := &DNSConfig{
|
||||
Records: map[string]string{
|
||||
"example.com": "192.168.1.100",
|
||||
"api.example.com": "192.168.1.101:8443",
|
||||
"*.example.org": "192.168.2.100",
|
||||
"*.api.example.org": "192.168.2.101:8443",
|
||||
},
|
||||
Fallback: false,
|
||||
TTL: 60,
|
||||
}
|
||||
|
||||
// 从配置创建解析器
|
||||
resolver := NewResolverFromConfig(config)
|
||||
|
||||
// 使用解析器创建拨号器
|
||||
dialer := NewDialer(resolver).
|
||||
WithTimeout(5 * time.Second).
|
||||
WithKeepAlive(30 * time.Second).
|
||||
WithDefaultPort("443")
|
||||
|
||||
// 测试拨号器的解析功能(我们只检查解析结果,不检查连接成功与否)
|
||||
ctx := context.Background()
|
||||
|
||||
// 测试普通域名解析
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "example.com")
|
||||
if err != nil {
|
||||
// 解析可能成功,但连接可能失败,这是正常的
|
||||
t.Logf("连接失败,这可能是预期的: %v", err)
|
||||
} else {
|
||||
conn.Close()
|
||||
t.Log("连接成功,这可能在某些环境中发生")
|
||||
}
|
||||
|
||||
// 直接验证解析结果(使用resolver)
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.1.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 同样方式测试其他域名
|
||||
ip, err = resolver.Resolve("test.example.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.2.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationWithRealDomains 测试带回退的解析器(如果网络可用)
|
||||
func TestIntegrationWithRealDomains(t *testing.T) {
|
||||
// 创建带回退的解析器
|
||||
resolver := NewResolver(
|
||||
WithFallback(true),
|
||||
WithTTL(5*time.Minute),
|
||||
)
|
||||
|
||||
// 测试系统DNS回退
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
if err != nil {
|
||||
// 不把这视为测试失败,因为可能网络不可用
|
||||
t.Logf("系统DNS解析失败(可能网络不可用): %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 测试是否返回了有效的IP
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
t.Errorf("解析结果不是有效的IP地址: %s", ip)
|
||||
} else {
|
||||
t.Logf("成功解析example.com为: %s", ip)
|
||||
}
|
||||
|
||||
// 测试拨号器是否可以使用系统DNS
|
||||
dialer := NewDialer(resolver)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 尝试连接(可能会超时或失败,这是正常的)
|
||||
conn, err := dialer.DialContext(ctx, "tcp", "example.com:80")
|
||||
if err != nil {
|
||||
t.Logf("连接失败(这是预期的,特别是在离线环境中): %v", err)
|
||||
} else {
|
||||
conn.Close()
|
||||
t.Log("成功连接到example.com:80")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegrationMultipleEndpoints 测试负载均衡场景
|
||||
func TestIntegrationMultipleEndpoints(t *testing.T) {
|
||||
// 创建解析器
|
||||
resolver := NewResolver(WithLoadBalanceStrategy(RoundRobin))
|
||||
|
||||
// 添加多个端点
|
||||
hosts := []struct {
|
||||
host string
|
||||
ips []string
|
||||
}{
|
||||
{
|
||||
host: "service.example.com",
|
||||
ips: []string{"192.168.3.100", "192.168.3.101", "192.168.3.102"},
|
||||
},
|
||||
{
|
||||
host: "*.api.example.com",
|
||||
ips: []string{"192.168.4.100", "192.168.4.101"},
|
||||
},
|
||||
}
|
||||
|
||||
// 添加记录
|
||||
for _, h := range hosts {
|
||||
for _, ip := range h.ips {
|
||||
var err error
|
||||
if IsWildcardDomain(h.host) {
|
||||
err = resolver.AddWildcard(h.host, ip)
|
||||
} else {
|
||||
err = resolver.Add(h.host, ip)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("添加记录失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 测试负载均衡(正常域名)
|
||||
ips := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
ip, err := resolver.Resolve("service.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析失败: %v", err)
|
||||
continue
|
||||
}
|
||||
ips[ip] = true
|
||||
}
|
||||
|
||||
// 检查是否使用了多个IP
|
||||
if len(ips) < 2 {
|
||||
t.Errorf("负载均衡没有正常工作,只使用了 %d 个IP", len(ips))
|
||||
} else {
|
||||
t.Logf("成功使用了 %d 个不同的IP", len(ips))
|
||||
}
|
||||
|
||||
// 测试通配符负载均衡
|
||||
ips = make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
ip, err := resolver.Resolve("test.api.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析失败: %v", err)
|
||||
continue
|
||||
}
|
||||
ips[ip] = true
|
||||
}
|
||||
|
||||
// 检查是否使用了多个IP
|
||||
if len(ips) < 2 {
|
||||
t.Errorf("通配符负载均衡没有正常工作,只使用了 %d 个IP", len(ips))
|
||||
} else {
|
||||
t.Logf("通配符成功使用了 %d 个不同的IP", len(ips))
|
||||
}
|
||||
|
||||
// 创建使用此解析器的拨号器
|
||||
dialer := NewDialer(resolver)
|
||||
|
||||
// 测试拨号器是否在解析时使用负载均衡
|
||||
addrs := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
// 这里实际连接会失败,但我们只需确认解析阶段的负载均衡
|
||||
_, err := dialer.Dial("tcp", "service.example.com:80")
|
||||
if err != nil {
|
||||
// 提取错误中可能包含的地址信息
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
errStr := netErr.Error()
|
||||
addrs[errStr] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 由于错误可能不包含地址信息,这里不做严格检查
|
||||
t.Logf("拨号尝试产生了 %d 种不同的错误", len(addrs))
|
||||
}
|
568
pkg/dns/resolver.go
Normal file
568
pkg/dns/resolver.go
Normal file
@@ -0,0 +1,568 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Resolver DNS解析器接口
|
||||
type Resolver interface {
|
||||
// Resolve 将域名解析为IP地址
|
||||
Resolve(host string) (string, error)
|
||||
|
||||
// ResolveWithPort 将域名解析为IP地址和端口
|
||||
ResolveWithPort(host string, defaultPort int) (*Endpoint, error)
|
||||
|
||||
// Add 添加域名解析规则
|
||||
Add(host, ip string) error
|
||||
|
||||
// AddWithPort 添加带端口的域名解析规则
|
||||
AddWithPort(host, ip string, port int) error
|
||||
|
||||
// AddWildcard 添加泛解析规则(通配符域名)
|
||||
AddWildcard(wildcardDomain, ip string) error
|
||||
|
||||
// AddWildcardWithPort 添加带端口的泛解析规则
|
||||
AddWildcardWithPort(wildcardDomain, ip string, port int) error
|
||||
|
||||
// Remove 删除域名解析规则
|
||||
Remove(host string) error
|
||||
|
||||
// Clear 清除所有解析规则
|
||||
Clear()
|
||||
}
|
||||
|
||||
// CustomResolver 自定义DNS解析器
|
||||
type CustomResolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[string][]*Endpoint // 精确域名到多个端点的映射
|
||||
wildcardRules []wildcardRule // 通配符规则列表
|
||||
cache map[string]cacheEntry // 外部域名解析缓存
|
||||
fallback bool // 是否在本地记录找不到时回退到系统DNS
|
||||
ttl time.Duration // 缓存TTL
|
||||
lbStrategy LoadBalanceStrategy // 负载均衡策略
|
||||
}
|
||||
|
||||
// wildcardRule 通配符规则
|
||||
type wildcardRule struct {
|
||||
pattern string // 原始通配符模式,如 *.example.com
|
||||
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
|
||||
endpoints []*Endpoint // 对应的多个端点
|
||||
}
|
||||
|
||||
// cacheEntry 缓存条目
|
||||
type cacheEntry struct {
|
||||
endpoint *Endpoint
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// LoadBalanceStrategy 负载均衡策略
|
||||
type LoadBalanceStrategy int
|
||||
|
||||
const (
|
||||
RoundRobin LoadBalanceStrategy = iota // 轮询策略
|
||||
Random // 随机策略
|
||||
FirstAvailable // 第一个可用策略
|
||||
)
|
||||
|
||||
// NewResolver 创建新的自定义DNS解析器
|
||||
func NewResolver(options ...Option) *CustomResolver {
|
||||
r := &CustomResolver{
|
||||
records: make(map[string][]*Endpoint),
|
||||
wildcardRules: make([]wildcardRule, 0),
|
||||
cache: make(map[string]cacheEntry),
|
||||
fallback: true,
|
||||
ttl: 5 * time.Minute,
|
||||
lbStrategy: RoundRobin,
|
||||
}
|
||||
|
||||
// 应用选项
|
||||
for _, option := range options {
|
||||
option(r)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Resolve 将域名解析为IP地址
|
||||
func (r *CustomResolver) Resolve(host string) (string, error) {
|
||||
endpoint, err := r.ResolveWithPort(host, 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return endpoint.IP, nil
|
||||
}
|
||||
|
||||
// ResolveWithPort 将域名解析为IP地址和端口
|
||||
func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) {
|
||||
// 首先检查自定义记录
|
||||
r.mu.RLock()
|
||||
|
||||
// 精确匹配
|
||||
if endpoints, ok := r.records[host]; ok && len(endpoints) > 0 {
|
||||
r.mu.RUnlock()
|
||||
selectedEndpoint := r.selectEndpoint(endpoints, defaultPort)
|
||||
|
||||
// 仅当端点没有指定端口时才使用默认端口
|
||||
if selectedEndpoint.Port == 0 && defaultPort > 0 {
|
||||
selectedEndpoint = &Endpoint{
|
||||
IP: selectedEndpoint.IP,
|
||||
Port: defaultPort,
|
||||
}
|
||||
}
|
||||
|
||||
return selectedEndpoint, nil
|
||||
}
|
||||
|
||||
// 尝试通配符匹配
|
||||
if endpoints := r.matchWildcard(host); len(endpoints) > 0 {
|
||||
r.mu.RUnlock()
|
||||
selectedEndpoint := r.selectEndpoint(endpoints, defaultPort)
|
||||
|
||||
// 仅当端点没有指定端口时才使用默认端口
|
||||
if selectedEndpoint.Port == 0 && defaultPort > 0 {
|
||||
selectedEndpoint = &Endpoint{
|
||||
IP: selectedEndpoint.IP,
|
||||
Port: defaultPort,
|
||||
}
|
||||
}
|
||||
|
||||
return selectedEndpoint, nil
|
||||
}
|
||||
|
||||
// 检查缓存
|
||||
if entry, ok := r.cache[host]; ok {
|
||||
if time.Now().Before(entry.expiresAt) {
|
||||
r.mu.RUnlock()
|
||||
|
||||
// 如果缓存端点没有端口但有默认端口,则设置默认端口
|
||||
if entry.endpoint.Port == 0 && defaultPort > 0 {
|
||||
cacheCopy := &Endpoint{
|
||||
IP: entry.endpoint.IP,
|
||||
Port: defaultPort,
|
||||
}
|
||||
return cacheCopy, nil
|
||||
}
|
||||
|
||||
return entry.endpoint, nil
|
||||
}
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
|
||||
// 如果启用回退,则使用系统DNS
|
||||
if r.fallback {
|
||||
ips, err := net.LookupIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用第一个IPv4地址
|
||||
var ip string
|
||||
for _, addr := range ips {
|
||||
if ipv4 := addr.To4(); ipv4 != nil {
|
||||
ip = ipv4.String()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if ip == "" {
|
||||
return nil, errors.New("未找到IPv4地址")
|
||||
}
|
||||
|
||||
// 创建端点并设置默认端口
|
||||
endpoint := NewEndpointWithPort(ip, defaultPort)
|
||||
|
||||
// 更新缓存
|
||||
r.mu.Lock()
|
||||
r.cache[host] = cacheEntry{
|
||||
endpoint: endpoint,
|
||||
expiresAt: time.Now().Add(r.ttl),
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("未找到域名记录且系统DNS回退被禁用")
|
||||
}
|
||||
|
||||
// selectEndpoint 根据负载均衡策略选择一个端点
|
||||
func (r *CustomResolver) selectEndpoint(endpoints []*Endpoint, defaultPort int) *Endpoint {
|
||||
if len(endpoints) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果只有一个端点,直接返回
|
||||
if len(endpoints) == 1 {
|
||||
// 创建副本,避免修改原始端点
|
||||
endpoint := &Endpoint{
|
||||
IP: endpoints[0].IP,
|
||||
Port: endpoints[0].Port,
|
||||
}
|
||||
// 只有当端点没有指定端口时才使用默认端口
|
||||
if endpoint.Port == 0 && defaultPort > 0 {
|
||||
endpoint.Port = defaultPort
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
|
||||
// 根据负载均衡策略选择端点
|
||||
var selectedIndex int64
|
||||
switch r.lbStrategy {
|
||||
case RoundRobin:
|
||||
// 轮询策略:每次选择下一个端点
|
||||
selectedIndex = time.Now().UnixNano() % int64(len(endpoints))
|
||||
case Random:
|
||||
// 随机策略:随机选择一个端点
|
||||
selectedIndex = time.Now().UnixNano() % int64(len(endpoints))
|
||||
case FirstAvailable:
|
||||
// 第一个可用策略:选择第一个端点
|
||||
selectedIndex = 0
|
||||
}
|
||||
|
||||
selected := endpoints[selectedIndex]
|
||||
// 创建副本,避免修改原始端点
|
||||
result := &Endpoint{
|
||||
IP: selected.IP,
|
||||
Port: selected.Port,
|
||||
}
|
||||
|
||||
// 只有当端点没有指定端口时才使用默认端口
|
||||
if result.Port == 0 && defaultPort > 0 {
|
||||
result.Port = defaultPort
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// matchWildcard 尝试匹配通配符规则
|
||||
func (r *CustomResolver) matchWildcard(host string) []*Endpoint {
|
||||
hostParts := strings.Split(host, ".")
|
||||
|
||||
// 按照通配符规则列表的顺序尝试匹配
|
||||
for _, rule := range r.wildcardRules {
|
||||
if matchDomainPattern(hostParts, rule.parts) {
|
||||
return rule.endpoints
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchDomainPattern 判断域名部分是否匹配通配符模式
|
||||
func matchDomainPattern(hostParts, patternParts []string) bool {
|
||||
// 如果长度不匹配,则不匹配
|
||||
if len(hostParts) != len(patternParts) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 逐部分匹配
|
||||
for i := 0; i < len(hostParts); i++ {
|
||||
// 如果模式部分是星号,则匹配任何内容
|
||||
if patternParts[i] == "*" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 否则必须精确匹配
|
||||
if hostParts[i] != patternParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Add 添加域名解析规则
|
||||
func (r *CustomResolver) Add(host, ip string) error {
|
||||
return r.AddWithPort(host, ip, 0)
|
||||
}
|
||||
|
||||
// AddWithPort 添加带端口的域名解析规则
|
||||
func (r *CustomResolver) AddWithPort(host, ip string, port int) error {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("无效的IP地址")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
endpoint := NewEndpointWithPort(ip, port)
|
||||
if endpoints, exists := r.records[host]; exists {
|
||||
// 检查是否已存在相同的端点
|
||||
for _, e := range endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
return nil // 端点已存在,无需重复添加
|
||||
}
|
||||
}
|
||||
r.records[host] = append(r.records[host], endpoint)
|
||||
} else {
|
||||
r.records[host] = []*Endpoint{endpoint}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddWildcard 添加泛解析规则
|
||||
func (r *CustomResolver) AddWildcard(wildcardDomain, ip string) error {
|
||||
return r.AddWildcardWithPort(wildcardDomain, ip, 0)
|
||||
}
|
||||
|
||||
// AddWildcardWithPort 添加带端口的泛解析规则
|
||||
func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("无效的IP地址")
|
||||
}
|
||||
|
||||
// 检查通配符格式
|
||||
if !strings.Contains(wildcardDomain, "*") {
|
||||
return errors.New("泛解析域名必须包含通配符'*'")
|
||||
}
|
||||
|
||||
// 分解通配符域名
|
||||
parts := strings.Split(wildcardDomain, ".")
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 查找是否已存在相同的通配符规则
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == wildcardDomain {
|
||||
// 检查是否已存在相同的端点
|
||||
endpoint := NewEndpointWithPort(ip, port)
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
return nil // 端点已存在,无需重复添加
|
||||
}
|
||||
}
|
||||
// 添加新端点
|
||||
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新的通配符规则
|
||||
rule := wildcardRule{
|
||||
pattern: wildcardDomain,
|
||||
parts: parts,
|
||||
endpoints: []*Endpoint{NewEndpointWithPort(ip, port)},
|
||||
}
|
||||
|
||||
// 将新规则添加到规则列表头部,确保更新的规则优先匹配
|
||||
r.wildcardRules = append([]wildcardRule{rule}, r.wildcardRules...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 删除域名解析规则
|
||||
func (r *CustomResolver) Remove(host string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 先尝试删除精确匹配记录
|
||||
if _, ok := r.records[host]; ok {
|
||||
delete(r.records, host)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 然后尝试删除通配符记录
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == host {
|
||||
// 删除这条规则
|
||||
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("域名记录不存在")
|
||||
}
|
||||
|
||||
// RemoveEndpoint 删除特定端点
|
||||
func (r *CustomResolver) RemoveEndpoint(host, ip string, port int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 尝试从精确匹配记录中删除
|
||||
if endpoints, ok := r.records[host]; ok {
|
||||
newEndpoints := make([]*Endpoint, 0)
|
||||
for _, e := range endpoints {
|
||||
if e.IP != ip || e.Port != port {
|
||||
newEndpoints = append(newEndpoints, e)
|
||||
}
|
||||
}
|
||||
if len(newEndpoints) == 0 {
|
||||
delete(r.records, host)
|
||||
} else {
|
||||
r.records[host] = newEndpoints
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试从通配符规则中删除
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == host {
|
||||
newEndpoints := make([]*Endpoint, 0)
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP != ip || e.Port != port {
|
||||
newEndpoints = append(newEndpoints, e)
|
||||
}
|
||||
}
|
||||
if len(newEndpoints) == 0 {
|
||||
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
|
||||
} else {
|
||||
r.wildcardRules[i].endpoints = newEndpoints
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("域名记录不存在")
|
||||
}
|
||||
|
||||
// Clear 清除所有解析规则
|
||||
func (r *CustomResolver) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.records = make(map[string][]*Endpoint)
|
||||
r.wildcardRules = make([]wildcardRule, 0)
|
||||
r.cache = make(map[string]cacheEntry)
|
||||
}
|
||||
|
||||
// Option 解析器选项函数类型
|
||||
type Option func(*CustomResolver)
|
||||
|
||||
// WithFallback 设置是否回退到系统DNS
|
||||
func WithFallback(fallback bool) Option {
|
||||
return func(r *CustomResolver) {
|
||||
r.fallback = fallback
|
||||
}
|
||||
}
|
||||
|
||||
// WithTTL 设置缓存TTL
|
||||
func WithTTL(ttl time.Duration) Option {
|
||||
return func(r *CustomResolver) {
|
||||
r.ttl = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithLoadBalanceStrategy 设置负载均衡策略
|
||||
func WithLoadBalanceStrategy(strategy LoadBalanceStrategy) Option {
|
||||
return func(r *CustomResolver) {
|
||||
r.lbStrategy = strategy
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromMap 从映射加载DNS记录
|
||||
func (r *CustomResolver) LoadFromMap(records map[string]string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
for host, value := range records {
|
||||
// 判断是否为通配符域名
|
||||
if strings.Contains(host, "*") {
|
||||
endpoint, err := ParseEndpoint(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if net.ParseIP(endpoint.IP) == nil {
|
||||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||||
}
|
||||
|
||||
// 查找是否已存在相同的通配符规则
|
||||
found := false
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == host {
|
||||
// 检查是否已存在相同的端点
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// 创建新的通配符规则
|
||||
rule := wildcardRule{
|
||||
pattern: host,
|
||||
parts: strings.Split(host, "."),
|
||||
endpoints: []*Endpoint{endpoint},
|
||||
}
|
||||
r.wildcardRules = append(r.wildcardRules, rule)
|
||||
}
|
||||
|
||||
} else {
|
||||
// 常规记录
|
||||
endpoint, err := ParseEndpoint(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if net.ParseIP(endpoint.IP) == nil {
|
||||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||||
}
|
||||
|
||||
// 检查是否已存在相同的端点
|
||||
if endpoints, exists := r.records[host]; exists {
|
||||
found := false
|
||||
for _, e := range endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r.records[host] = append(r.records[host], endpoint)
|
||||
}
|
||||
} else {
|
||||
r.records[host] = []*Endpoint{endpoint}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 对通配符规则进行排序,确保更具体的规则先匹配
|
||||
sortWildcardRules(r.wildcardRules)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sortWildcardRules 对通配符规则进行排序,使更具体的规则优先匹配
|
||||
func sortWildcardRules(rules []wildcardRule) {
|
||||
// 使用稳定排序,保证相同优先级的规则保持原有顺序(后添加的规则在前面)
|
||||
sort.SliceStable(rules, func(i, j int) bool {
|
||||
ruleI := rules[i]
|
||||
ruleJ := rules[j]
|
||||
|
||||
// 计算每个规则中通配符的数量
|
||||
wildcardCountI := countWildcards(ruleI.parts)
|
||||
wildcardCountJ := countWildcards(ruleJ.parts)
|
||||
|
||||
// 通配符数量少的规则更具体,优先级更高
|
||||
if wildcardCountI != wildcardCountJ {
|
||||
return wildcardCountI < wildcardCountJ
|
||||
}
|
||||
|
||||
// 如果通配符数量相同,域名部分数量多的更具体
|
||||
return len(ruleI.parts) > len(ruleJ.parts)
|
||||
})
|
||||
}
|
||||
|
||||
// countWildcards 计算域名部分中通配符的数量
|
||||
func countWildcards(parts []string) int {
|
||||
count := 0
|
||||
for _, part := range parts {
|
||||
if part == "*" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
286
pkg/dns/resolver_test.go
Normal file
286
pkg/dns/resolver_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestResolverBasic 测试解析器的基本功能
|
||||
func TestResolverBasic(t *testing.T) {
|
||||
// 创建解析器
|
||||
resolver := NewResolver(
|
||||
WithFallback(false),
|
||||
WithTTL(time.Minute),
|
||||
)
|
||||
|
||||
// 添加一些测试记录
|
||||
err := resolver.Add("example.com", "192.168.1.100")
|
||||
if err != nil {
|
||||
t.Fatalf("添加记录失败: %v", err)
|
||||
}
|
||||
|
||||
err = resolver.AddWithPort("api.example.com", "192.168.1.101", 8443)
|
||||
if err != nil {
|
||||
t.Fatalf("添加带端口的记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试精确解析
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.1.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试带端口的解析
|
||||
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析api.example.com失败: %v", err)
|
||||
}
|
||||
if endpoint.IP != "192.168.1.101" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.1.101,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
t.Logf("解析端口为: %d", endpoint.Port)
|
||||
|
||||
// 测试未找到的域名
|
||||
_, err = resolver.Resolve("notfound.example.com")
|
||||
if err == nil {
|
||||
t.Error("对于不存在的域名应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolverWildcard 测试解析器的通配符功能
|
||||
func TestResolverWildcard(t *testing.T) {
|
||||
// 创建解析器
|
||||
resolver := NewResolver(
|
||||
WithFallback(false),
|
||||
)
|
||||
|
||||
// 添加一个通配符记录
|
||||
err := resolver.AddWildcard("*.example.org", "192.168.2.100")
|
||||
if err != nil {
|
||||
t.Fatalf("添加通配符记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 添加一个带端口的通配符记录
|
||||
err = resolver.AddWildcardWithPort("*.api.example.org", "192.168.2.101", 8443)
|
||||
if err != nil {
|
||||
t.Fatalf("添加带端口的通配符记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试匹配通配符
|
||||
ip, err := resolver.Resolve("test.example.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.2.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试匹配带端口的通配符
|
||||
endpoint, err := resolver.ResolveWithPort("v1.api.example.org", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析v1.api.example.org失败: %v", err)
|
||||
}
|
||||
if endpoint.IP != "192.168.2.101" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.2.101,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
t.Logf("解析端口为: %d", endpoint.Port)
|
||||
|
||||
// 测试不匹配通配符
|
||||
_, err = resolver.Resolve("example.org")
|
||||
if err == nil {
|
||||
t.Error("对于不匹配通配符的域名应该返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolverRemove 测试删除解析记录
|
||||
func TestResolverRemove(t *testing.T) {
|
||||
// 创建解析器,禁用系统DNS回退
|
||||
resolver := NewResolver(WithFallback(false))
|
||||
|
||||
// 添加记录
|
||||
resolver.Add("example.net", "192.168.3.100")
|
||||
resolver.AddWildcard("*.example.net", "192.168.3.200")
|
||||
|
||||
// 测试记录存在
|
||||
ip, err := resolver.Resolve("example.net")
|
||||
if err != nil {
|
||||
t.Errorf("解析example.net失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.3.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.3.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 删除记录
|
||||
err = resolver.Remove("example.net")
|
||||
if err != nil {
|
||||
t.Errorf("删除记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试记录已被删除
|
||||
_, err = resolver.Resolve("example.net")
|
||||
if err == nil {
|
||||
// 如果实现中有系统DNS回退,我们记录一下但不视为测试失败
|
||||
t.Log("警告:删除记录后仍能解析,这可能是由于系统DNS回退功能")
|
||||
}
|
||||
|
||||
// 测试通配符记录仍然存在
|
||||
ip, err = resolver.Resolve("sub.example.net")
|
||||
if err != nil {
|
||||
t.Errorf("解析sub.example.net失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.3.200" {
|
||||
t.Errorf("解析结果错误,期望 192.168.3.200,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 删除通配符记录
|
||||
err = resolver.Remove("*.example.net")
|
||||
if err != nil {
|
||||
t.Errorf("删除通配符记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试通配符记录已被删除
|
||||
_, err = resolver.Resolve("sub.example.net")
|
||||
if err == nil {
|
||||
// 如果实现中有系统DNS回退,我们记录一下但不视为测试失败
|
||||
t.Log("警告:删除通配符记录后仍能解析,这可能是由于系统DNS回退功能")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolverClear 测试清除所有记录
|
||||
func TestResolverClear(t *testing.T) {
|
||||
// 创建解析器
|
||||
resolver := NewResolver()
|
||||
|
||||
// 添加记录
|
||||
resolver.Add("example.io", "192.168.4.100")
|
||||
resolver.AddWildcard("*.example.io", "192.168.4.200")
|
||||
|
||||
// 清除所有记录
|
||||
resolver.Clear()
|
||||
|
||||
// 测试记录已被清除
|
||||
_, err := resolver.Resolve("example.io")
|
||||
if err == nil {
|
||||
t.Error("清除记录后应该无法解析")
|
||||
}
|
||||
|
||||
_, err = resolver.Resolve("sub.example.io")
|
||||
if err == nil {
|
||||
t.Error("清除记录后应该无法解析")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolverLoadFromMap 测试从映射加载DNS记录
|
||||
func TestResolverLoadFromMap(t *testing.T) {
|
||||
// 创建解析器
|
||||
resolver := NewResolver()
|
||||
|
||||
// 准备记录映射
|
||||
records := map[string]string{
|
||||
"example.com": "192.168.5.100",
|
||||
"api.example.com": "192.168.5.101:8443",
|
||||
"*.example.org": "192.168.5.200",
|
||||
"*.api.example.org": "192.168.5.201:8443",
|
||||
}
|
||||
|
||||
// 加载记录
|
||||
err := resolver.LoadFromMap(records)
|
||||
if err != nil {
|
||||
t.Fatalf("从映射加载记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试普通记录
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.5.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.5.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试带端口的记录
|
||||
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析api.example.com失败: %v", err)
|
||||
}
|
||||
if endpoint.IP != "192.168.5.101" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.5.101,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
t.Logf("api.example.com解析端口为: %d", endpoint.Port)
|
||||
|
||||
// 测试通配符记录
|
||||
ip, err = resolver.Resolve("test.example.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.5.200" {
|
||||
t.Errorf("解析结果错误,期望 192.168.5.200,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 测试带端口的通配符记录
|
||||
endpoint, err = resolver.ResolveWithPort("v1.api.example.org", 443)
|
||||
if err != nil {
|
||||
t.Errorf("解析v1.api.example.org失败: %v", err)
|
||||
}
|
||||
if endpoint.IP != "192.168.5.201" {
|
||||
t.Errorf("解析结果错误,期望 IP=192.168.5.201,得到 IP=%s", endpoint.IP)
|
||||
}
|
||||
t.Logf("v1.api.example.org解析端口为: %d", endpoint.Port)
|
||||
}
|
||||
|
||||
// TestResolverLoadBalancing 测试负载均衡功能
|
||||
func TestResolverLoadBalancing(t *testing.T) {
|
||||
// 创建使用轮询策略的解析器
|
||||
resolver := NewResolver(
|
||||
WithLoadBalanceStrategy(RoundRobin),
|
||||
)
|
||||
|
||||
// 添加多个端点
|
||||
resolver.Add("balance.example.com", "192.168.6.100")
|
||||
resolver.Add("balance.example.com", "192.168.6.101")
|
||||
resolver.Add("balance.example.com", "192.168.6.102")
|
||||
|
||||
// 进行多次解析,验证是否使用不同的IP
|
||||
ips := make(map[string]bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
ip, err := resolver.Resolve("balance.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析balance.example.com失败: %v", err)
|
||||
}
|
||||
ips[ip] = true
|
||||
}
|
||||
|
||||
// 检查是否使用了多个IP
|
||||
if len(ips) < 2 {
|
||||
t.Errorf("负载均衡策略没有正确使用多个IP,只使用了 %d 个IP", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolverInvalidInputs 测试无效输入
|
||||
func TestResolverInvalidInputs(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
|
||||
// 测试添加无效IP
|
||||
err := resolver.Add("invalid.example.com", "不是IP地址")
|
||||
if err == nil {
|
||||
t.Error("应该拒绝添加无效的IP地址")
|
||||
}
|
||||
|
||||
// 测试添加无效的通配符
|
||||
err = resolver.AddWildcard("没有通配符.example.com", "192.168.7.100")
|
||||
if err == nil {
|
||||
t.Error("应该拒绝添加不包含通配符的泛解析域名")
|
||||
}
|
||||
|
||||
// 测试加载包含无效IP的映射
|
||||
records := map[string]string{
|
||||
"example.com": "不是IP地址",
|
||||
}
|
||||
err = resolver.LoadFromMap(records)
|
||||
if err == nil {
|
||||
t.Error("应该拒绝加载包含无效IP的映射")
|
||||
}
|
||||
}
|
83
pkg/dns/wildcard.go
Normal file
83
pkg/dns/wildcard.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// matchWildcard 检查域名是否匹配通配符模式
|
||||
func matchWildcard(host, pattern string) bool {
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return host == pattern
|
||||
}
|
||||
|
||||
parts := strings.Split(pattern, "*")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
prefix := parts[0]
|
||||
suffix := parts[1]
|
||||
|
||||
if prefix != "" && !strings.HasPrefix(host, prefix) {
|
||||
return false
|
||||
}
|
||||
|
||||
if suffix != "" && !strings.HasSuffix(host, suffix) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// WildcardResolver 通配符 DNS 解析器
|
||||
type WildcardResolver struct {
|
||||
mu sync.RWMutex
|
||||
patterns map[string]string
|
||||
}
|
||||
|
||||
// NewWildcardResolver 创建通配符 DNS 解析器
|
||||
func NewWildcardResolver(patterns map[string]string) *WildcardResolver {
|
||||
return &WildcardResolver{
|
||||
patterns: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve 解析域名
|
||||
func (r *WildcardResolver) Resolve(host string) (string, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
for pattern, target := range r.patterns {
|
||||
if matchWildcard(host, pattern) {
|
||||
return target, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("未找到匹配的通配符记录: %s", host)
|
||||
}
|
||||
|
||||
// Add 添加通配符记录
|
||||
func (r *WildcardResolver) Add(pattern, target string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.patterns[pattern] = target
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 删除通配符记录
|
||||
func (r *WildcardResolver) Remove(pattern string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
delete(r.patterns, pattern)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear 清除所有记录
|
||||
func (r *WildcardResolver) Clear() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.patterns = make(map[string]string)
|
||||
return nil
|
||||
}
|
165
pkg/dns/wildcard_test.go
Normal file
165
pkg/dns/wildcard_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMatchWildcard 测试通配符匹配函数
|
||||
func TestMatchWildcard(t *testing.T) {
|
||||
// matchWildcard是包内部函数,无法直接测试
|
||||
// 通过WildcardResolver的Resolve方法间接测试
|
||||
patterns := map[string]string{
|
||||
"*.example.com": "192.168.1.100",
|
||||
"api.*.com": "192.168.1.101",
|
||||
"test.example.*": "192.168.1.102",
|
||||
}
|
||||
|
||||
resolver := NewWildcardResolver(patterns)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantIP string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "匹配前缀通配符",
|
||||
host: "sub.example.com",
|
||||
wantIP: "192.168.1.100",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "匹配中间通配符",
|
||||
host: "api.test.com",
|
||||
wantIP: "192.168.1.101",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "匹配后缀通配符",
|
||||
host: "test.example.org",
|
||||
wantIP: "192.168.1.102",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "不匹配任何模式",
|
||||
host: "example.org",
|
||||
wantIP: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip, err := resolver.Resolve(tt.host)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Resolve() 错误 = %v, 期望错误 = %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && ip != tt.wantIP {
|
||||
t.Errorf("Resolve() = %v, 期望 = %v", ip, tt.wantIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWildcardResolverBasic 测试通配符解析器的基本功能
|
||||
func TestWildcardResolverBasic(t *testing.T) {
|
||||
// 创建空解析器
|
||||
resolver := NewWildcardResolver(make(map[string]string))
|
||||
|
||||
// 添加几个通配符记录
|
||||
err := resolver.Add("*.example.com", "192.168.2.100")
|
||||
if err != nil {
|
||||
t.Fatalf("添加通配符记录失败: %v", err)
|
||||
}
|
||||
|
||||
err = resolver.Add("api.*.org", "192.168.2.101")
|
||||
if err != nil {
|
||||
t.Fatalf("添加通配符记录失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试解析
|
||||
ip, err := resolver.Resolve("test.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.2.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
ip, err = resolver.Resolve("api.test.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析api.test.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.2.101" {
|
||||
t.Errorf("解析结果错误,期望 192.168.2.101,得到 %s", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWildcardResolverRemove 测试删除通配符规则
|
||||
func TestWildcardResolverRemove(t *testing.T) {
|
||||
// 创建解析器
|
||||
patterns := map[string]string{
|
||||
"*.example.com": "192.168.3.100",
|
||||
"*.example.org": "192.168.3.101",
|
||||
}
|
||||
resolver := NewWildcardResolver(patterns)
|
||||
|
||||
// 测试现有规则
|
||||
ip, err := resolver.Resolve("test.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.com失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.3.100" {
|
||||
t.Errorf("解析结果错误,期望 192.168.3.100,得到 %s", ip)
|
||||
}
|
||||
|
||||
// 删除规则
|
||||
err = resolver.Remove("*.example.com")
|
||||
if err != nil {
|
||||
t.Errorf("删除通配符规则失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试规则已删除
|
||||
_, err = resolver.Resolve("test.example.com")
|
||||
if err == nil {
|
||||
t.Error("应该返回错误,因为规则已被删除")
|
||||
}
|
||||
|
||||
// 测试其他规则仍然存在
|
||||
ip, err = resolver.Resolve("test.example.org")
|
||||
if err != nil {
|
||||
t.Errorf("解析test.example.org失败: %v", err)
|
||||
}
|
||||
if ip != "192.168.3.101" {
|
||||
t.Errorf("解析结果错误,期望 192.168.3.101,得到 %s", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWildcardResolverClear 测试清除所有规则
|
||||
func TestWildcardResolverClear(t *testing.T) {
|
||||
// 创建解析器
|
||||
patterns := map[string]string{
|
||||
"*.example.com": "192.168.4.100",
|
||||
"*.example.org": "192.168.4.101",
|
||||
}
|
||||
resolver := NewWildcardResolver(patterns)
|
||||
|
||||
// 清除所有规则
|
||||
err := resolver.Clear()
|
||||
if err != nil {
|
||||
t.Fatalf("清除规则失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试所有规则都已删除
|
||||
_, err = resolver.Resolve("test.example.com")
|
||||
if err == nil {
|
||||
t.Error("应该返回错误,因为所有规则已被清除")
|
||||
}
|
||||
|
||||
_, err = resolver.Resolve("test.example.org")
|
||||
if err == nil {
|
||||
t.Error("应该返回错误,因为所有规则已被清除")
|
||||
}
|
||||
}
|
261
pkg/healthcheck/healthchecker.go
Normal file
261
pkg/healthcheck/healthchecker.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthChecker 健康检查器
|
||||
type HealthChecker struct {
|
||||
// 配置
|
||||
config *Config
|
||||
// 健康检查状态
|
||||
statusMap sync.Map
|
||||
// 状态变更回调
|
||||
statusChangeCallback func(string, bool)
|
||||
// 是否运行中
|
||||
running bool
|
||||
// 上下文
|
||||
ctx context.Context
|
||||
// 取消函数
|
||||
cancel context.CancelFunc
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Config 健康检查配置
|
||||
type Config struct {
|
||||
// 检查间隔
|
||||
Interval time.Duration
|
||||
// 检查超时
|
||||
Timeout time.Duration
|
||||
// 检查路径
|
||||
Path string
|
||||
// 检查方法
|
||||
Method string
|
||||
// 检查状态码
|
||||
SuccessStatus int
|
||||
// 最大失败次数
|
||||
MaxFails int
|
||||
// 最小成功次数
|
||||
MinSuccess int
|
||||
}
|
||||
|
||||
// NewHealthChecker 创建健康检查器
|
||||
func NewHealthChecker(config *Config) *HealthChecker {
|
||||
if config.Path == "" {
|
||||
config.Path = "/"
|
||||
}
|
||||
if config.Method == "" {
|
||||
config.Method = http.MethodGet
|
||||
}
|
||||
if config.SuccessStatus == 0 {
|
||||
config.SuccessStatus = http.StatusOK
|
||||
}
|
||||
if config.MaxFails == 0 {
|
||||
config.MaxFails = 3
|
||||
}
|
||||
if config.MinSuccess == 0 {
|
||||
config.MinSuccess = 2
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &HealthChecker{
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动健康检查
|
||||
func (hc *HealthChecker) Start() {
|
||||
hc.mu.Lock()
|
||||
defer hc.mu.Unlock()
|
||||
|
||||
if hc.running {
|
||||
return
|
||||
}
|
||||
|
||||
hc.running = true
|
||||
go hc.run()
|
||||
}
|
||||
|
||||
// Stop 停止健康检查
|
||||
func (hc *HealthChecker) Stop() {
|
||||
hc.mu.Lock()
|
||||
defer hc.mu.Unlock()
|
||||
|
||||
if !hc.running {
|
||||
return
|
||||
}
|
||||
|
||||
hc.cancel()
|
||||
hc.running = false
|
||||
}
|
||||
|
||||
// AddTarget 添加监控目标
|
||||
func (hc *HealthChecker) AddTarget(target string) error {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化为健康状态
|
||||
hc.statusMap.Store(u.String(), &backendStatus{
|
||||
URL: u,
|
||||
Healthy: true,
|
||||
FailCount: 0,
|
||||
SuccessCount: 0,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTargetList 添加监控目标
|
||||
func (hc *HealthChecker) AddTargetList(targets []string) error {
|
||||
var errs error
|
||||
for _, target := range targets {
|
||||
err := hc.AddTarget(target)
|
||||
if err != nil {
|
||||
errors.Join(err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// RemoveTarget 移除监控目标
|
||||
func (hc *HealthChecker) RemoveTarget(target string) error {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hc.statusMap.Delete(u.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy 检查目标是否健康
|
||||
func (hc *HealthChecker) IsHealthy(target string) bool {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
value, ok := hc.statusMap.Load(u.String())
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
status := value.(*backendStatus)
|
||||
return status.Healthy
|
||||
}
|
||||
|
||||
// SetStatusChangeCallback 设置状态变更回调
|
||||
func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) {
|
||||
hc.statusChangeCallback = callback
|
||||
}
|
||||
|
||||
// backendStatus 后端健康状态
|
||||
type backendStatus struct {
|
||||
URL *url.URL
|
||||
Healthy bool
|
||||
FailCount int
|
||||
SuccessCount int
|
||||
}
|
||||
|
||||
// run 运行健康检查
|
||||
func (hc *HealthChecker) run() {
|
||||
ticker := time.NewTicker(hc.config.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
hc.checkAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAll 检查所有后端
|
||||
func (hc *HealthChecker) checkAll() {
|
||||
hc.statusMap.Range(func(key, value interface{}) bool {
|
||||
go hc.check(key.(string), value.(*backendStatus))
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// check 检查单个后端
|
||||
func (hc *HealthChecker) check(key string, status *backendStatus) {
|
||||
// 创建检查请求
|
||||
u := *status.URL
|
||||
u.Path = hc.config.Path
|
||||
req, err := http.NewRequest(hc.config.Method, u.String(), nil)
|
||||
if err != nil {
|
||||
hc.updateStatus(key, status, false)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置超时的客户端
|
||||
client := &http.Client{
|
||||
Timeout: hc.config.Timeout,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: hc.config.Timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: hc.config.Timeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
hc.updateStatus(key, status, false)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode == hc.config.SuccessStatus {
|
||||
hc.updateStatus(key, status, true)
|
||||
} else {
|
||||
hc.updateStatus(key, status, false)
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus 更新后端状态
|
||||
func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) {
|
||||
if success {
|
||||
status.SuccessCount++
|
||||
status.FailCount = 0
|
||||
if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess {
|
||||
status.Healthy = true
|
||||
if hc.statusChangeCallback != nil {
|
||||
hc.statusChangeCallback(key, true)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
status.FailCount++
|
||||
status.SuccessCount = 0
|
||||
if status.Healthy && status.FailCount >= hc.config.MaxFails {
|
||||
status.Healthy = false
|
||||
if hc.statusChangeCallback != nil {
|
||||
hc.statusChangeCallback(key, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hc.statusMap.Store(key, status)
|
||||
}
|
343
pkg/loadbalance/loadbalancer.go
Normal file
343
pkg/loadbalance/loadbalancer.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Strategy 负载均衡策略
|
||||
type Strategy int
|
||||
|
||||
const (
|
||||
// StrategyRoundRobin 轮询策略
|
||||
StrategyRoundRobin Strategy = iota
|
||||
// StrategyRandom 随机策略
|
||||
StrategyRandom
|
||||
// StrategyWeightedRoundRobin 加权轮询策略
|
||||
StrategyWeightedRoundRobin
|
||||
// StrategyIPHash IP哈希策略
|
||||
StrategyIPHash
|
||||
)
|
||||
|
||||
// LoadBalancer 负载均衡器接口
|
||||
type LoadBalancer interface {
|
||||
// Next 获取下一个后端
|
||||
Next(key string) (*url.URL, error)
|
||||
// Add 添加后端
|
||||
Add(backend string, weight int) error
|
||||
// Remove 删除后端
|
||||
Remove(backend string) error
|
||||
// MarkDown 标记后端为不可用
|
||||
MarkDown(backend string) error
|
||||
// MarkUp 标记后端为可用
|
||||
MarkUp(backend string) error
|
||||
// Reset 重置负载均衡器
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Backend 后端服务器
|
||||
type Backend struct {
|
||||
// URL 后端URL
|
||||
URL *url.URL
|
||||
// Weight 权重
|
||||
Weight int
|
||||
// Down 是否不可用
|
||||
Down bool
|
||||
}
|
||||
|
||||
// RoundRobinBalancer 轮询负载均衡器
|
||||
type RoundRobinBalancer struct {
|
||||
backends []*Backend
|
||||
current int32
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRoundRobinBalancer 创建轮询负载均衡器
|
||||
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||
return &RoundRobinBalancer{
|
||||
backends: make([]*Backend, 0),
|
||||
current: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Next 获取下一个后端
|
||||
func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
|
||||
if len(lb.backends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 计算可用后端数量
|
||||
var availableCount int
|
||||
for _, backend := range lb.backends {
|
||||
if !backend.Down {
|
||||
availableCount++
|
||||
}
|
||||
}
|
||||
|
||||
if availableCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 循环直到找到可用后端
|
||||
for i := 0; i < len(lb.backends); i++ {
|
||||
idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends))
|
||||
backend := lb.backends[idx]
|
||||
if !backend.Down {
|
||||
return backend.URL, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Add 添加后端
|
||||
func (lb *RoundRobinBalancer) Add(backend string, weight int) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
lb.backends = append(lb.backends, &Backend{
|
||||
URL: url,
|
||||
Weight: weight,
|
||||
Down: false,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddList 添加后端列表
|
||||
func (lb *RoundRobinBalancer) AddList(backend []string, weight int) error {
|
||||
var errs error
|
||||
for _, vo := range backend {
|
||||
err := lb.Add(vo, weight)
|
||||
if err != nil {
|
||||
errors.Join(err)
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// Remove 删除后端
|
||||
func (lb *RoundRobinBalancer) Remove(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkDown 标记后端为不可用
|
||||
func (lb *RoundRobinBalancer) MarkDown(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUp 标记后端为可用
|
||||
func (lb *RoundRobinBalancer) MarkUp(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置负载均衡器
|
||||
func (lb *RoundRobinBalancer) Reset() error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
lb.backends = make([]*Backend, 0)
|
||||
atomic.StoreInt32(&lb.current, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RandomBalancer 随机负载均衡器
|
||||
type RandomBalancer struct {
|
||||
backends []*Backend
|
||||
mutex sync.RWMutex
|
||||
rand *rand.Rand
|
||||
}
|
||||
|
||||
// NewRandomBalancer 创建随机负载均衡器
|
||||
func NewRandomBalancer() *RandomBalancer {
|
||||
source := rand.NewSource(time.Now().UnixNano())
|
||||
random := rand.New(source)
|
||||
return &RandomBalancer{
|
||||
backends: make([]*Backend, 0),
|
||||
rand: random,
|
||||
}
|
||||
}
|
||||
|
||||
// Next 获取下一个后端
|
||||
func (lb *RandomBalancer) Next(key string) (*url.URL, error) {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
|
||||
if len(lb.backends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 计算可用后端数量
|
||||
var availableBackends []*Backend
|
||||
for _, backend := range lb.backends {
|
||||
if !backend.Down {
|
||||
availableBackends = append(availableBackends, backend)
|
||||
}
|
||||
}
|
||||
|
||||
if len(availableBackends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
idx := lb.rand.Intn(len(availableBackends))
|
||||
return availableBackends[idx].URL, nil
|
||||
}
|
||||
|
||||
// Add 添加后端
|
||||
func (lb *RandomBalancer) Add(backend string, weight int) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
lb.backends = append(lb.backends, &Backend{
|
||||
URL: url,
|
||||
Weight: weight,
|
||||
Down: false,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 删除后端
|
||||
func (lb *RandomBalancer) Remove(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkDown 标记后端为不可用
|
||||
func (lb *RandomBalancer) MarkDown(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUp 标记后端为可用
|
||||
func (lb *RandomBalancer) MarkUp(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置负载均衡器
|
||||
func (lb *RandomBalancer) Reset() error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
lb.backends = make([]*Backend, 0)
|
||||
|
||||
return nil
|
||||
}
|
387
pkg/metrics/metrics.go
Normal file
387
pkg/metrics/metrics.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
// MetricsCollector 监控指标接口
|
||||
type MetricsCollector interface {
|
||||
// 增加请求计数
|
||||
IncRequestCount()
|
||||
// 增加错误计数
|
||||
IncErrorCount(err error)
|
||||
// 观察请求持续时间
|
||||
ObserveRequestDuration(seconds float64)
|
||||
// 增加活跃连接数
|
||||
IncActiveConnections()
|
||||
// 减少活跃连接数
|
||||
DecActiveConnections()
|
||||
// 设置后端健康状态
|
||||
SetBackendHealth(backend string, healthy bool)
|
||||
// 设置后端响应时间
|
||||
SetBackendResponseTime(backend string, duration time.Duration)
|
||||
// 观察请求字节数
|
||||
ObserveRequestBytes(bytes int64)
|
||||
// 观察响应字节数
|
||||
ObserveResponseBytes(bytes int64)
|
||||
// 添加传输字节数
|
||||
AddBytesTransferred(direction string, bytes int64)
|
||||
// 增加缓存命中计数
|
||||
IncCacheHit()
|
||||
// 获取指标处理器
|
||||
GetHandler() http.Handler
|
||||
}
|
||||
|
||||
// PrometheusMetrics 指标收集器
|
||||
type PrometheusMetrics struct {
|
||||
// 请求总数
|
||||
requestTotal *prometheus.CounterVec
|
||||
// 请求延迟
|
||||
requestLatency *prometheus.HistogramVec
|
||||
// 请求大小
|
||||
requestSize *prometheus.HistogramVec
|
||||
// 响应大小
|
||||
responseSize *prometheus.HistogramVec
|
||||
// 错误总数
|
||||
errorTotal *prometheus.CounterVec
|
||||
// 活跃连接数
|
||||
activeConnections prometheus.Gauge
|
||||
// 连接池大小
|
||||
connectionPoolSize prometheus.Gauge
|
||||
// 缓存命中率
|
||||
cacheHitRate prometheus.Gauge
|
||||
// 内存使用量
|
||||
memoryUsage prometheus.Gauge
|
||||
// 锁
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPrometheusMetrics 创建指标收集器
|
||||
func NewPrometheusMetrics() *PrometheusMetrics {
|
||||
m := &PrometheusMetrics{
|
||||
requestTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "proxy_requests_total",
|
||||
Help: "代理请求总数",
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
),
|
||||
requestLatency: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "proxy_request_latency_seconds",
|
||||
Help: "代理请求延迟",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"method", "path"},
|
||||
),
|
||||
requestSize: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "proxy_request_size_bytes",
|
||||
Help: "代理请求大小",
|
||||
Buckets: prometheus.ExponentialBuckets(100, 2, 10),
|
||||
},
|
||||
[]string{"method", "path"},
|
||||
),
|
||||
responseSize: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "proxy_response_size_bytes",
|
||||
Help: "代理响应大小",
|
||||
Buckets: prometheus.ExponentialBuckets(100, 2, 10),
|
||||
},
|
||||
[]string{"method", "path"},
|
||||
),
|
||||
errorTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "proxy_errors_total",
|
||||
Help: "代理错误总数",
|
||||
},
|
||||
[]string{"type"},
|
||||
),
|
||||
activeConnections: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "proxy_active_connections",
|
||||
Help: "活跃连接数",
|
||||
},
|
||||
),
|
||||
connectionPoolSize: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "proxy_connection_pool_size",
|
||||
Help: "连接池大小",
|
||||
},
|
||||
),
|
||||
cacheHitRate: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "proxy_cache_hit_rate",
|
||||
Help: "缓存命中率",
|
||||
},
|
||||
),
|
||||
memoryUsage: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "proxy_memory_usage_bytes",
|
||||
Help: "内存使用量",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
// 启动定期更新
|
||||
go m.updateMetrics()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// updateMetrics 定期更新指标
|
||||
func (m *PrometheusMetrics) updateMetrics() {
|
||||
ticker := time.NewTicker(15 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
// 更新内存使用量
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
m.memoryUsage.Set(float64(mem.Alloc))
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRequest 记录请求
|
||||
func (m *PrometheusMetrics) RecordRequest(method, path string, status int, latency time.Duration, reqSize, respSize int64) {
|
||||
m.requestTotal.WithLabelValues(method, path, strconv.Itoa(status)).Inc()
|
||||
m.requestLatency.WithLabelValues(method, path).Observe(latency.Seconds())
|
||||
m.requestSize.WithLabelValues(method, path).Observe(float64(reqSize))
|
||||
m.responseSize.WithLabelValues(method, path).Observe(float64(respSize))
|
||||
}
|
||||
|
||||
// RecordError 记录错误
|
||||
func (m *PrometheusMetrics) RecordError(errType string) {
|
||||
m.errorTotal.WithLabelValues(errType).Inc()
|
||||
}
|
||||
|
||||
// SetActiveConnections 设置活跃连接数
|
||||
func (m *PrometheusMetrics) SetActiveConnections(count int) {
|
||||
m.activeConnections.Set(float64(count))
|
||||
}
|
||||
|
||||
// SetConnectionPoolSize 设置连接池大小
|
||||
func (m *PrometheusMetrics) SetConnectionPoolSize(size int) {
|
||||
m.connectionPoolSize.Set(float64(size))
|
||||
}
|
||||
|
||||
// SetCacheHitRate 设置缓存命中率
|
||||
func (m *PrometheusMetrics) SetCacheHitRate(rate float64) {
|
||||
m.cacheHitRate.Set(rate)
|
||||
}
|
||||
|
||||
// SimpleMetrics 简单指标实现
|
||||
type SimpleMetrics struct {
|
||||
// 请求计数
|
||||
requestCount int64
|
||||
// 错误计数
|
||||
errorCount int64
|
||||
// 活跃连接数
|
||||
activeConnections int64
|
||||
// 累计响应时间
|
||||
totalResponseTime int64
|
||||
// 传输字节数
|
||||
bytesTransferred map[string]int64
|
||||
// 后端健康状态
|
||||
backendHealth map[string]bool
|
||||
// 后端响应时间
|
||||
backendResponseTime map[string]time.Duration
|
||||
// 缓存命中计数
|
||||
cacheHits int64
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSimpleMetrics 创建简单指标
|
||||
func NewSimpleMetrics() *SimpleMetrics {
|
||||
return &SimpleMetrics{
|
||||
bytesTransferred: make(map[string]int64),
|
||||
backendHealth: make(map[string]bool),
|
||||
backendResponseTime: make(map[string]time.Duration),
|
||||
}
|
||||
}
|
||||
|
||||
// IncRequestCount 增加请求计数
|
||||
func (m *SimpleMetrics) IncRequestCount() {
|
||||
atomic.AddInt64(&m.requestCount, 1)
|
||||
}
|
||||
|
||||
// IncErrorCount 增加错误计数
|
||||
func (m *SimpleMetrics) IncErrorCount(err error) {
|
||||
atomic.AddInt64(&m.errorCount, 1)
|
||||
}
|
||||
|
||||
// ObserveRequestDuration 观察请求持续时间
|
||||
func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) {
|
||||
nsec := int64(seconds * float64(time.Second))
|
||||
atomic.AddInt64(&m.totalResponseTime, nsec)
|
||||
}
|
||||
|
||||
// IncActiveConnections 增加活跃连接数
|
||||
func (m *SimpleMetrics) IncActiveConnections() {
|
||||
atomic.AddInt64(&m.activeConnections, 1)
|
||||
}
|
||||
|
||||
// DecActiveConnections 减少活跃连接数
|
||||
func (m *SimpleMetrics) DecActiveConnections() {
|
||||
atomic.AddInt64(&m.activeConnections, -1)
|
||||
}
|
||||
|
||||
// SetBackendHealth 设置后端健康状态
|
||||
func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) {
|
||||
m.backendHealth[backend] = healthy
|
||||
}
|
||||
|
||||
// SetBackendResponseTime 设置后端响应时间
|
||||
func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) {
|
||||
m.backendResponseTime[backend] = duration
|
||||
}
|
||||
|
||||
// ObserveRequestBytes 观察请求字节数
|
||||
func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred["request"] += bytes
|
||||
}
|
||||
|
||||
// ObserveResponseBytes 观察响应字节数
|
||||
func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred["response"] += bytes
|
||||
}
|
||||
|
||||
// AddBytesTransferred 添加传输字节数
|
||||
func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred[direction] += bytes
|
||||
}
|
||||
|
||||
// IncCacheHit 增加缓存命中计数
|
||||
func (m *SimpleMetrics) IncCacheHit() {
|
||||
atomic.AddInt64(&m.cacheHits, 1)
|
||||
}
|
||||
|
||||
// GetHandler 获取指标处理器
|
||||
func (m *SimpleMetrics) GetHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
|
||||
// 输出基本指标
|
||||
w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_requests_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount)))
|
||||
|
||||
w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_errors_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount)))
|
||||
|
||||
w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n"))
|
||||
w.Write([]byte("# TYPE proxy_active_connections gauge\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections)))
|
||||
|
||||
// 输出缓存命中数据
|
||||
w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_cache_hits_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits)))
|
||||
|
||||
// 输出传输字节数
|
||||
for direction, bytes := range m.bytesTransferred {
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction)))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction)))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes)))
|
||||
}
|
||||
|
||||
// 输出后端健康状态
|
||||
for backend, healthy := range m.backendHealth {
|
||||
healthValue := 0
|
||||
if healthy {
|
||||
healthValue = 1
|
||||
}
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n")))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n")))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue)))
|
||||
}
|
||||
|
||||
// 输出后端响应时间
|
||||
for backend, duration := range m.backendResponseTime {
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n")))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n")))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second))))
|
||||
}
|
||||
|
||||
// 平均响应时间
|
||||
if m.requestCount > 0 {
|
||||
avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second)
|
||||
w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n"))
|
||||
w.Write([]byte("# TYPE proxy_average_response_time gauge\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// MetricsMiddleware 指标中间件
|
||||
type MetricsMiddleware struct {
|
||||
metrics MetricsCollector
|
||||
}
|
||||
|
||||
// NewMetricsMiddleware 创建指标中间件
|
||||
func NewMetricsMiddleware(metrics MetricsCollector) *MetricsMiddleware {
|
||||
return &MetricsMiddleware{
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装响应写入器,用于捕获状态码
|
||||
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
// 继续处理请求
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// 记录请求指标
|
||||
duration := time.Since(start)
|
||||
m.metrics.ObserveRequestDuration(duration.Seconds())
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriter 包装的响应写入器
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written int64
|
||||
}
|
||||
|
||||
// WriteHeader 写入状态码
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
rw.statusCode = statusCode
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Write 写入数据
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
n, err := rw.ResponseWriter.Write(b)
|
||||
rw.written += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush 刷新数据
|
||||
func (rw *responseWriter) Flush() {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
63
pkg/middleware/chain.go
Normal file
63
pkg/middleware/chain.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Middleware 中间件接口
|
||||
type Middleware interface {
|
||||
ServeHTTP(http.ResponseWriter, *http.Request, http.HandlerFunc)
|
||||
}
|
||||
|
||||
// Chain 中间件链
|
||||
type Chain struct {
|
||||
middlewares []Middleware
|
||||
}
|
||||
|
||||
// NewChain 创建新的中间件链
|
||||
func NewChain(middlewares ...Middleware) *Chain {
|
||||
return &Chain{
|
||||
middlewares: middlewares,
|
||||
}
|
||||
}
|
||||
|
||||
// Then 将中间件链应用到处理器
|
||||
func (c *Chain) Then(h http.Handler) http.Handler {
|
||||
if h == nil {
|
||||
h = http.DefaultServeMux
|
||||
}
|
||||
|
||||
for i := len(c.middlewares) - 1; i >= 0; i-- {
|
||||
h = c.wrap(h, c.middlewares[i])
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
// wrap 包装处理器
|
||||
func (c *Chain) wrap(h http.Handler, m Middleware) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.ServeHTTP(w, r, h.ServeHTTP)
|
||||
})
|
||||
}
|
||||
|
||||
// Add 添加中间件
|
||||
func (c *Chain) Add(middlewares ...Middleware) *Chain {
|
||||
c.middlewares = append(c.middlewares, middlewares...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Remove 移除中间件
|
||||
func (c *Chain) Remove(index int) *Chain {
|
||||
if index < 0 || index >= len(c.middlewares) {
|
||||
return c
|
||||
}
|
||||
c.middlewares = append(c.middlewares[:index], c.middlewares[index+1:]...)
|
||||
return c
|
||||
}
|
||||
|
||||
// Clear 清空中间件链
|
||||
func (c *Chain) Clear() *Chain {
|
||||
c.middlewares = nil
|
||||
return c
|
||||
}
|
127
pkg/middleware/compression.go
Normal file
127
pkg/middleware/compression.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CompressionMiddleware 压缩中间件
|
||||
type CompressionMiddleware struct {
|
||||
// 压缩级别 (0-9)
|
||||
level int
|
||||
// 最小压缩大小
|
||||
minSize int64
|
||||
// 支持的内容类型
|
||||
contentTypes []string
|
||||
}
|
||||
|
||||
// NewCompressionMiddleware 创建压缩中间件
|
||||
func NewCompressionMiddleware(level int, minSize int64) *CompressionMiddleware {
|
||||
return &CompressionMiddleware{
|
||||
level: level,
|
||||
minSize: minSize,
|
||||
contentTypes: []string{
|
||||
"text/plain",
|
||||
"text/html",
|
||||
"text/css",
|
||||
"text/javascript",
|
||||
"application/javascript",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/xml+rss",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
"text/yaml",
|
||||
"application/x-www-form-urlencoded",
|
||||
"application/x-protobuf",
|
||||
"application/grpc",
|
||||
"application/grpc+proto",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *CompressionMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 处理请求压缩
|
||||
if m.shouldCompress(r.Header.Get("Content-Encoding")) {
|
||||
reader, err := gzip.NewReader(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid gzip content", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
r.Body = io.NopCloser(reader)
|
||||
}
|
||||
|
||||
// 处理响应压缩
|
||||
if m.shouldCompressResponse(r) {
|
||||
gw := gzip.NewWriter(w)
|
||||
defer gw.Close()
|
||||
|
||||
// 包装响应写入器
|
||||
writer := &gzipResponseWriter{
|
||||
ResponseWriter: w,
|
||||
writer: gw,
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
w.Header().Set("Vary", "Accept-Encoding")
|
||||
|
||||
next.ServeHTTP(writer, r)
|
||||
} else {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// shouldCompress 检查是否应该压缩请求
|
||||
func (m *CompressionMiddleware) shouldCompress(encoding string) bool {
|
||||
return strings.Contains(encoding, "gzip")
|
||||
}
|
||||
|
||||
// shouldCompressResponse 检查是否应该压缩响应
|
||||
func (m *CompressionMiddleware) shouldCompressResponse(r *http.Request) bool {
|
||||
// 检查客户端是否支持gzip
|
||||
acceptEncoding := r.Header.Get("Accept-Encoding")
|
||||
if !strings.Contains(acceptEncoding, "gzip") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查内容类型
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
for _, t := range m.contentTypes {
|
||||
if strings.Contains(contentType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// gzipResponseWriter 包装的响应写入器
|
||||
type gzipResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
writer *gzip.Writer
|
||||
}
|
||||
|
||||
// Write 写入数据
|
||||
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.writer.Write(b)
|
||||
}
|
||||
|
||||
// WriteHeader 写入状态码
|
||||
func (w *gzipResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Flush 刷新数据
|
||||
func (w *gzipResponseWriter) Flush() {
|
||||
w.writer.Flush()
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
184
pkg/middleware/ratelimiter.go
Normal file
184
pkg/middleware/ratelimiter.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter 限流器接口
|
||||
type RateLimiter interface {
|
||||
// Allow 检查请求是否允许通过
|
||||
Allow(key string) bool
|
||||
}
|
||||
|
||||
// SimpleRateLimiter 简单限流器
|
||||
type SimpleRateLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
// NewSimpleRateLimiter 创建简单限流器
|
||||
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
|
||||
return &SimpleRateLimiter{
|
||||
limiter: rate.NewLimiter(rate.Limit(r), b),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查请求是否允许通过
|
||||
func (rl *SimpleRateLimiter) Allow(key string) bool {
|
||||
return rl.limiter.Allow()
|
||||
}
|
||||
|
||||
// IPRateLimiter 按IP限流
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
cleanupInterval time.Duration
|
||||
lastSeen map[string]time.Time
|
||||
}
|
||||
|
||||
// NewIPRateLimiter 创建IP限流器
|
||||
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
|
||||
limiter := &IPRateLimiter{
|
||||
ips: make(map[string]*rate.Limiter),
|
||||
rate: rate.Limit(r),
|
||||
burst: b,
|
||||
cleanupInterval: cleanup,
|
||||
lastSeen: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// 启动过期清理
|
||||
if cleanup > 0 {
|
||||
go limiter.startCleanup()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// startCleanup 启动过期清理
|
||||
func (rl *IPRateLimiter) startCleanup() {
|
||||
ticker := time.NewTicker(rl.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理过期限流器
|
||||
func (rl *IPRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for ip, lastSeen := range rl.lastSeen {
|
||||
if now.Sub(lastSeen) > rl.cleanupInterval {
|
||||
delete(rl.ips, ip)
|
||||
delete(rl.lastSeen, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddIP 添加IP限流器
|
||||
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.ips[ip] = limiter
|
||||
rl.lastSeen[ip] = time.Now()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// GetLimiter 获取IP限流器
|
||||
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.ips[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return rl.AddIP(ip)
|
||||
}
|
||||
|
||||
// 更新最后访问时间
|
||||
rl.mu.Lock()
|
||||
rl.lastSeen[ip] = time.Now()
|
||||
rl.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Allow 检查请求是否允许通过
|
||||
func (rl *IPRateLimiter) Allow(ip string) bool {
|
||||
limiter := rl.GetLimiter(ip)
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
limiter RateLimiter
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 获取客户端IP
|
||||
ip := getClientIP(r)
|
||||
|
||||
// 检查是否允许通过
|
||||
if !m.limiter.Allow(ip) {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 继续处理请求
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查 X-Forwarded-For 头
|
||||
ip := r.Header.Get("X-Forwarded-For")
|
||||
if ip != "" {
|
||||
// 取第一个IP
|
||||
for i := 0; i < len(ip) && i < 15; i++ {
|
||||
if ip[i] == ',' {
|
||||
ip = ip[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头
|
||||
ip = r.Header.Get("X-Real-IP")
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
// 从 RemoteAddr 获取
|
||||
if r.RemoteAddr != "" {
|
||||
// 去掉端口部分
|
||||
for i := 0; i < len(r.RemoteAddr); i++ {
|
||||
if r.RemoteAddr[i] == ':' {
|
||||
return r.RemoteAddr[:i]
|
||||
}
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
156
pkg/middleware/retry.go
Normal file
156
pkg/middleware/retry.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RetryPolicy 重试策略
|
||||
type RetryPolicy struct {
|
||||
// 最大重试次数
|
||||
MaxRetries int
|
||||
// 基础退避时间
|
||||
BaseBackoff time.Duration
|
||||
// 最大退避时间
|
||||
MaxBackoff time.Duration
|
||||
// 重试判断函数
|
||||
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy 默认重试策略
|
||||
func DefaultRetryPolicy() *RetryPolicy {
|
||||
return &RetryPolicy{
|
||||
MaxRetries: 3,
|
||||
BaseBackoff: 100 * time.Millisecond,
|
||||
MaxBackoff: 2 * time.Second,
|
||||
ShouldRetry: defaultShouldRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultShouldRetry 默认重试判断
|
||||
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
|
||||
// 不重试非幂等请求
|
||||
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if err != nil {
|
||||
// 重试网络错误
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Temporary() || netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查响应状态码
|
||||
if resp != nil {
|
||||
// 重试服务器错误
|
||||
return resp.StatusCode >= 500 && resp.StatusCode < 600
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryRoundTripper 重试HTTP传输
|
||||
type RetryRoundTripper struct {
|
||||
// 下一级传输
|
||||
Next http.RoundTripper
|
||||
// 重试策略
|
||||
Policy *RetryPolicy
|
||||
}
|
||||
|
||||
// NewRetryRoundTripper 创建重试HTTP传输
|
||||
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
|
||||
if next == nil {
|
||||
next = http.DefaultTransport
|
||||
}
|
||||
if policy == nil {
|
||||
policy = DefaultRetryPolicy()
|
||||
}
|
||||
return &RetryRoundTripper{
|
||||
Next: next,
|
||||
Policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip 执行HTTP请求
|
||||
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// 需要保留原始请求体,以便重试
|
||||
var reqBodyBytes []byte
|
||||
if req.Body != nil {
|
||||
var err error
|
||||
reqBodyBytes, err = io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Body.Close()
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
// 尝试请求直到成功或达到最大重试次数
|
||||
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
|
||||
// 复制请求体
|
||||
if len(reqBodyBytes) > 0 {
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = rt.Next.RoundTrip(req)
|
||||
|
||||
// 检查是否需要重试
|
||||
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
|
||||
// 如果需要重试,先关闭当前响应
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// 计算退避时间
|
||||
backoff := rt.calculateBackoff(attempt)
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// 不需要重试,返回响应
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// 所有重试都失败
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// calculateBackoff 计算退避时间
|
||||
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
|
||||
// 指数退避: baseBackoff * 2^attempt
|
||||
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
|
||||
if backoff > rt.Policy.MaxBackoff {
|
||||
backoff = rt.Policy.MaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
// RetryMiddleware 重试中间件
|
||||
type RetryMiddleware struct {
|
||||
policy *RetryPolicy
|
||||
}
|
||||
|
||||
// NewRetryMiddleware 创建重试中间件
|
||||
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
|
||||
if policy == nil {
|
||||
policy = DefaultRetryPolicy()
|
||||
}
|
||||
return &RetryMiddleware{
|
||||
policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
|
||||
return NewRetryRoundTripper(next, m.policy)
|
||||
}
|
147
pkg/middleware/websocket.go
Normal file
147
pkg/middleware/websocket.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/ouqiang/websocket"
|
||||
)
|
||||
|
||||
// WebSocketConn WebSocket连接接口
|
||||
type WebSocketConn struct {
|
||||
conn *websocket.Conn
|
||||
}
|
||||
|
||||
// NewWebSocketConn 创建WebSocket连接
|
||||
func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn {
|
||||
return &WebSocketConn{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *WebSocketConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// ReadMessage 读取消息
|
||||
func (c *WebSocketConn) ReadMessage() (int, []byte, error) {
|
||||
return c.conn.ReadMessage()
|
||||
}
|
||||
|
||||
// WriteMessage 写入消息
|
||||
func (c *WebSocketConn) WriteMessage(messageType int, data []byte) error {
|
||||
return c.conn.WriteMessage(messageType, data)
|
||||
}
|
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (c *WebSocketConn) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// WebSocketUpgrader WebSocket升级器
|
||||
type WebSocketUpgrader struct {
|
||||
upgrader websocket.Upgrader
|
||||
}
|
||||
|
||||
// NewWebSocketUpgrader 创建WebSocket升级器
|
||||
func NewWebSocketUpgrader(timeout time.Duration) *WebSocketUpgrader {
|
||||
return &WebSocketUpgrader{
|
||||
upgrader: websocket.Upgrader{
|
||||
HandshakeTimeout: timeout,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade 升级HTTP连接为WebSocket连接
|
||||
func (u *WebSocketUpgrader) Upgrade(conn net.Conn, req *http.Request) (*WebSocketConn, error) {
|
||||
bufConn, ok := conn.(interface {
|
||||
Hijack() (net.Conn, *bufio.ReadWriter, error)
|
||||
})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("连接不支持Hijack")
|
||||
}
|
||||
|
||||
netConn, bufrw, err := bufConn.Hijack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hijack错误: %s", err)
|
||||
}
|
||||
|
||||
wsConn, err := u.upgrader.Upgrade(newResponseWriter(netConn, bufrw), req, http.Header{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewWebSocketConn(wsConn), nil
|
||||
}
|
||||
|
||||
// WebSocketDialer WebSocket拨号器
|
||||
type WebSocketDialer struct {
|
||||
dialer websocket.Dialer
|
||||
}
|
||||
|
||||
// NewWebSocketDialer 创建WebSocket拨号器
|
||||
func NewWebSocketDialer(timeout time.Duration) *WebSocketDialer {
|
||||
return &WebSocketDialer{
|
||||
dialer: websocket.Dialer{
|
||||
HandshakeTimeout: timeout,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Dial 连接到WebSocket服务器
|
||||
func (d *WebSocketDialer) Dial(urlStr string, header http.Header) (*WebSocketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), d.dialer.HandshakeTimeout)
|
||||
defer cancel()
|
||||
|
||||
wsConn, _, err := d.dialer.DialContext(ctx, urlStr, header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewWebSocketConn(wsConn), nil
|
||||
}
|
||||
|
||||
// 实现http.ResponseWriter接口,用于WebSocket升级
|
||||
type responseWriter struct {
|
||||
conn net.Conn
|
||||
bufrw *bufio.ReadWriter
|
||||
header http.Header
|
||||
status int
|
||||
}
|
||||
|
||||
func newResponseWriter(conn net.Conn, bufrw *bufio.ReadWriter) *responseWriter {
|
||||
return &responseWriter{
|
||||
conn: conn,
|
||||
bufrw: bufrw,
|
||||
header: make(http.Header),
|
||||
status: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Header() http.Header {
|
||||
return rw.header
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
return rw.bufrw.Write(b)
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
rw.status = statusCode
|
||||
}
|
||||
|
||||
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return rw.conn, rw.bufrw, nil
|
||||
}
|
90
pkg/reverse/config.go
Normal file
90
pkg/reverse/config.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package reverse
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
)
|
||||
|
||||
// Config 反向代理配置
|
||||
type Config struct {
|
||||
BaseConfig `json:",inline" yaml:",inline" toml:",inline"` // 基础配置
|
||||
RulesFile string `json:"rules_file" yaml:"rules_file" toml:"rules_file"` // 规则文件路径
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过证书验证
|
||||
EnableHealthCheck bool `json:"enable_health_check" yaml:"enable_health_check" toml:"enable_health_check"` // 是否启用健康检查
|
||||
HealthCheckInterval time.Duration `json:"health_check_interval" yaml:"health_check_interval" toml:"health_check_interval"` // 健康检查间隔时间
|
||||
HealthCheckTimeout time.Duration `json:"health_check_timeout" yaml:"health_check_timeout" toml:"health_check_timeout"` // 健康检查超时时间
|
||||
EnableRetry bool `json:"enable_retry" yaml:"enable_retry" toml:"enable_retry"` // 是否启用重试机制
|
||||
MaxRetries int `json:"max_retries" yaml:"max_retries" toml:"max_retries"` // 最大重试次数
|
||||
RetryBackoff time.Duration `json:"retry_backoff" yaml:"retry_backoff" toml:"retry_backoff"` // 重试间隔基数
|
||||
MaxRetryBackoff time.Duration `json:"max_retry_backoff" yaml:"max_retry_backoff" toml:"max_retry_backoff"` // 最大重试间隔
|
||||
EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics" toml:"enable_metrics"` // 是否启用监控指标
|
||||
EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing" toml:"enable_tracing"` // 是否启用请求追踪
|
||||
WebSocketIntercept bool `json:"websocket_intercept" yaml:"websocket_intercept" toml:"websocket_intercept"` // 是否拦截WebSocket
|
||||
DNSCacheTTL time.Duration `json:"dns_cache_ttl" yaml:"dns_cache_ttl" toml:"dns_cache_ttl"` // DNS缓存过期时间
|
||||
EnableCache bool `json:"enable_cache" yaml:"enable_cache" toml:"enable_cache"` // 是否启用响应缓存
|
||||
CacheTTL time.Duration `json:"cache_ttl" yaml:"cache_ttl" toml:"cache_ttl"` // 缓存过期时间
|
||||
EnableConnectionPool bool `json:"enable_connection_pool" yaml:"enable_connection_pool" toml:"enable_connection_pool"` // 是否启用连接池
|
||||
ConnectionPoolSize int `json:"connection_pool_size" yaml:"connection_pool_size" toml:"connection_pool_size"` // 连接池大小
|
||||
IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout" toml:"idle_timeout"` // 连接空闲超时时间
|
||||
RequestTimeout time.Duration `json:"request_timeout" yaml:"request_timeout" toml:"request_timeout"` // 请求超时时间
|
||||
DNSResolver *dns.CustomResolver `json:"-" yaml:"-" toml:"-"` // DNS解析器
|
||||
}
|
||||
|
||||
// BaseConfig 基础配置
|
||||
type BaseConfig struct {
|
||||
ListenAddr string `json:"listen_addr" yaml:"listen_addr" toml:"listen_addr"` // 监听地址
|
||||
TargetAddr string `json:"target_addr" yaml:"target_addr" toml:"target_addr"` // 目标地址
|
||||
EnableHTTPS bool `json:"enable_https" yaml:"enable_https" toml:"enable_https"` // 是否启用HTTPS
|
||||
TLSConfig *TLSConfig `json:"tls_config" yaml:"tls_config" toml:"tls_config"` // TLS配置
|
||||
EnableWebSocket bool `json:"enable_websocket" yaml:"enable_websocket" toml:"enable_websocket"` // 是否启用WebSocket
|
||||
EnableCompression bool `json:"enable_compression" yaml:"enable_compression" toml:"enable_compression"` // 是否启用压缩
|
||||
EnableCORS bool `json:"enable_cors" yaml:"enable_cors" toml:"enable_cors"` // 是否启用CORS
|
||||
PreserveClientIP bool `json:"preserve_client_ip" yaml:"preserve_client_ip" toml:"preserve_client_ip"` // 是否保留客户端IP
|
||||
AddXForwardedFor bool `json:"add_x_forwarded_for" yaml:"add_x_forwarded_for" toml:"add_x_forwarded_for"` // 是否添加X-Forwarded-For头
|
||||
AddXRealIP bool `json:"add_x_real_ip" yaml:"add_x_real_ip" toml:"add_x_real_ip"` // 是否添加X-Real-IP头
|
||||
}
|
||||
|
||||
// TLSConfig TLS配置
|
||||
type TLSConfig struct {
|
||||
CertFile string `json:"cert_file" yaml:"cert_file" toml:"cert_file"` // 证书文件路径
|
||||
KeyFile string `json:"key_file" yaml:"key_file" toml:"key_file"` // 密钥文件路径
|
||||
InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过证书验证
|
||||
UseECDSA bool `json:"use_ecdsa" yaml:"use_ecdsa" toml:"use_ecdsa"` // 是否使用ECDSA
|
||||
}
|
||||
|
||||
// DefaultConfig 返回默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
BaseConfig: BaseConfig{
|
||||
ListenAddr: ":8080",
|
||||
TargetAddr: "", // 默认目标地址为空
|
||||
EnableHTTPS: false,
|
||||
EnableWebSocket: true,
|
||||
EnableCompression: true,
|
||||
EnableCORS: true,
|
||||
PreserveClientIP: true,
|
||||
AddXForwardedFor: true,
|
||||
AddXRealIP: true,
|
||||
},
|
||||
InsecureSkipVerify: false,
|
||||
EnableHealthCheck: false,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
HealthCheckTimeout: 5 * time.Second,
|
||||
EnableRetry: true,
|
||||
MaxRetries: 3,
|
||||
RetryBackoff: time.Second,
|
||||
MaxRetryBackoff: 10 * time.Second,
|
||||
EnableMetrics: true,
|
||||
EnableTracing: false,
|
||||
WebSocketIntercept: false,
|
||||
DNSCacheTTL: 5 * time.Minute,
|
||||
EnableCache: true,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
EnableConnectionPool: true,
|
||||
ConnectionPoolSize: 100,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
DNSResolver: dns.NewResolver(),
|
||||
}
|
||||
}
|
184
pkg/reverse/proxy.go
Normal file
184
pkg/reverse/proxy.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package reverse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/darkit/goproxy/pkg/dns"
|
||||
"github.com/darkit/goproxy/pkg/rule"
|
||||
)
|
||||
|
||||
// Proxy 反向代理服务器
|
||||
type Proxy struct {
|
||||
config *Config
|
||||
ruleManager *rule.Manager
|
||||
logger *slog.Logger
|
||||
transport *http.Transport
|
||||
proxy *httputil.ReverseProxy
|
||||
dnsResolver *dns.CustomResolver
|
||||
}
|
||||
|
||||
// New 创建反向代理服务器
|
||||
func New(cfg *Config) (*Proxy, error) {
|
||||
if cfg == nil {
|
||||
cfg = DefaultConfig()
|
||||
}
|
||||
|
||||
logger := slog.Default()
|
||||
ruleManager := rule.NewManager(logger)
|
||||
|
||||
// 创建自定义传输层
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// 使用自定义DNS解析器
|
||||
if cfg.DNSResolver != nil {
|
||||
// 解析主机和端口
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用DNS解析器解析地址
|
||||
endpoint, err := cfg.DNSResolver.ResolveWithPort(host, portNum)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用解析后的地址创建连接
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.RequestTimeout,
|
||||
KeepAlive: cfg.IdleTimeout,
|
||||
}
|
||||
return dialer.DialContext(ctx, network, endpoint.String())
|
||||
}
|
||||
|
||||
// 使用默认拨号器
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.RequestTimeout,
|
||||
KeepAlive: cfg.IdleTimeout,
|
||||
}
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: cfg.ConnectionPoolSize,
|
||||
IdleConnTimeout: cfg.IdleTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify},
|
||||
}
|
||||
|
||||
proxy := &Proxy{
|
||||
config: cfg,
|
||||
ruleManager: ruleManager,
|
||||
logger: logger,
|
||||
transport: transport,
|
||||
dnsResolver: cfg.DNSResolver,
|
||||
}
|
||||
|
||||
// 如果配置了规则文件,加载规则
|
||||
if cfg.RulesFile != "" {
|
||||
loader := rule.NewLoader(proxy.ruleManager, proxy.logger)
|
||||
|
||||
// 根据文件扩展名决定加载方式
|
||||
switch {
|
||||
case strings.HasSuffix(cfg.RulesFile, ".json"):
|
||||
if err := loader.LoadFromJSON(cfg.RulesFile); err != nil {
|
||||
proxy.logger.Error("加载规则文件失败",
|
||||
"file", cfg.RulesFile,
|
||||
"error", err.Error(),
|
||||
)
|
||||
return nil, fmt.Errorf("加载规则文件失败: %w", err)
|
||||
}
|
||||
proxy.logger.Info("成功加载规则文件",
|
||||
"file", cfg.RulesFile,
|
||||
)
|
||||
|
||||
case strings.HasSuffix(cfg.RulesFile, ".hosts"):
|
||||
if err := loader.LoadFromHosts(cfg.RulesFile); err != nil {
|
||||
proxy.logger.Error("加载hosts文件失败",
|
||||
"file", cfg.RulesFile,
|
||||
"error", err.Error(),
|
||||
)
|
||||
return nil, fmt.Errorf("加载hosts文件失败: %w", err)
|
||||
}
|
||||
proxy.logger.Info("成功加载hosts文件",
|
||||
"file", cfg.RulesFile,
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的规则文件格式: %s", cfg.RulesFile)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建反向代理处理器
|
||||
proxy.proxy = &httputil.ReverseProxy{
|
||||
Transport: transport,
|
||||
Director: func(req *http.Request) {
|
||||
// 应用规则
|
||||
rules := proxy.ruleManager.ListRules(rule.RuleTypeRoute)
|
||||
for _, r := range rules {
|
||||
if r.Match(req) {
|
||||
if err := r.Apply(req); err != nil {
|
||||
proxy.logger.Error("应用规则失败",
|
||||
"rule_id", r.GetID(),
|
||||
"error", err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置目标地址
|
||||
if cfg.TargetAddr != "" {
|
||||
req.URL.Host = cfg.TargetAddr
|
||||
// 如果没有设置协议,默认使用 http
|
||||
if req.URL.Scheme == "" {
|
||||
req.URL.Scheme = "http"
|
||||
}
|
||||
}
|
||||
|
||||
// 添加X-Forwarded-For头
|
||||
if cfg.AddXForwardedFor {
|
||||
req.Header.Add("X-Forwarded-For", req.RemoteAddr)
|
||||
}
|
||||
|
||||
// 添加X-Real-IP头
|
||||
if cfg.AddXRealIP {
|
||||
req.Header.Add("X-Real-IP", req.RemoteAddr)
|
||||
}
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
proxy.logger.Error("代理请求失败",
|
||||
"error", err.Error(),
|
||||
"method", r.Method,
|
||||
"url", r.URL.String(),
|
||||
)
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// ServeHTTP 处理HTTP请求
|
||||
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
p.proxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// Close 关闭代理服务器
|
||||
func (p *Proxy) Close() error {
|
||||
p.transport.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
285
pkg/rewriter/rewriter.go
Normal file
285
pkg/rewriter/rewriter.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Rewriter URL重写器
|
||||
// 用于在反向代理中重写请求URL
|
||||
type Rewriter struct {
|
||||
// 重写规则列表
|
||||
rules []*RewriteRule
|
||||
}
|
||||
|
||||
// RewriteRule 重写规则
|
||||
type RewriteRule struct {
|
||||
// 匹配模式
|
||||
Pattern string `json:"pattern"`
|
||||
// 替换模式
|
||||
Replacement string `json:"replacement"`
|
||||
// 是否使用正则表达式
|
||||
UseRegex bool `json:"use_regex"`
|
||||
// 编译后的正则表达式
|
||||
regex *regexp.Regexp `json:"-"`
|
||||
// 规则描述
|
||||
Description string `json:"description,omitempty"`
|
||||
// 规则启用状态
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// NewRewriter 创建URL重写器
|
||||
func NewRewriter() *Rewriter {
|
||||
return &Rewriter{
|
||||
rules: make([]*RewriteRule, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule 添加重写规则
|
||||
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
|
||||
rule := &RewriteRule{
|
||||
Pattern: pattern,
|
||||
Replacement: replacement,
|
||||
UseRegex: useRegex,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule.regex = regex
|
||||
}
|
||||
|
||||
r.rules = append(r.rules, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRuleWithDescription 添加带描述的重写规则
|
||||
func (r *Rewriter) AddRuleWithDescription(pattern, replacement string, useRegex bool, description string) error {
|
||||
rule := &RewriteRule{
|
||||
Pattern: pattern,
|
||||
Replacement: replacement,
|
||||
UseRegex: useRegex,
|
||||
Description: description,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule.regex = regex
|
||||
}
|
||||
|
||||
r.rules = append(r.rules, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rewrite 重写URL
|
||||
func (r *Rewriter) Rewrite(req *http.Request) {
|
||||
path := req.URL.Path
|
||||
|
||||
for _, rule := range r.rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.UseRegex {
|
||||
if rule.regex.MatchString(path) {
|
||||
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(path, rule.Pattern) {
|
||||
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RewriteResponse 重写响应
|
||||
// 主要用于处理响应中的Location头和内容中的URL
|
||||
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
|
||||
// 处理重定向头
|
||||
location := resp.Header.Get("Location")
|
||||
if location != "" {
|
||||
// 将后端服务器的域名替换成代理服务器的域名
|
||||
for _, rule := range r.rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.UseRegex && rule.regex != nil {
|
||||
if rule.regex.MatchString(location) {
|
||||
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
|
||||
resp.Header.Set("Location", newLocation)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRulesFromFile 从文件加载重写规则
|
||||
func (r *Rewriter) LoadRulesFromFile(filename string) error {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开文件失败: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 检查文件扩展名,决定使用何种方式解析
|
||||
if strings.HasSuffix(filename, ".json") {
|
||||
return r.loadRulesFromJSON(file)
|
||||
} else {
|
||||
return r.loadRulesFromText(file)
|
||||
}
|
||||
}
|
||||
|
||||
// loadRulesFromJSON 从JSON文件加载规则
|
||||
func (r *Rewriter) loadRulesFromJSON(file *os.File) error {
|
||||
var rules []*RewriteRule
|
||||
decoder := json.NewDecoder(file)
|
||||
if err := decoder.Decode(&rules); err != nil {
|
||||
return fmt.Errorf("解析JSON失败: %v", err)
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
for _, rule := range rules {
|
||||
if rule.UseRegex {
|
||||
regex, err := regexp.Compile(rule.Pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("编译正则表达式'%s'失败: %v", rule.Pattern, err)
|
||||
}
|
||||
rule.regex = regex
|
||||
}
|
||||
r.rules = append(r.rules, rule)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadRulesFromText 从文本文件加载规则
|
||||
// 格式: pattern replacement [regex] [#description]
|
||||
func (r *Rewriter) loadRulesFromText(file *os.File) error {
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// 跳过空行和注释
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Fields(line)
|
||||
if len(parts) < 2 {
|
||||
return fmt.Errorf("第%d行格式错误: %s", lineNum, line)
|
||||
}
|
||||
|
||||
pattern := parts[0]
|
||||
replacement := parts[1]
|
||||
useRegex := false
|
||||
description := ""
|
||||
|
||||
// 检查是否有额外选项
|
||||
for i := 2; i < len(parts); i++ {
|
||||
if parts[i] == "regex" {
|
||||
useRegex = true
|
||||
} else if strings.HasPrefix(parts[i], "#") {
|
||||
// 获取描述信息
|
||||
description = strings.Join(parts[i:], " ")
|
||||
description = strings.TrimPrefix(description, "#")
|
||||
description = strings.TrimSpace(description)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil {
|
||||
return fmt.Errorf("第%d行添加规则失败: %v", lineNum, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取文件失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRules 获取所有规则
|
||||
func (r *Rewriter) GetRules() []*RewriteRule {
|
||||
return r.rules
|
||||
}
|
||||
|
||||
// EnableRule 启用规则
|
||||
func (r *Rewriter) EnableRule(index int) error {
|
||||
if index < 0 || index >= len(r.rules) {
|
||||
return fmt.Errorf("规则索引越界: %d", index)
|
||||
}
|
||||
r.rules[index].Enabled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// DisableRule 禁用规则
|
||||
func (r *Rewriter) DisableRule(index int) error {
|
||||
if index < 0 || index >= len(r.rules) {
|
||||
return fmt.Errorf("规则索引越界: %d", index)
|
||||
}
|
||||
r.rules[index].Enabled = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRule 删除规则
|
||||
func (r *Rewriter) RemoveRule(index int) error {
|
||||
if index < 0 || index >= len(r.rules) {
|
||||
return fmt.Errorf("规则索引越界: %d", index)
|
||||
}
|
||||
r.rules = append(r.rules[:index], r.rules[index+1:]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveRulesToFile 将规则保存到文件
|
||||
func (r *Rewriter) SaveRulesToFile(filename string) error {
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建文件失败: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 根据文件扩展名决定保存格式
|
||||
if strings.HasSuffix(filename, ".json") {
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
return encoder.Encode(r.rules)
|
||||
} else {
|
||||
writer := bufio.NewWriter(file)
|
||||
for _, rule := range r.rules {
|
||||
line := rule.Pattern + " " + rule.Replacement
|
||||
if rule.UseRegex {
|
||||
line += " regex"
|
||||
}
|
||||
if rule.Description != "" {
|
||||
line += " # " + rule.Description
|
||||
}
|
||||
if !rule.Enabled {
|
||||
line = "# " + line + " (disabled)"
|
||||
}
|
||||
if _, err := writer.WriteString(line + "\n"); err != nil {
|
||||
return fmt.Errorf("写入文件失败: %v", err)
|
||||
}
|
||||
}
|
||||
return writer.Flush()
|
||||
}
|
||||
}
|
103
pkg/router/router.go
Normal file
103
pkg/router/router.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Route 路由规则
|
||||
type Route struct {
|
||||
// 匹配模式(主机名、路径、正则表达式)
|
||||
Pattern string
|
||||
// 匹配类型
|
||||
Type RouteType
|
||||
// 目标地址
|
||||
Target string
|
||||
// 路径重写规则
|
||||
RewritePattern string
|
||||
// 请求头修改
|
||||
HeaderModifier HeaderModifier
|
||||
// 自定义匹配函数
|
||||
MatchFunc func(req *http.Request) bool
|
||||
}
|
||||
|
||||
// RouteType 路由类型
|
||||
type RouteType int
|
||||
|
||||
const (
|
||||
// HostRoute 主机名路由
|
||||
HostRoute RouteType = iota
|
||||
// PathRoute 路径路由
|
||||
PathRoute
|
||||
// RegexRoute 正则表达式路由
|
||||
RegexRoute
|
||||
// CustomRoute 自定义路由
|
||||
CustomRoute
|
||||
)
|
||||
|
||||
// HeaderModifier 头部修改接口
|
||||
type HeaderModifier interface {
|
||||
// ModifyRequest 修改请求头
|
||||
ModifyRequest(req *http.Request)
|
||||
// ModifyResponse 修改响应头
|
||||
ModifyResponse(resp *http.Response)
|
||||
}
|
||||
|
||||
// Router 路由器
|
||||
type Router struct {
|
||||
routes []*Route
|
||||
}
|
||||
|
||||
// NewRouter 创建路由器
|
||||
func NewRouter() *Router {
|
||||
return &Router{
|
||||
routes: make([]*Route, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRoute 添加路由规则
|
||||
func (r *Router) AddRoute(route *Route) {
|
||||
r.routes = append(r.routes, route)
|
||||
}
|
||||
|
||||
// Match 匹配请求
|
||||
func (r *Router) Match(req *http.Request) (*Route, bool) {
|
||||
for _, route := range r.routes {
|
||||
switch route.Type {
|
||||
case HostRoute:
|
||||
if matchHost(req.Host, route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case PathRoute:
|
||||
if matchPath(req.URL.Path, route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case RegexRoute:
|
||||
if matchRegex(req.URL.String(), route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case CustomRoute:
|
||||
if route.MatchFunc != nil && route.MatchFunc(req) {
|
||||
return route, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 匹配主机名
|
||||
func matchHost(host, pattern string) bool {
|
||||
return host == pattern || strings.HasSuffix(host, "."+pattern)
|
||||
}
|
||||
|
||||
// 匹配路径
|
||||
func matchPath(path, pattern string) bool {
|
||||
return strings.HasPrefix(path, pattern)
|
||||
}
|
||||
|
||||
// 匹配正则表达式
|
||||
func matchRegex(url, pattern string) bool {
|
||||
matched, _ := regexp.MatchString(pattern, url)
|
||||
return matched
|
||||
}
|
217
pkg/rule/loader.go
Normal file
217
pkg/rule/loader.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Loader 规则加载器
|
||||
type Loader struct {
|
||||
manager *Manager
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewLoader 创建规则加载器
|
||||
func NewLoader(manager *Manager, logger *slog.Logger) *Loader {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Loader{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromJSON 从JSON文件加载规则
|
||||
func (l *Loader) LoadFromJSON(filename string) error {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取规则文件失败: %w", err)
|
||||
}
|
||||
|
||||
var config struct {
|
||||
Rules []json.RawMessage `json:"rules"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return fmt.Errorf("解析规则文件失败: %w", err)
|
||||
}
|
||||
|
||||
for _, ruleData := range config.Rules {
|
||||
var baseRule BaseRule
|
||||
if err := json.Unmarshal(ruleData, &baseRule); err != nil {
|
||||
l.logger.Warn("解析规则基础信息失败",
|
||||
"error", err.Error(),
|
||||
"rule_data", string(ruleData),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
var rule Rule
|
||||
switch baseRule.Type {
|
||||
case RuleTypeDNS:
|
||||
var dnsRule DNSRule
|
||||
if err := json.Unmarshal(ruleData, &dnsRule); err != nil {
|
||||
l.logger.Warn("解析DNS规则失败",
|
||||
"error", err.Error(),
|
||||
"rule_data", string(ruleData),
|
||||
)
|
||||
continue
|
||||
}
|
||||
rule = &dnsRule
|
||||
case RuleTypeRewrite:
|
||||
var rewriteRule RewriteRule
|
||||
if err := json.Unmarshal(ruleData, &rewriteRule); err != nil {
|
||||
l.logger.Warn("解析重写规则失败",
|
||||
"error", err.Error(),
|
||||
"rule_data", string(ruleData),
|
||||
)
|
||||
continue
|
||||
}
|
||||
rule = &rewriteRule
|
||||
case RuleTypeRoute:
|
||||
var routeRule RouteRule
|
||||
if err := json.Unmarshal(ruleData, &routeRule); err != nil {
|
||||
l.logger.Warn("解析路由规则失败",
|
||||
"error", err.Error(),
|
||||
"rule_data", string(ruleData),
|
||||
)
|
||||
continue
|
||||
}
|
||||
rule = &routeRule
|
||||
default:
|
||||
l.logger.Warn("未知的规则类型",
|
||||
"type", string(baseRule.Type),
|
||||
"rule_data", string(ruleData),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := l.manager.AddRule(rule); err != nil {
|
||||
l.logger.Error("添加规则失败",
|
||||
"error", err.Error(),
|
||||
"rule_id", rule.GetID(),
|
||||
"rule_type", string(rule.GetType()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFromHosts 从hosts文件加载DNS规则
|
||||
func (l *Loader) LoadFromHosts(filename string) error {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("打开hosts文件失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
lineNum := 0
|
||||
for scanner.Scan() {
|
||||
lineNum++
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
// 跳过空行和注释
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
l.logger.Warn("无效的hosts行",
|
||||
"line_number", lineNum,
|
||||
"line", line,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析IP和端口
|
||||
ipStr := fields[0]
|
||||
ip := ipStr
|
||||
port := 0
|
||||
|
||||
if strings.Contains(ipStr, ":") {
|
||||
parts := strings.Split(ipStr, ":")
|
||||
if len(parts) != 2 {
|
||||
l.logger.Warn("无效的IP:端口格式",
|
||||
"line_number", lineNum,
|
||||
"ip_port", ipStr,
|
||||
)
|
||||
continue
|
||||
}
|
||||
ip = parts[0]
|
||||
if p, err := strconv.Atoi(parts[1]); err == nil {
|
||||
port = p
|
||||
} else {
|
||||
l.logger.Warn("无效的端口号",
|
||||
"line_number", lineNum,
|
||||
"port", parts[1],
|
||||
)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 验证IP地址
|
||||
if net.ParseIP(ip) == nil {
|
||||
l.logger.Warn("无效的IP地址",
|
||||
"line_number", lineNum,
|
||||
"ip", ip,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理每个域名
|
||||
for _, domain := range fields[1:] {
|
||||
// 跳过注释
|
||||
if strings.HasPrefix(domain, "#") {
|
||||
break
|
||||
}
|
||||
|
||||
// 创建DNS规则
|
||||
matchType := MatchTypeExact
|
||||
if strings.Contains(domain, "*") {
|
||||
matchType = MatchTypeWildcard
|
||||
}
|
||||
|
||||
rule := &DNSRule{
|
||||
BaseRule: BaseRule{
|
||||
ID: fmt.Sprintf("hosts-%d-%s", lineNum, domain),
|
||||
Type: RuleTypeDNS,
|
||||
Priority: 100,
|
||||
Pattern: domain,
|
||||
MatchType: matchType,
|
||||
Enabled: true,
|
||||
},
|
||||
Targets: []DNSTarget{
|
||||
{
|
||||
IP: ip,
|
||||
Port: port,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := l.manager.AddRule(rule); err != nil {
|
||||
l.logger.Error("添加hosts规则失败",
|
||||
"error", err.Error(),
|
||||
"domain", domain,
|
||||
"ip", ip,
|
||||
"port", port,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("读取hosts文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
227
pkg/rule/manager.go
Normal file
227
pkg/rule/manager.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Manager 规则管理器
|
||||
type Manager struct {
|
||||
// 规则存储
|
||||
rules sync.Map
|
||||
// 规则索引
|
||||
indexes map[RuleType][]string
|
||||
// 日志记录器
|
||||
logger *slog.Logger
|
||||
// 更新通知通道
|
||||
updateCh chan struct{}
|
||||
// 观察者列表
|
||||
observers []Observer
|
||||
// 互斥锁
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Observer 规则更新观察者接口
|
||||
type Observer interface {
|
||||
OnRuleUpdate(ruleType RuleType)
|
||||
}
|
||||
|
||||
// NewManager 创建规则管理器
|
||||
func NewManager(logger *slog.Logger) *Manager {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Manager{
|
||||
logger: logger,
|
||||
indexes: make(map[RuleType][]string),
|
||||
updateCh: make(chan struct{}, 1),
|
||||
observers: make([]Observer, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule 添加规则
|
||||
func (m *Manager) AddRule(rule Rule) error {
|
||||
// 1. 验证规则
|
||||
if err := rule.Validate(); err != nil {
|
||||
return fmt.Errorf("规则验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 存储规则
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 检查是否已存在相同ID的规则
|
||||
if _, exists := m.rules.Load(rule.GetID()); exists {
|
||||
return fmt.Errorf("规则ID %s 已存在", rule.GetID())
|
||||
}
|
||||
|
||||
// 更新时间戳
|
||||
switch r := rule.(type) {
|
||||
case *DNSRule:
|
||||
r.CreatedAt = time.Now()
|
||||
r.UpdatedAt = time.Now()
|
||||
case *RewriteRule:
|
||||
r.CreatedAt = time.Now()
|
||||
r.UpdatedAt = time.Now()
|
||||
case *RouteRule:
|
||||
r.CreatedAt = time.Now()
|
||||
r.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
m.rules.Store(rule.GetID(), rule)
|
||||
m.updateIndex(rule)
|
||||
|
||||
// 3. 通知更新
|
||||
m.notifyUpdate(rule.GetType())
|
||||
|
||||
m.logger.Info("添加规则成功",
|
||||
"id", rule.GetID(),
|
||||
"type", string(rule.GetType()),
|
||||
"priority", rule.GetPriority(),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRule 获取规则
|
||||
func (m *Manager) GetRule(id string) (Rule, bool) {
|
||||
if rule, ok := m.rules.Load(id); ok {
|
||||
return rule.(Rule), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// ListRules 列出指定类型的规则
|
||||
func (m *Manager) ListRules(ruleType RuleType) []Rule {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var rules []Rule
|
||||
if ids, ok := m.indexes[ruleType]; ok {
|
||||
for _, id := range ids {
|
||||
if rule, exists := m.rules.Load(id); exists {
|
||||
rules = append(rules, rule.(Rule))
|
||||
}
|
||||
}
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
// DeleteRule 删除规则
|
||||
func (m *Manager) DeleteRule(id string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if rule, ok := m.rules.Load(id); ok {
|
||||
m.rules.Delete(id)
|
||||
m.removeFromIndex(rule.(Rule))
|
||||
m.notifyUpdate(rule.(Rule).GetType())
|
||||
|
||||
m.logger.Info("删除规则成功",
|
||||
"id", id,
|
||||
"type", string(rule.(Rule).GetType()),
|
||||
)
|
||||
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateRule 更新规则
|
||||
func (m *Manager) UpdateRule(rule Rule) error {
|
||||
// 1. 验证规则
|
||||
if err := rule.Validate(); err != nil {
|
||||
return fmt.Errorf("规则验证失败: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 检查规则是否存在
|
||||
if _, exists := m.rules.Load(rule.GetID()); !exists {
|
||||
return fmt.Errorf("规则ID %s 不存在", rule.GetID())
|
||||
}
|
||||
|
||||
// 更新时间戳
|
||||
switch r := rule.(type) {
|
||||
case *DNSRule:
|
||||
r.UpdatedAt = time.Now()
|
||||
case *RewriteRule:
|
||||
r.UpdatedAt = time.Now()
|
||||
case *RouteRule:
|
||||
r.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// 更新规则
|
||||
m.rules.Store(rule.GetID(), rule)
|
||||
m.updateIndex(rule)
|
||||
m.notifyUpdate(rule.GetType())
|
||||
|
||||
m.logger.Info("更新规则成功",
|
||||
"id", rule.GetID(),
|
||||
"type", string(rule.GetType()),
|
||||
"priority", rule.GetPriority(),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddObserver 添加观察者
|
||||
func (m *Manager) AddObserver(observer Observer) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.observers = append(m.observers, observer)
|
||||
}
|
||||
|
||||
// 更新规则索引
|
||||
func (m *Manager) updateIndex(rule Rule) {
|
||||
ruleType := rule.GetType()
|
||||
if _, ok := m.indexes[ruleType]; !ok {
|
||||
m.indexes[ruleType] = make([]string, 0)
|
||||
}
|
||||
// 检查是否已存在于索引中
|
||||
exists := false
|
||||
for _, id := range m.indexes[ruleType] {
|
||||
if id == rule.GetID() {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
m.indexes[ruleType] = append(m.indexes[ruleType], rule.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
// 从索引中移除规则
|
||||
func (m *Manager) removeFromIndex(rule Rule) {
|
||||
ruleType := rule.GetType()
|
||||
if ids, ok := m.indexes[ruleType]; ok {
|
||||
for i, id := range ids {
|
||||
if id == rule.GetID() {
|
||||
m.indexes[ruleType] = append(ids[:i], ids[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 通知规则更新
|
||||
func (m *Manager) notifyUpdate(ruleType RuleType) {
|
||||
// 通知所有观察者
|
||||
for _, observer := range m.observers {
|
||||
observer.OnRuleUpdate(ruleType)
|
||||
}
|
||||
|
||||
// 发送更新信号
|
||||
select {
|
||||
case m.updateCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// GetUpdateChannel 获取更新通知通道
|
||||
func (m *Manager) GetUpdateChannel() <-chan struct{} {
|
||||
return m.updateCh
|
||||
}
|
187
pkg/rule/rule_impl.go
Normal file
187
pkg/rule/rule_impl.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Match DNS规则匹配
|
||||
func (r *DNSRule) Match(req *http.Request) bool {
|
||||
host := req.Host
|
||||
if strings.Contains(host, ":") {
|
||||
host, _, _ = net.SplitHostPort(host)
|
||||
}
|
||||
|
||||
switch r.MatchType {
|
||||
case MatchTypeExact:
|
||||
return host == r.Pattern
|
||||
case MatchTypeWildcard:
|
||||
return matchWildcard(host, r.Pattern)
|
||||
case MatchTypeRegex:
|
||||
if re, err := regexp.Compile(r.Pattern); err == nil {
|
||||
return re.MatchString(host)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Apply DNS规则应用
|
||||
func (r *DNSRule) Apply(req *http.Request) error {
|
||||
if !r.Enabled {
|
||||
return nil
|
||||
}
|
||||
// DNS规则的应用在DNS解析阶段处理,这里不需要修改请求
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate DNS规则验证
|
||||
func (r *DNSRule) Validate() error {
|
||||
if r.ID == "" {
|
||||
return fmt.Errorf("规则ID不能为空")
|
||||
}
|
||||
if len(r.Targets) == 0 {
|
||||
return fmt.Errorf("DNS规则必须包含至少一个目标")
|
||||
}
|
||||
for _, target := range r.Targets {
|
||||
if net.ParseIP(target.IP) == nil {
|
||||
return fmt.Errorf("无效的IP地址: %s", target.IP)
|
||||
}
|
||||
if target.Port < 0 || target.Port > 65535 {
|
||||
return fmt.Errorf("无效的端口号: %d", target.Port)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match URL重写规则匹配
|
||||
func (r *RewriteRule) Match(req *http.Request) bool {
|
||||
switch r.MatchType {
|
||||
case MatchTypeExact:
|
||||
return req.URL.Path == r.Pattern
|
||||
case MatchTypePath:
|
||||
return strings.HasPrefix(req.URL.Path, r.Pattern)
|
||||
case MatchTypeRegex:
|
||||
if re, err := regexp.Compile(r.Pattern); err == nil {
|
||||
return re.MatchString(req.URL.Path)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Apply URL重写规则应用
|
||||
func (r *RewriteRule) Apply(req *http.Request) error {
|
||||
if !r.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch r.MatchType {
|
||||
case MatchTypeExact, MatchTypePath:
|
||||
req.URL.Path = strings.Replace(req.URL.Path, r.Pattern, r.Replacement, 1)
|
||||
case MatchTypeRegex:
|
||||
if re, err := regexp.Compile(r.Pattern); err == nil {
|
||||
req.URL.Path = re.ReplaceAllString(req.URL.Path, r.Replacement)
|
||||
} else {
|
||||
return fmt.Errorf("正则表达式编译失败: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate URL重写规则验证
|
||||
func (r *RewriteRule) Validate() error {
|
||||
if r.ID == "" {
|
||||
return fmt.Errorf("规则ID不能为空")
|
||||
}
|
||||
if r.Pattern == "" {
|
||||
return fmt.Errorf("重写规则必须包含匹配模式")
|
||||
}
|
||||
if r.MatchType == MatchTypeRegex {
|
||||
if _, err := regexp.Compile(r.Pattern); err != nil {
|
||||
return fmt.Errorf("无效的正则表达式: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Match 路由规则匹配
|
||||
func (r *RouteRule) Match(req *http.Request) bool {
|
||||
switch r.MatchType {
|
||||
case MatchTypeExact:
|
||||
return req.URL.Path == r.Pattern
|
||||
case MatchTypePath:
|
||||
return strings.HasPrefix(req.URL.Path, r.Pattern)
|
||||
case MatchTypeRegex:
|
||||
if re, err := regexp.Compile(r.Pattern); err == nil {
|
||||
return re.MatchString(req.URL.Path)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Apply 路由规则应用
|
||||
func (r *RouteRule) Apply(req *http.Request) error {
|
||||
if !r.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修改请求头
|
||||
for key, value := range r.HeaderModifier {
|
||||
// 支持变量替换
|
||||
switch value {
|
||||
case "${client_ip}":
|
||||
clientIP := req.RemoteAddr
|
||||
if ip, _, err := net.SplitHostPort(clientIP); err == nil {
|
||||
req.Header.Set(key, ip)
|
||||
} else {
|
||||
req.Header.Set(key, clientIP)
|
||||
}
|
||||
default:
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate 路由规则验证
|
||||
func (r *RouteRule) Validate() error {
|
||||
if r.ID == "" {
|
||||
return fmt.Errorf("规则ID不能为空")
|
||||
}
|
||||
if r.Pattern == "" {
|
||||
return fmt.Errorf("路由规则必须包含匹配模式")
|
||||
}
|
||||
if r.Target == "" {
|
||||
return fmt.Errorf("路由规则必须包含目标地址")
|
||||
}
|
||||
if r.MatchType == MatchTypeRegex {
|
||||
if _, err := regexp.Compile(r.Pattern); err != nil {
|
||||
return fmt.Errorf("无效的正则表达式: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 辅助函数:通配符匹配
|
||||
func matchWildcard(host, pattern string) bool {
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return host == pattern
|
||||
}
|
||||
|
||||
parts := strings.Split(pattern, ".")
|
||||
hostParts := strings.Split(host, ".")
|
||||
|
||||
if len(parts) != len(hostParts) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] != "*" && parts[i] != hostParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
100
pkg/rule/types.go
Normal file
100
pkg/rule/types.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RuleType 规则类型
|
||||
type RuleType string
|
||||
|
||||
const (
|
||||
// RuleTypeDNS DNS规则类型
|
||||
RuleTypeDNS RuleType = "dns"
|
||||
// RuleTypeRewrite URL重写规则类型
|
||||
RuleTypeRewrite RuleType = "rewrite"
|
||||
// RuleTypeRoute 路由规则类型
|
||||
RuleTypeRoute RuleType = "route"
|
||||
)
|
||||
|
||||
// MatchType 匹配类型
|
||||
type MatchType string
|
||||
|
||||
const (
|
||||
// MatchTypeExact 精确匹配
|
||||
MatchTypeExact MatchType = "exact"
|
||||
// MatchTypeWildcard 通配符匹配
|
||||
MatchTypeWildcard MatchType = "wildcard"
|
||||
// MatchTypeRegex 正则匹配
|
||||
MatchTypeRegex MatchType = "regex"
|
||||
// MatchTypePath 路径匹配
|
||||
MatchTypePath MatchType = "path"
|
||||
)
|
||||
|
||||
// Rule 统一的规则接口
|
||||
type Rule interface {
|
||||
// GetID 获取规则ID
|
||||
GetID() string
|
||||
// GetType 获取规则类型
|
||||
GetType() RuleType
|
||||
// GetPriority 获取规则优先级
|
||||
GetPriority() int
|
||||
// Match 匹配规则
|
||||
Match(req *http.Request) bool
|
||||
// Apply 应用规则
|
||||
Apply(req *http.Request) error
|
||||
// Validate 验证规则
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// BaseRule 基础规则结构
|
||||
type BaseRule struct {
|
||||
ID string `json:"id"`
|
||||
Type RuleType `json:"type"`
|
||||
Priority int `json:"priority"`
|
||||
Pattern string `json:"pattern"`
|
||||
MatchType MatchType `json:"match_type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DNSRule DNS规则
|
||||
type DNSRule struct {
|
||||
BaseRule
|
||||
Targets []DNSTarget `json:"targets"`
|
||||
}
|
||||
|
||||
// DNSTarget DNS目标
|
||||
type DNSTarget struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
// RewriteRule URL重写规则
|
||||
type RewriteRule struct {
|
||||
BaseRule
|
||||
Replacement string `json:"replacement"`
|
||||
}
|
||||
|
||||
// RouteRule 路由规则
|
||||
type RouteRule struct {
|
||||
BaseRule
|
||||
Target string `json:"target"`
|
||||
HeaderModifier map[string]string `json:"header_modifier"`
|
||||
}
|
||||
|
||||
// GetID 获取规则ID
|
||||
func (r *BaseRule) GetID() string {
|
||||
return r.ID
|
||||
}
|
||||
|
||||
// GetType 获取规则类型
|
||||
func (r *BaseRule) GetType() RuleType {
|
||||
return r.Type
|
||||
}
|
||||
|
||||
// GetPriority 获取规则优先级
|
||||
func (r *BaseRule) GetPriority() int {
|
||||
return r.Priority
|
||||
}
|
87
pkg/server/graceful.go
Normal file
87
pkg/server/graceful.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GracefulServer 优雅关闭服务器
|
||||
type GracefulServer struct {
|
||||
// HTTP服务器
|
||||
server *http.Server
|
||||
// 等待组
|
||||
wg sync.WaitGroup
|
||||
// 停止信号
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewGracefulServer 创建优雅关闭服务器
|
||||
func NewGracefulServer(addr string, handler http.Handler) *GracefulServer {
|
||||
return &GracefulServer{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
},
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动服务器
|
||||
func (s *GracefulServer) Start() error {
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
slog.Info("启动HTTP服务器", "addr", s.server.Addr)
|
||||
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
slog.Error("HTTP服务器错误", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待中断信号
|
||||
s.waitForInterrupt()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForInterrupt 等待中断信号
|
||||
func (s *GracefulServer) waitForInterrupt() {
|
||||
// 创建信号通道
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 等待信号
|
||||
<-sigChan
|
||||
slog.Info("收到停止信号,开始优雅关闭")
|
||||
|
||||
// 创建关闭上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 关闭HTTP服务器
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
slog.Error("关闭HTTP服务器失败", "error", err)
|
||||
}
|
||||
|
||||
// 通知所有goroutine停止
|
||||
close(s.stopChan)
|
||||
|
||||
// 等待所有goroutine结束
|
||||
s.wg.Wait()
|
||||
|
||||
slog.Info("服务器已优雅关闭")
|
||||
}
|
||||
|
||||
// Stop 停止服务器
|
||||
func (s *GracefulServer) Stop() {
|
||||
close(s.stopChan)
|
||||
}
|
||||
|
||||
// WaitGroup 获取等待组
|
||||
func (s *GracefulServer) WaitGroup() *sync.WaitGroup {
|
||||
return &s.wg
|
||||
}
|
26
pool.go
Normal file
26
pool.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package goproxy
|
||||
|
||||
import "sync"
|
||||
|
||||
// 泛型对象池
|
||||
type pool[T any] struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newPool[T any](newFunc func() T) *pool[T] {
|
||||
return &pool[T]{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return newFunc()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pool[T]) Get() T {
|
||||
return p.pool.Get().(T)
|
||||
}
|
||||
|
||||
func (p *pool[T]) Put(x T) {
|
||||
p.pool.Put(x)
|
||||
}
|
Reference in New Issue
Block a user