commit 1a53a9a8f30f401fc5e2cf0287290c1689059893 Author: DarkiT Date: Fri Mar 14 18:50:49 2025 +0000 init diff --git a/README.md b/README.md new file mode 100644 index 0000000..c638a14 --- /dev/null +++ b/README.md @@ -0,0 +1,895 @@ +# GoProxy + +GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebSocket代理,并提供了丰富的功能和扩展点。 + +## 功能特性 + +- 支持HTTP、HTTPS和WebSocket代理 +- 支持正向代理和反向代理 +- 支持HTTPS解密(中间人模式) + - 自定义CA证书和私钥 + - 动态证书生成与缓存 + - 通配符域名证书支持 + - 支持RSA和ECDSA证书算法选择 +- 支持上游代理链 +- 支持多后端DNS解析和负载均衡 + - 支持一个域名对应多个后端服务器 + - 支持多种负载均衡策略(轮询、随机、第一个可用) + - 支持通配符域名解析 + - 支持自定义DNS解析规则 + - 支持动态添加/删除后端服务器 +- 支持健康检查 +- 支持请求重试 +- 支持HTTP缓存 +- 支持请求限流 +- 支持请求/响应压缩 + - 支持gzip压缩 + - 智能压缩决策 + - 可配置压缩级别 + - 支持最小压缩大小 + - 支持多种内容类型 +- 支持监控指标收集(Prometheus格式) + - 请求总数和延迟统计 + - 请求和响应大小统计 + - 错误计数 + - 活跃连接数 + - 连接池大小 + - 缓存命中率 + - 内存使用量 + - 后端健康状态 + - 后端响应时间 +- 支持自定义处理逻辑(委托模式) +- 支持DNS缓存 +- 支持认证授权 + - JWT认证 + - 基于角色的访问控制 + - 用户管理 + - 权限管理 + +## 安装 + +```bash +go get github.com/darkit/goproxy +``` + +## 快速开始 + +### 正向代理 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建代理 + p := goproxy.New(nil) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### 启用HTTPS解密 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/config" + "github.com/darkit/goproxy" +) + +func main() { + // 创建配置 + cfg := config.DefaultConfig() + cfg.DecryptHTTPS = true + cfg.CACert = "ca.crt" // CA证书路径 + cfg.CAKey = "ca.key" // CA私钥路径 + cfg.UseECDSA = true // 使用ECDSA生成证书(默认为false,使用RSA) + + // 可选:使用自定义TLS证书 + // cfg.TLSCert = "server.crt" + // cfg.TLSKey = "server.key" + + // 创建证书缓存 + certCache := &goproxy.MemCertCache{} + + // 创建代理 + p := goproxy.New(&goproxy.Options{ + Config: cfg, + CertCache: certCache, + }) + + // 启动HTTP服务器 + log.Println("HTTPS解密代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +> **注意**: 使用HTTPS解密功能时,需要在客户端安装CA证书,否则会出现证书警告。 + +### 反向代理 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/config" + "github.com/darkit/goproxy" +) + +func main() { + // 创建配置 + cfg := config.DefaultConfig() + cfg.ReverseProxy = true + cfg.EnableURLRewrite = true + cfg.AddXForwardedFor = true + cfg.AddXRealIP = true + + // 创建自定义委托 + delegate := &ReverseProxyDelegate{ + backend: "localhost:8081", + } + + // 创建代理 + p := goproxy.New(&goproxy.Options{ + Config: cfg, + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("反向代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// ReverseProxyDelegate 反向代理委托 +type ReverseProxyDelegate struct { + goproxy.DefaultDelegate + backend string +} + +// ResolveBackend 解析后端服务器 +func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.backend, nil +} +``` + +### 自定义委托 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建自定义委托 + delegate := &CustomDelegate{} + + // 创建代理 + p := goproxy.New(&goproxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + goproxy.DefaultDelegate +} + +// BeforeRequest 请求前事件 +func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) { + log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String()) +} + +// BeforeResponse 响应前事件 +func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) { + if err != nil { + log.Printf("响应错误: %v\n", err) + return + } + log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status) +} +``` + +### 完整示例 + +- 正向代理示例: [cmd/example/main.go](cmd/example/main.go) +- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go) + +### 使用函数式选项模式 + +```go +package main + +import ( + "log" + "net/http" + "time" + + "github.com/darkit/goproxy/pkg/metrics" + "github.com/darkit/goproxy" +) + +func main() { + // 创建监控指标 + metricsCollector := metrics.NewSimpleMetrics() + + // 创建证书缓存 + certCache := &goproxy.MemCertCache{} + + // 使用函数式选项模式创建代理 + p := goproxy.NewProxy( + // 启用HTTPS解密 + goproxy.WithDecryptHTTPS(certCache), + goproxy.WithCACertAndKey("ca.crt", "ca.key"), + + // 设置监控指标 + goproxy.WithMetrics(metricsCollector), + + // 设置请求超时和连接池 + goproxy.WithRequestTimeout(30 * time.Second), + goproxy.WithConnectionPoolSize(100), + goproxy.WithIdleTimeout(90 * time.Second), + + // 启用DNS缓存 + goproxy.WithDNSCacheTTL(10 * time.Minute), + + // 启用请求重试 + goproxy.WithEnableRetry(3, 1*time.Second, 10*time.Second), + + // 启用CORS支持 + goproxy.WithEnableCORS(true), + ) + + // 启动HTTP服务器和监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector.GetHandler()) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### 多后端DNS解析和负载均衡 + +```go +package main + +import ( + "log" + "net/http" + "time" + + "github.com/darkit/goproxy/pkg/dns" + "github.com/darkit/goproxy" +) + +func main() { + // 创建DNS解析器 + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略 + dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL + ) + + // 添加多个后端服务器 + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) + + // 添加通配符域名解析 + resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + + // 创建自定义委托 + delegate := &CustomDelegate{ + resolver: resolver, + } + + // 创建代理 + p := goproxy.New(&goproxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + goproxy.DefaultDelegate + resolver *dns.CustomResolver +} + +// ResolveBackend 解析后端服务器 +func (d *CustomDelegate) ResolveBackend(req *http.Request) (string, error) { + // 从请求中获取目标主机 + host := req.Host + if host == "" { + host = req.URL.Host + } + + // 解析域名获取后端服务器 + endpoint, err := d.resolver.ResolveWithPort(host, 80) + if err != nil { + return "", err + } + + return endpoint.String(), nil +} +``` + +### 从配置文件加载DNS规则 + +```go +package main + +import ( + "encoding/json" + "log" + "net/http" + "os" + + "github.com/darkit/goproxy/pkg/dns" + "github.com/darkit/goproxy" +) + +func main() { + // 创建DNS解析器 + resolver := dns.NewResolver() + + // 从JSON文件加载DNS规则 + if err := loadDNSConfig(resolver, "dns_config.json"); err != nil { + log.Fatalf("加载DNS配置失败: %v", err) + } + + // 创建自定义委托 + delegate := &CustomDelegate{ + resolver: resolver, + } + + // 创建代理 + p := goproxy.New(&goproxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// DNSConfig DNS配置结构 +type DNSConfig struct { + Records map[string][]string `json:"records"` // 域名到IP地址列表的映射 + Wildcards map[string][]string `json:"wildcards"` // 通配符域名到IP地址列表的映射 +} + +func loadDNSConfig(resolver *dns.CustomResolver, filename string) error { + data, err := os.ReadFile(filename) + if err != nil { + return err + } + + var config DNSConfig + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + // 加载精确匹配记录 + for host, ips := range config.Records { + for _, ip := range ips { + if err := resolver.Add(host, ip); err != nil { + return err + } + } + } + + // 加载通配符记录 + for pattern, ips := range config.Wildcards { + for _, ip := range ips { + if err := resolver.AddWildcard(pattern, ip); err != nil { + return err + } + } + } + + return nil +} +``` + +示例配置文件 `dns_config.json`: +```json +{ + "records": { + "api.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ] + }, + "wildcards": { + "*.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080" + ] + } +} +``` + +### 动态管理后端服务器 + +```go +// 添加新的后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) + +// 删除特定的后端服务器 +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) + +// 删除整个域名记录 +resolver.Remove("api.example.com") + +// 清除所有记录 +resolver.Clear() +``` + +## 架构设计 + +GoProxy采用模块化设计,主要包含以下模块: + +- **代理核心(Proxy)**:处理HTTP请求和响应,实现代理功能 +- **反向代理(ReverseProxy)**:处理反向代理请求,支持请求修改 +- **路由(Router)**:基于主机名、路径、正则表达式等规则路由请求到不同的后端 +- **代理上下文(Context)**:保存请求上下文信息,用于在处理过程中传递数据 +- **代理委托(Delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑 +- **连接缓冲区(ConnBuffer)**:封装网络连接,提供缓冲读写功能 +- **负载均衡(LoadBalancer)**:实现负载均衡算法,支持轮询、随机、权重等 +- **健康检查(HealthChecker)**:检查上游服务器的健康状态,自动剔除不健康的服务器 +- **缓存(Cache)**:实现HTTP缓存,减少重复请求 +- **缓存适配器(CacheAdapter)**:统一不同缓存实现的接口,提高代码可读性和性能 +- **证书生成(CertGenerator)**:动态生成TLS证书,支持HTTPS解密 +- **限流(RateLimit)**:实现请求限流,防止过载 +- **监控(Metrics)**:收集代理运行指标,用于监控和分析 +- **重试(Retry)**:实现请求重试,提高请求成功率 + +## 配置选项 + +GoProxy提供了丰富的配置选项,可以通过`Options`结构体进行配置: + +```go +type Options struct { + // 配置 + Config *config.Config + // 委托 + Delegate Delegate + // 证书缓存 + CertCache CertificateCache + // HTTP缓存 + HTTPCache cache.Cache + // 负载均衡器 + LoadBalancer loadbalance.LoadBalancer + // 健康检查器 + HealthChecker *healthcheck.HealthChecker + // 监控指标 + Metrics metrics.Metrics + // 客户端跟踪 + ClientTrace *httptrace.ClientTrace +} +``` + +### 函数式选项模式 + +GoProxy 现在支持函数式选项模式(Functional Options Pattern),通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于: + +- 参数配置更加直观和清晰 +- 可以灵活选择需要的配置项,不必记忆参数顺序 +- 代码可读性更高,便于维护 +- 可以逐步添加新的配置选项而不破坏兼容性 + +可以使用 `NewProxy` 函数和函数式选项创建代理: + +```go +// 创建一个简单的代理 +proxy := goproxy.NewProxy() + +// 创建一个功能丰富的代理 +proxy := goproxy.NewProxy( + goproxy.WithConfig(config.DefaultConfig()), + goproxy.WithHTTPCache(myCache), + goproxy.WithDecryptHTTPS(myCertCache), + goproxy.WithCACertAndKey("ca.crt", "ca.key"), + goproxy.WithMetrics(myMetrics), + goproxy.WithLoadBalancer(myLoadBalancer), + goproxy.WithRequestTimeout(10 * time.Second), + goproxy.WithEnableCORS(true) +) +``` + +### 可用的 With 方法 + +GoProxy 提供了以下 With 方法用于配置代理的各个方面: + +#### 基础配置选项 +- `WithConfig(cfg *config.Config)`: 设置代理配置 +- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用 +- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输 +- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪 + +#### 功能模块选项 +- `WithDelegate(delegate Delegate)`: 设置委托类 +- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存 +- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器 +- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器 +- `WithMetrics(m metrics.Metrics)`: 设置监控指标 + +#### 功能开启选项 +- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS +- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成(默认使用RSA) +- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截 +- `WithReverseProxy(enable bool)`: 启用反向代理模式 +- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试 +- `WithRateLimit(rps float64)`: 设置请求限流 +- `WithEnableCORS(enable bool)`: 启用CORS支持 + +#### 证书相关选项 +- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥 +- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥 + +#### 性能和超时相关选项 +- `WithConnectionPoolSize(size int)`: 设置连接池大小 +- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间 +- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间 +- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL + +### 配置HTTPS解密 + +要启用HTTPS解密功能,需要在配置中设置以下选项: + +```go +config := &config.Config{ + // 启用HTTPS解密 + DecryptHTTPS: true, + + // 方式一:使用CA证书和私钥动态生成证书 + CACert: "path/to/ca.crt", // CA证书路径 + CAKey: "path/to/ca.key", // CA私钥路径 + + // 选择证书生成算法(可选) + UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA) + + // 方式二:使用固定的TLS证书和私钥 + // TLSCert: "path/to/server.crt", + // TLSKey: "path/to/server.key", +} +``` + +或者使用函数式选项模式: + +```go +proxy := goproxy.NewProxy( + goproxy.WithDecryptHTTPS(&goproxy.MemCertCache{}), + goproxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"), + goproxy.WithEnableECDSA(true), // 使用ECDSA生成证书 + // 或者使用静态TLS证书 + // goproxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key") +) +``` + +同时,建议配置证书缓存以提高性能: + +```go +certCache := &goproxy.MemCertCache{} +``` + +## 扩展点 + +GoProxy提供了多个扩展点,可以通过实现相应的接口进行扩展: + +- **Delegate**:代理委托接口,用于自定义代理处理逻辑 +- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法 +- **Cache**:缓存接口,用于实现自定义缓存策略 +- **CertificateCache**:证书缓存接口,用于自定义证书存储方式 +- **Metrics**:监控接口,用于实现自定义监控指标收集 + +## 反向代理特性 + +GoProxy的反向代理模式提供以下特性: + +- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则 +- **请求修改**:支持修改发往后端服务器的请求 +- **响应修改**:支持修改来自后端服务器的响应 +- **保留客户端信息**:支持添加X-Forwarded-For和X-Real-IP头 +- **CORS支持**:支持自动添加CORS头 +- **WebSocket支持**:支持WebSocket协议的透明代理 +- **负载均衡**:支持多种负载均衡算法 +- **健康检查**:支持对后端服务器进行健康检查 +- **监控指标**:支持收集反向代理的监控指标 + +## 监控指标 + +GoProxy提供了两种监控指标实现:PrometheusMetrics和SimpleMetrics。 + +### PrometheusMetrics + +PrometheusMetrics是一个完整的Prometheus指标实现,提供以下指标: + +- `proxy_requests_total`: 请求总数(按方法、路径、状态码分类) +- `proxy_request_latency_seconds`: 请求延迟(按方法、路径分类) +- `proxy_request_size_bytes`: 请求大小(按方法、路径分类) +- `proxy_response_size_bytes`: 响应大小(按方法、路径分类) +- `proxy_errors_total`: 错误总数(按类型分类) +- `proxy_active_connections`: 活跃连接数 +- `proxy_connection_pool_size`: 连接池大小 +- `proxy_cache_hit_rate`: 缓存命中率 +- `proxy_memory_usage_bytes`: 内存使用量 + +使用示例: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/metrics" + "github.com/darkit/goproxy" +) + +func main() { + // 创建Prometheus指标收集器 + metricsCollector := metrics.NewPrometheusMetrics() + + // 创建代理 + p := goproxy.NewProxy( + goproxy.WithMetrics(metricsCollector), + ) + + // 启动监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector.GetHandler()) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### SimpleMetrics + +SimpleMetrics是一个简单的指标实现,提供基本的指标收集功能: + +- 请求计数 +- 错误计数 +- 活跃连接数 +- 累计响应时间 +- 传输字节数 +- 后端健康状态 +- 后端响应时间 +- 缓存命中计数 + +使用示例: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/metrics" + "github.com/darkit/goproxy" +) + +func main() { + // 创建简单指标收集器 + metricsCollector := metrics.NewSimpleMetrics() + + // 创建代理 + p := goproxy.NewProxy( + goproxy.WithMetrics(metricsCollector), + ) + + // 启动监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector.GetHandler()) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### 指标中间件 + +GoProxy还提供了一个指标中间件,可以用于收集HTTP请求的指标: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/metrics" +) + +func main() { + // 创建指标收集器 + metricsCollector := metrics.NewPrometheusMetrics() + + // 创建指标中间件 + metricsMiddleware := metrics.NewMetricsMiddleware(metricsCollector) + + // 创建HTTP处理器 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 处理请求 + w.Write([]byte("Hello, World!")) + }) + + // 使用中间件包装处理器 + wrappedHandler := metricsMiddleware.Middleware(handler) + + // 启动HTTP服务器 + log.Println("HTTP服务器启动在 :8080") + if err := http.ListenAndServe(":8080", wrappedHandler); err != nil { + log.Fatalf("HTTP服务器启动失败: %v", err) + } +} +``` + +### 使用压缩中间件 + +GoProxy提供了压缩中间件,支持请求和响应的gzip压缩: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/middleware" + "github.com/darkit/goproxy" +) + +func main() { + // 创建压缩中间件 + compressionMiddleware := middleware.NewCompressionMiddleware(6, 1024) // 压缩级别6,最小压缩大小1KB + + // 创建代理 + p := goproxy.NewProxy( + goproxy.WithMiddleware(compressionMiddleware), + ) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +压缩中间件提供以下特性: + +- 自动检测客户端是否支持gzip压缩 +- 智能判断内容类型是否适合压缩 +- 可配置压缩级别(0-9) +- 可配置最小压缩大小 +- 支持多种内容类型 +- 自动处理压缩请求体 +- 自动添加压缩相关响应头 + +### 使用认证授权系统 + +GoProxy提供了完整的认证授权系统,支持JWT认证和基于角色的访问控制: + +```go +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy/pkg/auth" + "github.com/darkit/goproxy" +) + +func main() { + // 创建认证系统 + auth := auth.NewAuth("your-secret-key") + + // 添加用户和角色 + auth.AddUser("admin", "password123", []string{"admin"}) + auth.AddUser("user", "password456", []string{"user"}) + + // 创建代理 + p := goproxy.NewProxy( + goproxy.WithAuth(auth), + ) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +认证授权系统提供以下特性: + +- JWT令牌认证 +- 基于角色的访问控制 +- 用户管理 +- 密码加密存储 +- 权限管理 +- 认证中间件 + +## 贡献 + +欢迎贡献代码、报告问题或提出建议。请遵循以下步骤: + +1. Fork 项目 +2. 创建特性分支 (`git checkout -b feature/amazing-feature`) +3. 提交更改 (`git commit -m 'Add some amazing feature'`) +4. 推送到分支 (`git push origin feature/amazing-feature`) +5. 创建 Pull Request + +## 许可证 + +本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file diff --git a/README_DNS.md b/README_DNS.md new file mode 100644 index 0000000..8521c4d --- /dev/null +++ b/README_DNS.md @@ -0,0 +1,338 @@ +# GoProxy自定义DNS解析功能 + +该功能允许GoProxy使用自定义DNS解析器,实现以下功能: + +- 自定义域名解析 +- 动态变更后端服务器IP +- 自定义后端服务器端口 +- 泛解析(通配符域名)支持 +- 多后端服务器支持 + - 一个域名对应多个后端服务器 + - 支持多种负载均衡策略(轮询、随机、第一个可用) + - 支持动态添加/删除后端服务器 +- 负载均衡和故障转移 +- 绕过DNS污染 +- 高效的DNS缓存 + +## 特性 + +- **自定义记录**:直接设置域名到IP的映射 +- **自定义端口**:为每个域名指定自定义端口,无需在URL中指定 +- **泛解析**:支持通配符域名(如`*.example.com`)自动匹配多个子域名 +- **多级泛解析**:支持复杂的通配符模式(如`api.*.example.com`) +- **多后端支持**:支持一个域名配置多个后端服务器 +- **负载均衡**:支持多种负载均衡策略 +- **备用解析**:在自定义记录未找到时可选择使用系统DNS +- **DNS缓存**:缓存解析结果以提高性能 +- **自动重试**:解析失败时可配置重试策略 +- **加载配置**:从JSON文件或hosts格式文件加载配置 +- **自定义拨号器**:与net标准库兼容的拨号器 + +## 安装 + +确保已安装Go(建议1.22或更高版本): + +```bash +git clone github.com/darkit/goproxy +cd goproxy +go build ./... +``` + +## 使用方法 + +### 1. HTTP代理(自定义DNS) + +以下命令会启动一个HTTP代理(监听端口8080),它使用自定义DNS解析: + +```bash +go run cmd/custom_dns_proxy/main.go +``` + +### 2. 自定义端口代理 + +支持为不同域名指定不同端口的代理: + +```bash +go run cmd/custom_port_proxy/main.go +``` + +### 3. 泛解析DNS代理 + +支持通配符域名解析的代理: + +```bash +# 使用默认的示例泛解析规则 +go run cmd/wildcard_dns_proxy/main.go + +# 透明代理模式(使用请求中的Host进行匹配) +go run cmd/wildcard_dns_proxy/main.go -target "" + +# 使用自定义配置文件 +go run cmd/wildcard_dns_proxy/main.go -dns examples/wildcard_dns_config.json + +# 使用hosts格式配置文件 +go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt +``` + +#### 参数说明 + +- `-listen`: 监听地址,默认 `:8080` +- `-target`: 目标主机名,空字符串表示使用请求中的Host头 +- `-port`: 默认目标端口,默认 `443` +- `-dns`: DNS配置文件(JSON格式) +- `-hosts`: hosts格式配置文件 + +## DNS配置格式 + +### JSON格式 + +```json +{ + "records": { + "example.com": ["93.184.216.34", "93.184.216.35"], + "api.example.com": ["93.184.216.35:8443", "93.184.216.36:8443"], + + "*.github.com": ["140.82.121.3", "140.82.121.4"], + "github.com": ["140.82.121.4", "140.82.121.5"], + + "*.dev.local": ["127.0.0.1:3000", "127.0.0.1:3001"], + "api.*.dev.local": ["127.0.0.1:3001", "127.0.0.1:3002"] + }, + "use_fallback": true, + "ttl": 300, + "load_balance_strategy": "round_robin" +} +``` + +### Hosts格式 + +``` +# 精确匹配(多后端) +93.184.216.34 example.com +93.184.216.35 example.com +93.184.216.35:8443 api.example.com +93.184.216.36:8443 api.example.com + +# 泛解析(多后端) +140.82.121.3 *.github.com +140.82.121.4 *.github.com +127.0.0.1:3000 *.dev.local +127.0.0.1:3001 *.dev.local +127.0.0.1:3001 api.*.dev.local +127.0.0.1:3002 api.*.dev.local +``` + +## 编程接口 + +### 使用DNS解析器 + +```go +import "github.com/darkit/goproxy/pkg/dns" + +// 创建解析器 +resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略 + dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL +) + +// 添加多个后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) +resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) +resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) + +// 添加泛解析记录(多后端) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + +// 解析域名(使用负载均衡策略选择后端) +endpoint, err := resolver.ResolveWithPort("api.example.com", 443) +if err != nil { + log.Fatalf("解析失败: %v", err) +} +fmt.Printf("解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port) + +// 动态添加/删除后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) +``` + +### 负载均衡策略 + +GoProxy支持三种负载均衡策略: + +1. **轮询策略(Round Robin)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), + ) + ``` + +2. **随机策略(Random)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.Random), + ) + ``` + +3. **第一个可用策略(First Available)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.FirstAvailable), + ) + ``` + +### 使用DNS拨号器 + +```go +import "github.com/darkit/goproxy/pkg/dns" + +// 创建解析器 +resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), +) +resolver.AddWithPort("example.com", "192.168.1.1", 8080) +resolver.AddWithPort("example.com", "192.168.1.2", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.3", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.4", 8080) + +// 创建拨号器 +dialer := dns.NewDialer(resolver) + +// 使用拨号器连接(会自动应用负载均衡) +conn, err := dialer.Dial("tcp", "api.example.com:443") +if err != nil { + log.Fatalf("连接失败: %v", err) +} +defer conn.Close() + +// 或者获取用于http.Transport的拨号上下文函数 +transport := &http.Transport{ + DialContext: dialer.DialContext, +} +client := &http.Client{Transport: transport} +``` + +## 高级用法 + +### 多后端配置模式 + +多后端配置支持以下模式: + +1. **精确匹配多后端**: + ```go + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) + ``` + +2. **泛解析多后端**: + ```go + resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + ``` + +3. **混合配置**: + ```go + // 精确匹配优先于泛解析 + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + ``` + +### 动态管理后端服务器 + +```go +// 添加新的后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) + +// 删除特定的后端服务器 +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) + +// 删除整个域名记录 +resolver.Remove("api.example.com") + +// 清除所有记录 +resolver.Clear() +``` + +### 健康检查集成 + +可以结合健康检查功能,自动剔除不健康的后端服务器: + +```go +// 创建健康检查器 +healthChecker := healthcheck.NewChecker( + healthcheck.WithCheckInterval(30*time.Second), + healthcheck.WithTimeout(5*time.Second), +) + +// 添加健康检查 +healthChecker.Add("api.example.com", "192.168.1.1:8080") +healthChecker.Add("api.example.com", "192.168.1.2:8080") + +// 在解析器中使用健康检查结果 +resolver.SetHealthChecker(healthChecker) +``` + +## 应用场景 + +### 高可用部署 + +使用多后端配置实现高可用: + +``` +# 主备模式 +192.168.1.1 api.example.com # 主服务器 +192.168.1.2 api.example.com # 备用服务器 + +# 负载均衡模式 +192.168.1.1 api.example.com # 服务器1 +192.168.1.2 api.example.com # 服务器2 +192.168.1.3 api.example.com # 服务器3 +``` + +### 多环境部署 + +为不同环境配置不同的后端服务器组: + +``` +# 测试环境 +192.168.1.10 *.test.example.com +192.168.1.11 *.test.example.com + +# 预发布环境 +192.168.1.20 *.staging.example.com +192.168.1.21 *.staging.example.com + +# 生产环境 +192.168.1.30 *.production.example.com +192.168.1.31 *.production.example.com +``` + +### 微服务架构 + +为不同类型的微服务配置多个后端: + +``` +# 认证服务 +10.0.0.1:8001 *.auth.internal +10.0.0.2:8001 *.auth.internal + +# 用户服务 +10.0.0.3:8002 *.user.internal +10.0.0.4:8002 *.user.internal + +# 支付服务 +10.0.0.5:8003 *.payment.internal +10.0.0.6:8003 *.payment.internal +``` + +## 注意事项 + +1. 通配符域名只在我们的自定义DNS解析器中有效,不会影响系统DNS +2. 泛解析规则的顺序会影响匹配结果,后添加的规则优先级更高 +3. 过多的泛解析规则可能会影响性能,建议合理组织规则 +4. 当域名同时匹配多个规则时,精确匹配优先于通配符匹配 +5. 自签名证书会导致浏览器警告,仅用于测试目的 +6. 多后端配置时,建议使用健康检查确保后端服务器可用性 +7. 负载均衡策略的选择应根据实际需求进行配置 +8. 动态添加/删除后端服务器时,需要考虑并发安全性 \ No newline at end of file diff --git a/base.go b/base.go new file mode 100644 index 0000000..7065c63 --- /dev/null +++ b/base.go @@ -0,0 +1,47 @@ +package goproxy + +import ( + "net/http" +) + +// BaseProxy 基础代理接口 +type BaseProxy interface { + // ServeHTTP 处理HTTP请求 + ServeHTTP(w http.ResponseWriter, r *http.Request) + // Close 关闭代理 + Close() error +} + +// BaseConfig 基础配置 +type BaseConfig struct { + // 监听地址 + ListenAddr string + // 是否启用HTTPS + EnableHTTPS bool + // TLS配置 + TLSConfig *TLSConfig + // 是否启用WebSocket + EnableWebSocket bool + // 是否启用压缩 + EnableCompression bool + // 是否启用CORS + EnableCORS bool + // 是否保留客户端IP + PreserveClientIP bool + // 是否添加X-Forwarded-For头 + AddXForwardedFor bool + // 是否添加X-Real-IP头 + AddXRealIP bool +} + +// TLSConfig TLS配置 +type TLSConfig struct { + // 证书文件路径 + CertFile string + // 密钥文件路径 + KeyFile string + // 是否跳过证书验证 + InsecureSkipVerify bool + // 是否使用ECDSA + UseECDSA bool +} diff --git a/certificate.go b/certificate.go new file mode 100644 index 0000000..3c66206 --- /dev/null +++ b/certificate.go @@ -0,0 +1,552 @@ +package goproxy + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "hash/fnv" + "math/big" + "net" + "os" + "strings" + "time" +) + +// 内置默认CA证书和私钥 +var ( + // 默认根证书 + defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE----- +MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF +Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK +EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy +NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G +A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw +EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX +mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO +BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG +A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI +MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC +A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY +9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO +-----END CERTIFICATE----- +`) + // 默认根私钥 + defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49 +AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/ +nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA== +-----END EC PRIVATE KEY----- +`) +) + +// 加载和初始化默认根证书 +var ( + defaultRootCA *x509.Certificate + defaultRootKey *ecdsa.PrivateKey +) + +func init() { + var err error + block, _ := pem.Decode(defaultRootCAPem) + if block == nil { + panic("解析默认根证书PEM块失败") + } + defaultRootCA, err = x509.ParseCertificate(block.Bytes) + if err != nil { + panic(fmt.Errorf("加载默认根证书失败: %s", err)) + } + + block, _ = pem.Decode(defaultRootKeyPem) + if block == nil { + panic("解析默认根私钥PEM块失败") + } + defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes) + if err != nil { + panic(fmt.Errorf("加载默认根私钥失败: %s", err)) + } +} + +// CertManager 证书管理器 +type CertManager struct { + // 证书缓存 + cache CertificateCache + // 默认私钥,可用于多个证书共享 + defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥 + // 默认使用ECDSA P-256曲线 + curve elliptic.Curve + // 证书有效期(年) + validityYears int + // 是否使用ECDSA(否则使用RSA) + useECDSA bool +} + +// NewCertManager 创建证书管理器 +func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager { + manager := &CertManager{ + cache: cache, + curve: elliptic.P256(), // 默认使用P-256曲线 + validityYears: 1, // 默认证书有效期1年 + useECDSA: true, // 默认使用ECDSA + } + + // 应用选项 + for _, option := range options { + option(manager) + } + + return manager +} + +// CertManagerOption 证书管理器选项 +type CertManagerOption func(*CertManager) + +// WithUseECDSA 设置是否使用ECDSA(否则使用RSA) +func WithUseECDSA(useECDSA bool) CertManagerOption { + return func(m *CertManager) { + m.useECDSA = useECDSA + } +} + +// WithDefaultPrivateKey 设置是否使用默认私钥 +func WithDefaultPrivateKey(enable bool) CertManagerOption { + return func(m *CertManager) { + if enable { + if m.useECDSA { + // 生成ECDSA私钥 + priv, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err)) + } + m.defaultPrivateKey = priv + } else { + // 生成RSA私钥 + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(fmt.Errorf("生成默认RSA私钥失败: %s", err)) + } + m.defaultPrivateKey = priv + } + } + } +} + +// WithCurve 设置椭圆曲线 +func WithCurve(curve elliptic.Curve) CertManagerOption { + return func(m *CertManager) { + m.curve = curve + } +} + +// WithValidityYears 设置证书有效期(年) +func WithValidityYears(years int) CertManagerOption { + return func(m *CertManager) { + if years > 0 { + m.validityYears = years + } + } +} + +// GenerateTLSConfig 为指定主机生成TLS配置 +func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) { + // 处理可能的端口 + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + + // 检查证书缓存 + if m.cache != nil { + // 检查主域名和子域名 + fields := strings.Split(host, ".") + domains := []string{host} + + // 添加父域名 + if len(fields) > 2 { + domains = append(domains, strings.Join(fields[1:], ".")) + } + + // 查找缓存 + for _, domain := range domains { + if cert := m.cache.Get(domain); cert != nil { + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + } + } + + // 生成新证书 + cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey) + if err != nil { + return nil, err + } + + // 缓存证书 + if m.cache != nil { + // 缓存主机名 + m.cache.Set(host, cert) + + // 如果是IP地址,不进行其他处理 + if net.ParseIP(host) != nil { + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + + // 缓存域名和子域名证书 + fields := strings.Split(host, ".") + if len(fields) >= 2 { + // 缓存主域名 + domain := strings.Join(fields[1:], ".") + m.cache.Set(domain, cert) + } + } + + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil +} + +// GenerateCertificate 生成证书 +func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) { + // 准备私钥 + var priv interface{} + var pubKey interface{} + var err error + + // 使用默认私钥或生成新私钥 + if m.defaultPrivateKey != nil { + priv = m.defaultPrivateKey + } else if m.useECDSA { + // 生成ECDSA私钥 + ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err) + } + priv = ecdsaKey + pubKey = &ecdsaKey.PublicKey + } else { + // 生成RSA私钥 + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("生成RSA私钥失败: %s", err) + } + priv = rsaKey + pubKey = &rsaKey.PublicKey + } + + // 获取公钥 + if pubKey == nil { + switch k := priv.(type) { + case *ecdsa.PrivateKey: + pubKey = &k.PublicKey + case *rsa.PrivateKey: + pubKey = &k.PublicKey + default: + return nil, fmt.Errorf("不支持的私钥类型") + } + } + + // 创建证书模板 + template := m.createCertificateTemplate(host) + + // 签名证书 + derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey) + if err != nil { + return nil, fmt.Errorf("创建证书失败: %s", err) + } + + // 编码为PEM格式 + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }) + + // 编码私钥 + var keyPEM []byte + switch k := priv.(type) { + case *ecdsa.PrivateKey: + privBytes, err := x509.MarshalECPrivateKey(k) + if err != nil { + return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: privBytes, + }) + case *rsa.PrivateKey: + privBytes := x509.MarshalPKCS1PrivateKey(k) + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privBytes, + }) + default: + return nil, fmt.Errorf("不支持的私钥类型") + } + + // 创建TLS证书 + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("创建TLS证书对失败: %s", err) + } + + return &tlsCert, nil +} + +// createCertificateTemplate 创建证书模板 +func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate { + // 使用基于主机名的哈希值作为序列号 + fv := fnv.New64a() + fv.Write([]byte(host)) + serialNumber := big.NewInt(0).SetUint64(fv.Sum64()) + + // 准备模板 + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: host, + Organization: []string{"GoProxy Dynamic CA"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效,容忍时间偏差 + NotAfter: time.Now().AddDate(m.validityYears, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: false, + } + + // 处理IP地址和域名 + ipAddr := net.ParseIP(host) + if ipAddr != nil { + template.IPAddresses = []net.IP{ipAddr} + } else { + // 移除可能的端口部分 + if strings.Contains(host, ":") { + host = strings.Split(host, ":")[0] + } + + // 将主机名添加到DNS名称列表 + template.DNSNames = []string{host} + + // 添加通配符域名支持 + fields := strings.Split(host, ".") + fieldNum := len(fields) + + // 为每一级子域名添加通配符 + for i := 0; i <= (fieldNum - 2); i++ { + wildcardDomain := "*." + strings.Join(fields[i:], ".") + // 避免重复 + if wildcardDomain != host { + template.DNSNames = append(template.DNSNames, wildcardDomain) + } + } + } + + return template +} + +// LoadCAFromFiles 从文件加载CA证书和私钥 +func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) { + // 读取CA证书 + caCertPEM, err := os.ReadFile(certFile) + if err != nil { + return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err) + } + + block, _ := pem.Decode(caCertPEM) + if block == nil { + return nil, nil, fmt.Errorf("解析CA证书PEM块失败") + } + + caCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析CA证书失败: %s", err) + } + + // 读取CA私钥 + caKeyPEM, err := os.ReadFile(keyFile) + if err != nil { + return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err) + } + + block, _ = pem.Decode(caKeyPEM) + if block == nil { + return nil, nil, fmt.Errorf("解析CA私钥PEM块失败") + } + + var caKey *ecdsa.PrivateKey + + // 尝试不同的私钥格式 + switch block.Type { + case "EC PRIVATE KEY": + caKey, err = x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err) + } + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err) + } + var ok bool + caKey, ok = key.(*ecdsa.PrivateKey) + if !ok { + return nil, nil, fmt.Errorf("私钥不是ECDSA类型") + } + default: + return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type) + } + + return caCert, caKey, nil +} + +// GetDefaultRootCA 获取默认根证书和私钥 +func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) { + return defaultRootCA, defaultRootKey +} + +// GenerateRootCA 生成新的根证书和私钥 +func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) { + // 生成私钥 + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err) + } + + // 随机生成序列号 + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("生成序列号失败: %s", err) + } + + // 创建根证书模板 + notBefore := time.Now().Add(-10 * time.Minute) + notAfter := notBefore.AddDate(validYears, 0, 0) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "GoProxy Root CA", + Organization: []string{"GoProxy"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 2, + } + + // 自签名 + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("创建根证书失败: %s", err) + } + + // 解析生成的证书 + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err) + } + + return cert, priv, nil +} + +// SaveCertificateToFile 将证书和私钥保存到文件 +func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error { + // 保存证书 + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err := os.WriteFile(certFile, certPEM, 0o644); err != nil { + return fmt.Errorf("保存证书到文件失败: %s", err) + } + + // 保存私钥 + keyBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("序列化私钥失败: %s", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + }) + if err := os.WriteFile(keyFile, keyPEM, 0o600); err != nil { + return fmt.Errorf("保存私钥到文件失败: %s", err) + } + + return nil +} + +// GenerateRootCA 生成根证书 +func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) { + // 生成私钥 + var priv interface{} + var pubKey interface{} + var err error + + if m.useECDSA { + // 生成ECDSA私钥 + ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err) + } + priv = ecdsaKey + pubKey = &ecdsaKey.PublicKey + } else { + // 生成RSA私钥 + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err) + } + priv = rsaKey + pubKey = &rsaKey.PublicKey + } + + // 创建根证书模板 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "GoProxy Root CA", + Organization: []string{"GoProxy Authority"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: time.Now().Add(-10 * time.Minute), + NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期 + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + + // 自签名 + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("创建根证书失败: %s", err) + } + + // 解析证书 + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err) + } + + return cert, priv, nil +} diff --git a/cmd/cmd_reverse_proxy.go b/cmd/cmd_reverse_proxy.go new file mode 100644 index 0000000..89f6806 --- /dev/null +++ b/cmd/cmd_reverse_proxy.go @@ -0,0 +1,86 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 解析命令行参数 + listenAddr := flag.String("listen", ":8080", "监听地址") + targetAddr := flag.String("target", "http://localhost:8080", "目标服务地址") + enableHTTPS := flag.Bool("https", false, "启用HTTPS") + certFile := flag.String("cert", "", "TLS证书文件") + keyFile := flag.String("key", "", "TLS私钥文件") + enableCache := flag.Bool("cache", false, "启用缓存") + enableCompression := flag.Bool("compression", true, "启用压缩") + enableCORS := flag.Bool("cors", false, "启用CORS") + preserveClientIP := flag.Bool("preserve-ip", true, "保留客户端IP") + addXForwardedFor := flag.Bool("x-forwarded-for", true, "添加X-Forwarded-For头") + addXRealIP := flag.Bool("x-real-ip", true, "添加X-Real-IP头") + + flag.Parse() + + // 创建配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = *listenAddr // 监听地址 + cfg.TargetAddr = *targetAddr // 目标地址 + cfg.DecryptHTTPS = *enableHTTPS // 启用HTTPS + cfg.TLSCert = *certFile // 证书文件 + cfg.TLSKey = *keyFile // 私钥文件 + cfg.EnableCache = *enableCache // 启用缓存 + cfg.EnableCompression = *enableCompression // 启用压缩 + cfg.EnableCORS = *enableCORS // 启用CORS + cfg.PreserveClientIP = *preserveClientIP // 保留客户端IP + cfg.AddXForwardedFor = *addXForwardedFor // 添加X-Forwarded-For头 + cfg.AddXRealIP = *addXRealIP // 添加X-Real-IP头 + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + }) + + // 创建HTTP服务器 + server := &http.Server{ + Addr: *listenAddr, + Handler: proxy, + } + + // 优雅退出 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + // 启动HTTP服务器 + go func() { + fmt.Printf("反向代理启动在 %s,转发到 %s\n", *listenAddr, *targetAddr) + var err error + if *enableHTTPS && *certFile != "" && *keyFile != "" { + err = server.ListenAndServeTLS(*certFile, *keyFile) + } else { + err = server.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + log.Fatalf("服务器启动失败: %v\n", err) + } + }() + + // 等待退出信号 + <-quit + fmt.Println("服务器正在关闭...") + + // 关闭服务器 + server.Shutdown(nil) // 简化版,不使用context + + fmt.Println("服务器已关闭") +} diff --git a/cmd/functional_options_proxy.go b/cmd/functional_options_proxy.go new file mode 100644 index 0000000..543aef0 --- /dev/null +++ b/cmd/functional_options_proxy.go @@ -0,0 +1,100 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "time" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" + "github.com/darkit/goproxy/pkg/cache" + "github.com/darkit/goproxy/pkg/healthcheck" + "github.com/darkit/goproxy/pkg/loadbalance" + "github.com/darkit/goproxy/pkg/metrics" +) + +// 这是functional_options_proxy的main函数,改名避免与其他文件冲突 +func main_functional_options() { + // 创建基本配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = ":8080" // 监听地址 + cfg.TargetAddr = "http://example.com" // 目标地址 + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + cfg.AddXRealIP = true // 添加X-Real-IP头 + cfg.EnableCompression = true // 启用压缩 + + // 创建负载均衡器 + // 根据linter错误,要检查正确的API + backends := []string{ + "http://server1:8080", + "http://server2:8080", + "http://server3:8080", + } + lb := loadbalance.NewRoundRobinBalancer() + // 设置后端服务器列表 + lb.AddList(backends, 1) + + // 健康检查配置 + // 根据错误信息调整构造函数 + healthCfg := &healthcheck.Config{ + Interval: 10 * time.Second, + Timeout: 3 * time.Second, + MaxFails: 3, + } + // 假设NewHealthChecker只接受Config参数 + healthChecker := healthcheck.NewHealthChecker(healthCfg) + // 然后单独设置后端 + healthChecker.AddTargetList(backends) + + // 创建缓存 + // 根据错误信息修正参数 + memCache := cache.NewMemoryCache(5*time.Minute, 30*time.Second, 1000) + + // 使用功能选项模式创建反向代理 + proxy := goproxy.NewWithOptions( + // 将配置对象传入 + goproxy.WithConfig(cfg), + + // 性能优化 + goproxy.WithConnectionPoolSize(100), + goproxy.WithIdleTimeout(30*time.Second), + goproxy.WithRequestTimeout(10*time.Second), + + // 负载均衡 + goproxy.WithLoadBalancer(lb), + + // 健康检查 + goproxy.WithHealthChecker(healthChecker), + + // 缓存 + goproxy.WithHTTPCache(memCache), + + // 指标收集 + goproxy.WithMetrics(metrics.NewSimpleMetrics()), + + // 重试策略 + goproxy.WithEnableRetry(3, 1*time.Second, 10*time.Second), + + // DNS缓存 + goproxy.WithDNSCacheTTL(5*time.Minute), + + // 启用CORS + goproxy.WithEnableCORS(true), + ) + + // 启动服务器 + fmt.Println("反向代理启动在 :8080,转发到 http://example.com 并启用负载均衡") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } +} + +// 添加实际的main函数,调用上面的示例函数 +func main() { + main_functional_options() +} diff --git a/cmd/https_reverse_proxy.go b/cmd/https_reverse_proxy.go new file mode 100644 index 0000000..14e5db1 --- /dev/null +++ b/cmd/https_reverse_proxy.go @@ -0,0 +1,152 @@ +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 解析命令行参数 + var ( + listenAddr = flag.String("listen", ":443", "HTTPS监听地址") + httpAddr = flag.String("http", ":80", "HTTP监听地址,用于重定向到HTTPS") + targetAddr = flag.String("target", "http://localhost:8080", "目标服务地址") + certFile = flag.String("cert", "server.crt", "TLS证书文件") + keyFile = flag.String("key", "server.key", "TLS私钥文件") + autoRedirect = flag.Bool("redirect", true, "自动将HTTP请求重定向到HTTPS") + insecureSkipVerify = flag.Bool("insecure", false, "跳过目标HTTPS验证") + ) + flag.Parse() + + // 检查证书和私钥文件 + if !fileExists(*certFile) || !fileExists(*keyFile) { + log.Printf("证书或私钥文件不存在: %s, %s", *certFile, *keyFile) + if *autoRedirect { + log.Printf("将只启动HTTP服务") + } else { + log.Fatalf("无法启动HTTPS服务,请提供有效的证书和私钥文件") + } + } + + // 创建配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = *listenAddr // HTTPS监听地址 + cfg.TargetAddr = *targetAddr // 目标地址 + cfg.DecryptHTTPS = true // 启用HTTPS + cfg.TLSCert = *certFile // 证书文件 + cfg.TLSKey = *keyFile // 私钥文件 + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + cfg.InsecureSkipVerify = *insecureSkipVerify // 是否跳过目标HTTPS验证 + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + }) + + // 创建HTTPS服务器 + httpsServer := &http.Server{ + Addr: *listenAddr, + Handler: proxy, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + } + + // 创建HTTP重定向服务器 + var httpServer *http.Server + if *autoRedirect { + httpServer = &http.Server{ + Addr: *httpAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 构建重定向URL + host := r.Host + // 如果Host包含端口,去除端口 + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + // 构建HTTPS URL + target := "https://" + host + // 如果HTTPS端口不是443,添加端口 + if *listenAddr != ":443" { + _, port, _ := net.SplitHostPort(*listenAddr) + if port != "" && port != "443" { + target = "https://" + host + ":" + port + } + } + // 添加路径和查询参数 + target += r.URL.Path + if r.URL.RawQuery != "" { + target += "?" + r.URL.RawQuery + } + // 重定向 + http.Redirect(w, r, target, http.StatusMovedPermanently) + }), + } + } + + // 优雅退出 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + // 启动HTTPS服务器 + if fileExists(*certFile) && fileExists(*keyFile) { + go func() { + fmt.Printf("HTTPS反向代理启动在 %s,转发到 %s\n", *listenAddr, *targetAddr) + if err := httpsServer.ListenAndServeTLS(*certFile, *keyFile); err != nil && err != http.ErrServerClosed { + log.Printf("HTTPS服务器启动失败: %v\n", err) + } + }() + } + + // 启动HTTP重定向服务器 + if *autoRedirect && httpServer != nil { + go func() { + fmt.Printf("HTTP重定向服务启动在 %s\n", *httpAddr) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Printf("HTTP服务器启动失败: %v\n", err) + } + }() + } + + // 等待退出信号 + <-quit + fmt.Println("服务器正在关闭...") + + // 关闭服务器 + if httpServer != nil { + if err := httpServer.Close(); err != nil { + log.Printf("HTTP服务器关闭失败: %v\n", err) + } + } + + if fileExists(*certFile) && fileExists(*keyFile) { + if err := httpsServer.Close(); err != nil { + log.Printf("HTTPS服务器关闭失败: %v\n", err) + } + } + + fmt.Println("服务器已关闭") +} + +// 检查文件是否存在 +func fileExists(filename string) bool { + info, err := os.Stat(filename) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() +} diff --git a/cmd/middleware_reverse_proxy.go b/cmd/middleware_reverse_proxy.go new file mode 100644 index 0000000..97fd070 --- /dev/null +++ b/cmd/middleware_reverse_proxy.go @@ -0,0 +1,93 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "time" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +// 自定义代理委托实现 +type CustomDelegate struct { + goproxy.DefaultDelegate + logger *log.Logger +} + +// 创建自定义代理委托 +func NewCustomDelegate() *CustomDelegate { + return &CustomDelegate{ + logger: log.New(log.Writer(), "[反向代理] ", log.LstdFlags), + } +} + +// 连接处理 +func (d *CustomDelegate) Connect(ctx *goproxy.Context, rw http.ResponseWriter) { + d.logger.Printf("新连接: %s -> %s", ctx.Req.RemoteAddr, ctx.Req.Host) +} + +// 请求前处理 +func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) { + // 添加自定义请求头 + ctx.Req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339)) + ctx.Req.Header.Set("X-Proxy-ID", "custom-proxy-1") + + d.logger.Printf("处理请求: %s %s", ctx.Req.Method, ctx.Req.URL.String()) +} + +// 响应前处理 +func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) { + if err != nil { + d.logger.Printf("请求错误: %s", err.Error()) + return + } + + if resp != nil { + // 添加自定义响应头 + resp.Header.Set("X-Proxy-Server", "GoProxy") + + // 记录响应状态 + d.logger.Printf("响应状态: %d %s", resp.StatusCode, http.StatusText(resp.StatusCode)) + } +} + +// 请求完成处理 +func (d *CustomDelegate) Finish(ctx *goproxy.Context) { + d.logger.Printf("请求完成: %s %s", ctx.Req.Method, ctx.Req.URL.String()) +} + +// 错误日志 +func (d *CustomDelegate) ErrorLog(err error) { + d.logger.Printf("错误: %s", err.Error()) +} + +func main() { + // 创建基本配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = ":8080" // 监听本地8080端口 + cfg.TargetAddr = "http://example.com" // 转发到example.com + + // 启用基本功能 + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + + // 创建自定义代理委托 + delegate := NewCustomDelegate() + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + Delegate: delegate, // 使用自定义委托 + }) + + // 启动服务器 + fmt.Println("反向代理启动在 :8080,转发到 http://example.com") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } +} diff --git a/cmd/proxy_rules.json b/cmd/proxy_rules.json new file mode 100644 index 0000000..a8d5840 --- /dev/null +++ b/cmd/proxy_rules.json @@ -0,0 +1,33 @@ +{ + "rules": [ + { + "path": "/api/", + "target": "http://api-server:8000", + "strip_prefix": true + }, + { + "path": "/admin/", + "target": "http://admin-server:9000", + "strip_prefix": false + }, + { + "path": "/static/", + "target": "http://static-server:8080", + "strip_prefix": true, + "cache": true, + "cache_ttl": 3600 + }, + { + "host": "api.example.com", + "target": "http://api-server:8000" + }, + { + "host": "admin.example.com", + "target": "http://admin-server:9000" + }, + { + "default": true, + "target": "http://default-server:8080" + } + ] +} \ No newline at end of file diff --git a/cmd/reverse_proxy_example.go b/cmd/reverse_proxy_example.go new file mode 100644 index 0000000..30d5aa4 --- /dev/null +++ b/cmd/reverse_proxy_example.go @@ -0,0 +1,175 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" + "github.com/darkit/goproxy/pkg/cache" + "github.com/darkit/goproxy/pkg/healthcheck" + "github.com/darkit/goproxy/pkg/loadbalance" + "github.com/darkit/goproxy/pkg/metrics" +) + +// 命令行参数 +var ( + // 监听地址 + listenAddr string + // 目标地址(后端服务) + targetAddr string + // 启用HTTPS + enableHTTPS bool + // TLS证书文件 + certFile string + // TLS私钥文件 + keyFile string + // 启用负载均衡 + enableLoadBalancing bool + // 负载均衡目标地址列表 + targets string + // 启用健康检查 + enableHealthCheck bool + // 健康检查间隔 + healthCheckInterval time.Duration + // 启用缓存 + enableCache bool + // 启用压缩 + enableCompression bool + // 启用CORS + enableCORS bool +) + +func init() { + // 解析命令行参数 + flag.StringVar(&listenAddr, "listen", ":8080", "监听地址") + flag.StringVar(&targetAddr, "target", "http://localhost:9090", "目标地址") + flag.BoolVar(&enableHTTPS, "https", false, "启用HTTPS") + flag.StringVar(&certFile, "cert", "", "TLS证书文件") + flag.StringVar(&keyFile, "key", "", "TLS私钥文件") + flag.BoolVar(&enableLoadBalancing, "lb", false, "启用负载均衡") + flag.StringVar(&targets, "targets", "", "负载均衡目标地址列表,用逗号分隔") + flag.BoolVar(&enableHealthCheck, "health", false, "启用健康检查") + flag.DurationVar(&healthCheckInterval, "health-interval", 10*time.Second, "健康检查间隔") + flag.BoolVar(&enableCache, "cache", false, "启用缓存") + flag.BoolVar(&enableCompression, "compression", true, "启用压缩") + flag.BoolVar(&enableCORS, "cors", false, "启用CORS") +} + +func main() { + // 解析命令行参数 + flag.Parse() + + // 创建配置 + cfg := config.DefaultConfig() + + // 设置基本配置 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = listenAddr + cfg.TargetAddr = targetAddr + cfg.DecryptHTTPS = enableHTTPS + cfg.TLSCert = certFile + cfg.TLSKey = keyFile + cfg.EnableCache = enableCache + cfg.EnableCompression = enableCompression + cfg.EnableCORS = enableCORS + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + cfg.AddXRealIP = true // 添加X-Real-IP头 + + // 健康检查配置 + cfg.EnableHealthCheck = enableHealthCheck + cfg.HealthCheckInterval = healthCheckInterval + cfg.HealthCheckTimeout = 3 * time.Second + + // 负载均衡配置 + cfg.EnableLoadBalancing = enableLoadBalancing + + // 创建代理选项 + opts := &goproxy.Options{ + Config: cfg, + } + + // 如果启用负载均衡,创建负载均衡器 + if enableLoadBalancing && targets != "" { + // 创建轮询负载均衡器 + lb := loadbalance.NewRoundRobinBalancer() + lb.Add(targets, 1) + opts.LoadBalancer = lb + + // 如果启用健康检查,创建健康检查器 + if enableHealthCheck { + // 健康检查配置 + healthCfg := &healthcheck.Config{ + Interval: healthCheckInterval, + Timeout: 3 * time.Second, + MinSuccess: 1, + MaxFails: 3, + } + healthChecker := healthcheck.NewHealthChecker(healthCfg) + healthChecker.AddTarget(targets) + opts.HealthChecker = healthChecker + + // 启动健康检查 + healthChecker.Start() + defer healthChecker.Stop() + } + } + + // 如果启用缓存,创建缓存 + if enableCache { + // 创建一个内存缓存,TTL为5分钟 + memCache := cache.NewMemoryCache(5*time.Minute, time.Second, 10000) + opts.HTTPCache = memCache + } + + // 创建指标收集器 + metricsCollector := metrics.NewSimpleMetrics() + opts.Metrics = metricsCollector + + // 创建代理 + proxy := goproxy.New(opts) + + // 创建HTTP服务器 + server := &http.Server{ + Addr: listenAddr, + Handler: proxy, + } + + // 启动HTTP服务器 + go func() { + fmt.Printf("反向代理启动在 %s,目标地址为 %s\n", listenAddr, targetAddr) + var err error + if enableHTTPS && certFile != "" && keyFile != "" { + err = server.ListenAndServeTLS(certFile, keyFile) + } else { + err = server.ListenAndServe() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("服务器启动失败: %v\n", err) + } + }() + + // 优雅退出 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + // 关闭服务器 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("服务器关闭失败: %v\n", err) + } + + fmt.Println("服务器已关闭") +} diff --git a/cmd/rules_reverse_proxy.go b/cmd/rules_reverse_proxy.go new file mode 100644 index 0000000..cb4c8ea --- /dev/null +++ b/cmd/rules_reverse_proxy.go @@ -0,0 +1,121 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 解析命令行参数 + var ( + listenAddr = flag.String("listen", ":8080", "监听地址") + rulesFile = flag.String("rules", "proxy_rules.json", "代理规则文件") + enableCache = flag.Bool("cache", false, "启用缓存") + enableCORS = flag.Bool("cors", true, "启用CORS") + ) + flag.Parse() + + // 检查规则文件是否存在 + if _, err := os.Stat(*rulesFile); os.IsNotExist(err) { + // 创建示例规则文件 + createExampleRulesFile(*rulesFile) + fmt.Printf("已创建示例规则文件: %s\n", *rulesFile) + } + + // 创建配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = *listenAddr // 监听地址 + cfg.ReverseProxyRulesFile = *rulesFile // 规则文件 + cfg.EnableCache = *enableCache // 是否启用缓存 + cfg.EnableCORS = *enableCORS // 是否启用CORS + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + cfg.AddXRealIP = true // 添加X-Real-IP头 + cfg.TargetAddr = "www.chipeak.com" + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + }) + + // 创建HTTP服务器 + server := &http.Server{ + Addr: *listenAddr, + Handler: proxy, + } + + // 优雅退出 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + + // 启动服务 + go func() { + fmt.Printf("反向代理启动在 %s,使用规则文件 %s\n", *listenAddr, *rulesFile) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("服务器启动失败: %v\n", err) + } + }() + + // 等待退出信号 + <-quit + fmt.Println("服务器正在关闭...") + + // 关闭服务器 + if err := server.Close(); err != nil { + log.Fatalf("服务器关闭失败: %v\n", err) + } + fmt.Println("服务器已关闭") +} + +// 创建示例规则文件 +func createExampleRulesFile(filename string) { + content := `{ + "rules": [ + { + "path": "/api/", + "target": "http://api-server:8000", + "strip_prefix": true + }, + { + "path": "/admin/", + "target": "http://admin-server:9000", + "strip_prefix": false + }, + { + "path": "/static/", + "target": "http://static-server:8080", + "strip_prefix": true, + "cache": true, + "cache_ttl": 3600 + }, + { + "host": "api.example.com", + "target": "http://api-server:8000" + }, + { + "host": "admin.example.com", + "target": "http://admin-server:9000" + }, + { + "default": true, + "target": "http://default-server:8080" + } + ] +}` + + err := os.WriteFile(filename, []byte(content), 0644) + if err != nil { + log.Fatalf("创建规则文件失败: %v", err) + } +} diff --git a/cmd/simple_reverse_proxy.go b/cmd/simple_reverse_proxy.go new file mode 100644 index 0000000..e52605f --- /dev/null +++ b/cmd/simple_reverse_proxy.go @@ -0,0 +1,35 @@ +package main + +import ( + "fmt" + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 创建基本配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = ":8080" // 监听本地8080端口 + cfg.TargetAddr = "http://example.com" // 转发到example.com + + // 启用一些基本功能 + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + }) + + // 启动服务器 + fmt.Println("反向代理启动在 :8080,转发到 http://example.com") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } +} diff --git a/cmd/simple_reverse_proxy_fixed.go b/cmd/simple_reverse_proxy_fixed.go new file mode 100644 index 0000000..7e5e920 --- /dev/null +++ b/cmd/simple_reverse_proxy_fixed.go @@ -0,0 +1,40 @@ +package main + +import ( + "fmt" + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 创建基本配置 + cfg := config.DefaultConfig() + + // 配置反向代理 + cfg.ReverseProxy = true // 启用反向代理模式 + cfg.ListenAddr = ":8080" // 监听本地8080端口 + cfg.TargetAddr = "http://example.com" // 转发到example.com + + // 设置HTTP头部选项 + cfg.PreserveClientIP = true // 保留客户端IP + cfg.AddXForwardedFor = true // 添加X-Forwarded-For头 + cfg.AddXRealIP = true // 添加X-Real-IP头 + + // 设置性能选项 + cfg.EnableCompression = true // 启用压缩 + cfg.EnableCache = true // 启用缓存 + + // 创建代理实例 + proxy := goproxy.New(&goproxy.Options{ + Config: cfg, + }) + + // 启动服务器 + fmt.Println("反向代理启动在 :8080,转发到 http://example.com") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..0293681 --- /dev/null +++ b/config/config.go @@ -0,0 +1,111 @@ +package config + +import ( + "log/slog" + "time" +) + +// Config 代理配置 +type Config struct { + // 基本配置 + ListenAddr string `json:"listen_addr" yaml:"listen_addr" toml:"listen_addr"` // 监听地址 + TargetAddr string `json:"target_addr" yaml:"target_addr" toml:"target_addr"` // 目标地址 + DecryptHTTPS bool `json:"decrypt_https" yaml:"decrypt_https" toml:"decrypt_https"` // 是否启用HTTPS解密 + CACert string `json:"ca_cert" yaml:"ca_cert" toml:"ca_cert"` // CA证书文件路径(用于生成动态证书) + CAKey string `json:"ca_key" yaml:"ca_key" toml:"ca_key"` // CA密钥文件路径(用于生成动态证书) + UseECDSA bool `json:"use_ecdsa" yaml:"use_ecdsa" toml:"use_ecdsa"` // 是否使用ECDSA生成证书(默认使用RSA) + TLSCert string `json:"tls_cert" yaml:"tls_cert" toml:"tls_cert"` // TLS证书文件路径 + TLSKey string `json:"tls_key" yaml:"tls_key" toml:"tls_key"` // TLS密钥文件路径 + + // 连接配置 + DisableKeepAlive bool `json:"disable_keep_alive" yaml:"disable_keep_alive" toml:"disable_keep_alive"` // 是否禁用连接复用 + RequestTimeout time.Duration `json:"request_timeout" yaml:"request_timeout" toml:"request_timeout"` // 请求超时时间 + EnableCache bool `json:"enable_cache" yaml:"enable_cache" toml:"enable_cache"` // 是否启用响应缓存 + IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout" toml:"idle_timeout"` // 连接空闲超时时间 + MaxIdleConns int `json:"max_idle_conns" yaml:"max_idle_conns" toml:"max_idle_conns"` // 最大空闲连接数 + + // 缓存配置 + DNSCacheTTL time.Duration `json:"dns_cache_ttl" yaml:"dns_cache_ttl" toml:"dns_cache_ttl"` // DNS缓存过期时间 + CacheTTL time.Duration `json:"cache_ttl" yaml:"cache_ttl" toml:"cache_ttl"` // 缓存过期时间 + + // 重试配置 + EnableRetry bool `json:"enable_retry" yaml:"enable_retry" toml:"enable_retry"` // 是否启用重试机制 + MaxRetries int `json:"max_retries" yaml:"max_retries" toml:"max_retries"` // 最大重试次数 + BaseBackoff time.Duration `json:"base_backoff" yaml:"base_backoff" toml:"base_backoff"` // 重试间隔基数 + MaxBackoff time.Duration `json:"max_backoff" yaml:"max_backoff" toml:"max_backoff"` // 最大重试间隔 + + // 限流配置 + RateLimit float64 `json:"rate_limit" yaml:"rate_limit" toml:"rate_limit"` // 每秒请求速率限制 + + // 其他配置 + EnableCORS bool `json:"enable_cors" yaml:"enable_cors" toml:"enable_cors"` // 是否自动添加CORS头 + + // 负载均衡配置 + EnableLoadBalancing bool `json:"enable_load_balancing" yaml:"enable_load_balancing" toml:"enable_load_balancing"` // 是否启用负载均衡 + Backends []string `json:"backends" yaml:"backends" toml:"backends"` // 负载均衡后端列表 + EnableRateLimit bool `json:"enable_rate_limit" yaml:"enable_rate_limit" toml:"enable_rate_limit"` // 是否启用限流 + MaxBurst int `json:"max_burst" yaml:"max_burst" toml:"max_burst"` // 并发请求峰值限制 + MaxConnections int `json:"max_connections" yaml:"max_connections" toml:"max_connections"` // 最大连接数 + EnableConnectionPool bool `json:"enable_connection_pool" yaml:"enable_connection_pool" toml:"enable_connection_pool"` // 是否启用连接池 + ConnectionPoolSize int `json:"connection_pool_size" yaml:"connection_pool_size" toml:"connection_pool_size"` // 连接池大小 + EnableHealthCheck bool `json:"enable_health_check" yaml:"enable_health_check" toml:"enable_health_check"` // 是否启用健康检查 + HealthCheckInterval time.Duration `json:"health_check_interval" yaml:"health_check_interval" toml:"health_check_interval"` // 健康检查间隔时间 + HealthCheckTimeout time.Duration `json:"health_check_timeout" yaml:"health_check_timeout" toml:"health_check_timeout"` // 健康检查超时时间 + EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics" toml:"enable_metrics"` // 是否启用监控指标 + EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing" toml:"enable_tracing"` // 是否启用请求追踪 + WebSocketIntercept bool `json:"websocket_intercept" yaml:"websocket_intercept" toml:"websocket_intercept"` // 是否拦截WebSocket + ReverseProxy bool `json:"reverse_proxy" yaml:"reverse_proxy" toml:"reverse_proxy"` // 是否作为反向代理 + ReverseProxyRulesFile string `json:"reverse_proxy_rules_file" yaml:"reverse_proxy_rules_file" toml:"reverse_proxy_rules_file"` // 反向代理规则文件路径 + PreserveClientIP bool `json:"preserve_client_ip" yaml:"preserve_client_ip" toml:"preserve_client_ip"` // 是否保留客户端IP + EnableCompression bool `json:"enable_compression" yaml:"enable_compression" toml:"enable_compression"` // 是否启用压缩 + RewriteHostHeader bool `json:"rewrite_host_header" yaml:"rewrite_host_header" toml:"rewrite_host_header"` // 重写Host头 + AddXForwardedFor bool `json:"add_x_forwarded_for" yaml:"add_x_forwarded_for" toml:"add_x_forwarded_for"` // 是否添加X-Forwarded-For头 + AddXRealIP bool `json:"add_x_real_ip" yaml:"add_x_real_ip" toml:"add_x_real_ip"` // 是否添加X-Real-IP头 + SupportWebSocketUpgrade bool `json:"support_websocket_upgrade" yaml:"support_websocket_upgrade" toml:"support_websocket_upgrade"` // 是否支持Websocket升级 + InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过TLS证书验证 + Logger *slog.Logger `json:"-" yaml:"-" toml:"-"` // 日志记录器 +} + +// DefaultConfig 返回默认配置 +func DefaultConfig() *Config { + return &Config{ + ListenAddr: ":8080", + DecryptHTTPS: false, + UseECDSA: false, + RequestTimeout: 30 * time.Second, + IdleTimeout: 90 * time.Second, + EnableCache: false, + MaxIdleConns: 100, + DNSCacheTTL: 5 * time.Minute, + CacheTTL: 5 * time.Minute, + EnableRetry: true, + MaxRetries: 3, + BaseBackoff: time.Second, + MaxBackoff: 10 * time.Second, + RateLimit: 0, // 0 表示不限流 + EnableCORS: true, + EnableLoadBalancing: false, + Backends: []string{}, + EnableRateLimit: false, + MaxBurst: 50, + MaxConnections: 1000, + EnableConnectionPool: true, + ConnectionPoolSize: 100, + EnableHealthCheck: false, + HealthCheckInterval: 10 * time.Second, + HealthCheckTimeout: 5 * time.Second, + EnableMetrics: false, + EnableTracing: false, + WebSocketIntercept: false, + ReverseProxy: false, + ReverseProxyRulesFile: "", + PreserveClientIP: true, + EnableCompression: false, + RewriteHostHeader: false, + AddXForwardedFor: true, + AddXRealIP: true, + SupportWebSocketUpgrade: true, + InsecureSkipVerify: false, + Logger: slog.Default(), + } +} diff --git a/config/hot_reload.go b/config/hot_reload.go new file mode 100644 index 0000000..c9aefc9 --- /dev/null +++ b/config/hot_reload.go @@ -0,0 +1,117 @@ +package config + +import ( + "encoding/json" + "log/slog" + "os" + "path/filepath" + "sync" + "time" + + "github.com/fsnotify/fsnotify" +) + +// HotReloadConfig 热更新配置 +type HotReloadConfig struct { + // 配置文件路径 + ConfigPath string + // 配置更新回调函数 + OnUpdate func(*Config) + // 配置锁 + mu sync.RWMutex + // 当前配置 + current *Config + // 文件监视器 + watcher *fsnotify.Watcher + // 停止信号 + stopChan chan struct{} +} + +// NewHotReloadConfig 创建热更新配置 +func NewHotReloadConfig(configPath string, onUpdate func(*Config)) (*HotReloadConfig, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + hrc := &HotReloadConfig{ + ConfigPath: configPath, + OnUpdate: onUpdate, + watcher: watcher, + stopChan: make(chan struct{}), + } + + // 加载初始配置 + if err := hrc.loadConfig(); err != nil { + watcher.Close() + return nil, err + } + + // 启动文件监视 + go hrc.watch() + + return hrc, nil +} + +// loadConfig 加载配置 +func (hrc *HotReloadConfig) loadConfig() error { + data, err := os.ReadFile(hrc.ConfigPath) + if err != nil { + return err + } + + var config Config + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + hrc.mu.Lock() + hrc.current = &config + hrc.mu.Unlock() + + if hrc.OnUpdate != nil { + hrc.OnUpdate(&config) + } + + return nil +} + +// watch 监视配置文件变化 +func (hrc *HotReloadConfig) watch() { + // 添加配置文件目录到监视 + configDir := filepath.Dir(hrc.ConfigPath) + if err := hrc.watcher.Add(configDir); err != nil { + slog.Error("添加配置目录到监视失败", "error", err) + return + } + + for { + select { + case event := <-hrc.watcher.Events: + if event.Name == hrc.ConfigPath { + // 等待文件写入完成 + time.Sleep(100 * time.Millisecond) + if err := hrc.loadConfig(); err != nil { + slog.Error("重新加载配置失败", "error", err) + } + } + case err := <-hrc.watcher.Errors: + slog.Error("配置文件监视错误", "error", err) + case <-hrc.stopChan: + hrc.watcher.Close() + return + } + } +} + +// Get 获取当前配置 +func (hrc *HotReloadConfig) Get() *Config { + hrc.mu.RLock() + defer hrc.mu.RUnlock() + return hrc.current +} + +// Stop 停止热更新 +func (hrc *HotReloadConfig) Stop() { + close(hrc.stopChan) +} diff --git a/conn_buffer.go b/conn_buffer.go new file mode 100644 index 0000000..aa983ba --- /dev/null +++ b/conn_buffer.go @@ -0,0 +1,134 @@ +package goproxy + +import ( + "bufio" + "io" + "net" + "time" +) + +// ConnBuffer 连接缓冲区 +// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口 +type ConnBuffer struct { + // 底层连接 + conn net.Conn + // 缓冲读取器 + reader *bufio.Reader +} + +// NewConnBuffer 创建连接缓冲区 +func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer { + if reader == nil { + reader = bufio.NewReader(conn) + } + return &ConnBuffer{ + conn: conn, + reader: reader, + } +} + +// Read 从连接读取数据 +func (c *ConnBuffer) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// Write 向连接写入数据 +func (c *ConnBuffer) Write(b []byte) (int, error) { + return c.conn.Write(b) +} + +// Close 关闭连接 +func (c *ConnBuffer) Close() error { + return c.conn.Close() +} + +// LocalAddr 获取本地地址 +func (c *ConnBuffer) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr 获取远程地址 +func (c *ConnBuffer) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline 设置读写超时 +func (c *ConnBuffer) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline 设置读取超时 +func (c *ConnBuffer) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline 设置写入超时 +func (c *ConnBuffer) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// BufferReader 获取缓冲读取器 +func (c *ConnBuffer) BufferReader() *bufio.Reader { + return c.reader +} + +// Peek 查看缓冲区中的数据,但不消费 +func (c *ConnBuffer) Peek(n int) ([]byte, error) { + return c.reader.Peek(n) +} + +// ReadByte 读取一个字节 +func (c *ConnBuffer) ReadByte() (byte, error) { + return c.reader.ReadByte() +} + +// UnreadByte 将最后读取的字节放回缓冲区 +func (c *ConnBuffer) UnreadByte() error { + return c.reader.UnreadByte() +} + +// ReadLine 读取一行数据 +func (c *ConnBuffer) ReadLine() (string, error) { + line, isPrefix, err := c.reader.ReadLine() + if err != nil { + return "", err + } + + // 如果一行数据没有读取完整,继续读取 + if isPrefix { + var buf []byte + buf = append(buf, line...) + for isPrefix && err == nil { + line, isPrefix, err = c.reader.ReadLine() + if err != nil { + return "", err + } + buf = append(buf, line...) + } + return string(buf), nil + } + + return string(line), nil +} + +// ReadN 读取指定字节数的数据 +func (c *ConnBuffer) ReadN(n int) ([]byte, error) { + buf := make([]byte, n) + _, err := io.ReadFull(c.reader, buf) + if err != nil { + return nil, err + } + return buf, nil +} + +// Flush 刷新缓冲区 +func (c *ConnBuffer) Flush() error { + // 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作 + return nil +} + +// Reset 重置连接缓冲区 +func (c *ConnBuffer) Reset(conn net.Conn) { + c.conn = conn + c.reader = bufio.NewReader(conn) +} diff --git a/context.go b/context.go new file mode 100644 index 0000000..dc84995 --- /dev/null +++ b/context.go @@ -0,0 +1,132 @@ +package goproxy + +import ( + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// Context 代理上下文 +// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据 +type Context struct { + // 原始请求 + Req *http.Request + // 请求开始时间 + StartTime time.Time + // 上下文数据,用于在各个处理阶段传递数据 + Data map[interface{}]interface{} + // 是否是隧道代理 + TunnelProxy bool + // 请求ID + RequestID string + // 目标地址 + TargetAddr string + // 上级代理地址 + ParentProxyURL *url.URL + // 是否中断执行 + abort bool + // 请求标签,用于标记请求类型 + Tags []string + // 是否已中止 + aborted bool + // 互斥锁 + mu sync.RWMutex +} + +// IsHTTPS 是否是HTTPS请求 +func (c *Context) IsHTTPS() bool { + return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect +} + +// defaultPorts 默认端口映射 +var defaultPorts = map[string]string{ + "https": "443", + "http": "80", + "": "80", +} + +// WebSocketURL 获取WebSocket URL +func (c *Context) WebSocketURL() *url.URL { + u := *c.Req.URL + if c.IsHTTPS() { + u.Scheme = "wss" + } else { + u.Scheme = "ws" + } + return &u +} + +// Addr 获取请求地址 +func (c *Context) Addr() string { + addr := c.Req.Host + + if !strings.Contains(c.Req.URL.Host, ":") { + addr += ":" + defaultPorts[c.Req.URL.Scheme] + } + + return addr +} + +// AddTag 添加请求标签 +func (c *Context) AddTag(tag string) { + c.Tags = append(c.Tags, tag) +} + +// HasTag 检查是否包含指定标签 +func (c *Context) HasTag(tag string) bool { + for _, t := range c.Tags { + if t == tag { + return true + } + } + return false +} + +// Abort 中断执行 +func (c *Context) Abort() { + c.aborted = true +} + +// IsAborted 是否已中断执行 +func (c *Context) IsAborted() bool { + return c.aborted +} + +// Reset 重置上下文 +func (c *Context) Reset(req *http.Request) { + c.Req = req + c.StartTime = time.Now() + c.Data = make(map[interface{}]interface{}) + c.abort = false + c.TunnelProxy = false + c.Tags = make([]string, 0) + c.RequestID = "" + c.TargetAddr = "" + c.ParentProxyURL = nil + c.aborted = false +} + +// Set 设置数据 +func (c *Context) Set(key, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Data == nil { + c.Data = make(map[interface{}]interface{}) + } + c.Data[key] = value +} + +// Get 获取数据 +func (c *Context) Get(key interface{}) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Data == nil { + return nil, false + } + val, ok := c.Data[key] + return val, ok +} diff --git a/coverage.out b/coverage.out new file mode 100644 index 0000000..53f3923 --- /dev/null +++ b/coverage.out @@ -0,0 +1,225 @@ +mode: set +github.com/darkit/goproxy/pkg/dns/config.go:24.55,26.16 2 1 +github.com/darkit/goproxy/pkg/dns/config.go:26.16,28.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:29.2,33.42 4 1 +github.com/darkit/goproxy/pkg/dns/config.go:33.42,35.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:37.2,37.12 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:41.56,43.16 2 1 +github.com/darkit/goproxy/pkg/dns/config.go:43.16,45.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:46.2,55.47 4 1 +github.com/darkit/goproxy/pkg/dns/config.go:55.47,57.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:59.2,59.20 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:63.61,65.16 2 1 +github.com/darkit/goproxy/pkg/dns/config.go:65.16,67.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:68.2,78.21 5 1 +github.com/darkit/goproxy/pkg/dns/config.go:78.21,83.49 3 1 +github.com/darkit/goproxy/pkg/dns/config.go:83.49,84.12 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:87.3,88.22 2 1 +github.com/darkit/goproxy/pkg/dns/config.go:88.22,89.12 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:92.3,97.21 4 1 +github.com/darkit/goproxy/pkg/dns/config.go:97.21,98.12 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:101.3,106.20 4 1 +github.com/darkit/goproxy/pkg/dns/config.go:106.20,108.4 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:110.3,110.34 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:110.34,112.38 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:112.38,113.10 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:117.4,117.34 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:121.2,121.38 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:121.38,123.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:125.2,125.20 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:129.63,131.20 2 1 +github.com/darkit/goproxy/pkg/dns/config.go:131.20,133.3 1 1 +github.com/darkit/goproxy/pkg/dns/config.go:133.8,135.3 1 0 +github.com/darkit/goproxy/pkg/dns/config.go:137.2,145.17 3 1 +github.com/darkit/goproxy/pkg/dns/config.go:149.43,151.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:19.43,26.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:29.61,32.2 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:35.65,38.2 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:41.55,44.2 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:47.94,50.16 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:50.16,54.3 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:57.2,58.16 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:58.16,60.3 1 0 +github.com/darkit/goproxy/pkg/dns/dialer.go:63.2,63.17 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:63.17,65.17 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:65.17,67.4 1 0 +github.com/darkit/goproxy/pkg/dns/dialer.go:70.3,70.24 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:70.24,72.4 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:74.3,74.21 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:78.2,84.71 2 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:88.66,90.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:93.26,95.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:98.98,100.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:103.52,105.2 1 1 +github.com/darkit/goproxy/pkg/dns/dialer.go:108.111,110.2 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:16.39,21.2 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:24.57,29.2 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:32.49,34.30 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:34.30,36.22 2 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:36.22,38.4 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:41.3,45.17 3 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:45.17,47.4 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:49.3,49.44 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:53.2,53.28 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:57.36,58.16 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:58.16,60.3 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:61.2,61.13 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:65.70,66.16 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:66.16,68.3 1 1 +github.com/darkit/goproxy/pkg/dns/endpoint.go:69.2,69.48 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:73.53,84.33 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:84.33,86.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:88.2,88.10 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:92.63,94.16 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:94.16,96.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:97.2,97.25 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:101.91,106.64 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:106.64,109.3 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:112.2,112.60 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:112.60,115.3 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:118.2,118.36 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:118.36,119.41 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:119.41,122.4 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:124.2,127.16 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:127.16,129.17 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:129.17,131.4 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:134.3,135.28 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:135.28,136.39 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:136.39,138.10 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:142.3,142.15 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:142.15,144.4 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:147.3,148.22 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:148.22,150.4 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:153.3,160.23 4 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:163.2,163.76 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:167.91,168.25 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:168.25,170.3 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:173.2,173.25 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:173.25,175.22 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:175.22,177.4 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:178.3,178.18 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:182.2,183.22 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:184.18,186.68 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:187.14,189.68 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:190.22,192.26 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:195.2,195.40 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:195.40,197.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:198.2,198.17 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:202.65,206.39 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:206.39,207.48 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:207.48,209.4 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:212.2,212.12 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:216.64,218.41 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:218.41,220.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:223.2,223.38 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:223.38,225.29 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:225.29,226.12 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:230.3,230.38 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:230.38,232.4 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:235.2,235.13 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:239.53,241.2 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:244.71,245.28 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:245.28,247.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:249.2,253.50 4 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:253.50,255.31 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:255.31,256.54 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:256.54,258.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:260.3,260.54 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:261.8,263.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:264.2,264.12 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:268.71,270.2 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:273.89,274.28 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:274.28,276.3 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:279.2,279.44 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:279.44,281.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:284.2,290.39 4 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:290.39,291.37 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:291.37,294.37 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:294.37,295.55 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:295.55,297.6 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:300.4,301.14 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:306.2,315.12 3 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:319.52,324.34 3 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:324.34,327.3 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:330.2,330.39 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:330.39,331.27 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:331.27,335.4 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:338.2,338.44 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:342.74,347.42 3 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:347.42,349.31 2 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:349.31,350.36 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:350.36,352.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:354.3,354.29 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:354.29,356.4 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:356.9,358.4 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:359.3,359.13 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:363.2,363.39 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:363.39,364.27 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:364.27,366.37 2 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:366.37,367.37 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:367.37,369.6 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:371.4,371.30 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:371.30,373.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:373.10,375.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:376.4,376.14 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:380.2,380.44 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:384.34,391.2 5 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:397.41,398.33 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:398.33,400.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:404.40,405.33 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:405.33,407.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:411.67,412.33 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:412.33,414.3 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:418.71,422.35 3 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:422.35,424.34 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:424.34,426.18 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:426.18,428.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:430.4,430.39 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:430.39,432.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:435.4,436.41 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:436.41,437.29 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:437.29,439.39 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:439.39,440.57 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:440.57,442.13 2 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:445.6,445.16 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:445.16,447.7 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:448.6,448.11 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:452.4,452.14 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:452.14,460.5 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:462.9,465.18 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:465.18,467.5 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:469.4,469.39 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:469.39,471.5 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:474.4,474.52 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:474.52,476.33 2 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:476.33,477.56 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:477.56,479.12 2 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:482.5,482.15 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:482.15,484.6 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:485.10,487.5 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:492.2,494.12 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:498.46,500.46 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:500.46,509.39 5 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:509.39,511.4 1 0 +github.com/darkit/goproxy/pkg/dns/resolver.go:514.3,514.45 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:519.41,521.29 2 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:521.29,522.18 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:522.18,524.4 1 1 +github.com/darkit/goproxy/pkg/dns/resolver.go:526.2,526.14 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:10.47,11.37 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:11.37,13.3 1 0 +github.com/darkit/goproxy/pkg/dns/wildcard.go:15.2,16.21 2 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:16.21,18.3 1 0 +github.com/darkit/goproxy/pkg/dns/wildcard.go:20.2,23.54 3 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:23.54,25.3 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:27.2,27.54 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:27.54,29.3 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:31.2,31.13 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:41.72,45.2 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:48.65,52.42 3 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:52.42,53.35 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:53.35,55.4 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:58.2,58.70 1 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:62.62,67.2 4 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:70.57,75.2 4 1 +github.com/darkit/goproxy/pkg/dns/wildcard.go:78.42,83.2 4 1 diff --git a/delegate.go b/delegate.go new file mode 100644 index 0000000..092f008 --- /dev/null +++ b/delegate.go @@ -0,0 +1,120 @@ +package goproxy + +import ( + "net/http" + "net/url" +) + +// Delegate 代理委托接口 +// 定义了代理处理请求的各个阶段的回调方法 +type Delegate interface { + // Connect 连接事件 + Connect(ctx *Context, rw http.ResponseWriter) + + // Auth 认证事件 + Auth(ctx *Context, rw http.ResponseWriter) + + // BeforeRequest 请求前事件 + BeforeRequest(ctx *Context) + + // BeforeResponse 响应前事件 + BeforeResponse(ctx *Context, resp *http.Response, err error) + + // WebSocketSendMessage websocket发送消息拦截 + WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte) + + // WebSocketReceiveMessage websocket接收消息拦截 + WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte) + + // ParentProxy 获取上级代理 + ParentProxy(req *http.Request) (*url.URL, error) + + // ErrorLog 错误日志 + ErrorLog(err error) + + // Finish 完成事件 + Finish(ctx *Context) + + // 以下是反向代理相关的方法 + + // ResolveBackend 解析后端服务器 + // 在反向代理模式下,根据请求确定应该转发到哪个后端服务器 + ResolveBackend(req *http.Request) (string, error) + + // ModifyRequest 修改请求 + // 在反向代理模式下,可以修改发往后端服务器的请求 + ModifyRequest(req *http.Request) + + // ModifyResponse 修改响应 + // 在反向代理模式下,可以修改来自后端服务器的响应 + ModifyResponse(resp *http.Response) error + + // HandleError 处理错误 + // 在反向代理模式下,可以自定义错误处理逻辑 + HandleError(rw http.ResponseWriter, req *http.Request, err error) +} + +// DefaultDelegate 默认代理委托 +type DefaultDelegate struct{} + +// Connect 连接事件 +func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {} + +// Auth 认证事件 +func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {} + +// BeforeRequest 请求前事件 +func (d *DefaultDelegate) BeforeRequest(ctx *Context) {} + +// BeforeResponse 响应前事件 +func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) { + // 默认实现不做任何处理 +} + +// ParentProxy 获取上级代理 +func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) { + // 默认实现不使用上级代理 + return nil, nil +} + +// WebSocketSendMessage websocket发送消息拦截 +func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) { + // 默认实现不做任何处理 +} + +// WebSocketReceiveMessage websocket接收消息拦截 +func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) { + // 默认实现不做任何处理 +} + +// ErrorLog 错误日志 +func (d *DefaultDelegate) ErrorLog(err error) { + // 默认实现不做任何处理 +} + +// Finish 完成事件 +func (d *DefaultDelegate) Finish(ctx *Context) { + // 默认实现不做任何处理 +} + +// ResolveBackend 解析后端服务器 +func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) { + // 默认实现返回请求中的主机 + return req.Host, nil +} + +// ModifyRequest 修改请求 +func (d *DefaultDelegate) ModifyRequest(req *http.Request) { + // 默认实现不做任何修改 +} + +// ModifyResponse 修改响应 +func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error { + // 默认实现不做任何修改 + return nil +} + +// HandleError 处理错误 +func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) { + http.Error(rw, err.Error(), http.StatusBadGateway) +} diff --git a/examples/auth_proxy/main.go b/examples/auth_proxy/main.go new file mode 100644 index 0000000..8c55222 --- /dev/null +++ b/examples/auth_proxy/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/auth" +) + +func main() { + // 创建认证系统 + auths := auth.NewAuth("1234") + auths.Authenticate("admin", "password") + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithAuth(auths), + ) + + // 启动代理服务器 + log.Println("认证代理服务器启动在 :8080") + log.Println("认证配置:") + log.Printf("- 用户名: admin\n") + log.Printf("- 密码: password\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/custom_dns_https_proxy/main.go b/examples/custom_dns_https_proxy/main.go new file mode 100644 index 0000000..32fde5f --- /dev/null +++ b/examples/custom_dns_https_proxy/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/dns" +) + +// CustomDNSHTTPSDelegate 自定义 DNS HTTPS 代理委托 +type CustomDNSHTTPSDelegate struct { + goproxy.DefaultDelegate + dnsResolver *dns.CustomResolver +} + +// ResolveBackend 解析后端服务器 +func (d *CustomDNSHTTPSDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.dnsResolver.Resolve(req.URL.Host) +} + +func main() { + // 创建证书缓存 + certCache := &goproxy.MemCertCache{} + + // 创建自定义 DNS 解析器 + resolver := dns.NewResolver(dns.WithFallback(true)) + + // 添加 DNS 记录 + resolver.LoadFromMap(map[string]string{ + "example.com": "http://backend1.example.com", + "test.com": "http://backend2.test.com", + }) + + // 创建自定义 DNS HTTPS 代理委托 + delegate := &CustomDNSHTTPSDelegate{ + dnsResolver: resolver, + } + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + goproxy.WithDecryptHTTPS(certCache), + goproxy.WithCACertAndKey("ca.crt", "ca.key"), + goproxy.WithEnableECDSA(true), + ) + + // 启动代理服务器 + log.Println("自定义 DNS HTTPS 代理服务器启动在 :8443") + log.Println("配置说明:") + log.Printf("- 支持 HTTPS 解密\n") + log.Printf("- 使用 ECDSA 证书\n") + log.Println("DNS 配置:") + log.Printf("- example.com -> backend1.example.com\n") + log.Printf("- test.com -> backend2.test.com\n") + if err := http.ListenAndServeTLS(":8443", "server.crt", "server.key", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/custom_dns_proxy/main.go b/examples/custom_dns_proxy/main.go new file mode 100644 index 0000000..e3be68d --- /dev/null +++ b/examples/custom_dns_proxy/main.go @@ -0,0 +1,49 @@ +package custom_dns_proxy + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/dns" +) + +// CustomDNSDelegate 自定义 DNS 代理委托 +type CustomDNSDelegate struct { + goproxy.DefaultDelegate + dnsResolver dns.Resolver +} + +// ResolveBackend 解析后端服务器 +func (d *CustomDNSDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.dnsResolver.Resolve(req.URL.Host) +} + +// RunCustomDNSProxy 运行自定义 DNS 代理服务器 +func RunCustomDNSProxy() error { + // 创建自定义 DNS 解析器 + resolver := dns.NewResolver(dns.WithFallback(true)) + + // 添加 DNS 记录 + resolver.LoadFromMap(map[string]string{ + "example.com": "http://backend1.example.com", + "test.com": "http://backend2.test.com", + }) + + // 创建自定义 DNS 代理委托 + delegate := &CustomDNSDelegate{ + dnsResolver: resolver, + } + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + ) + + // 启动代理服务器 + log.Println("自定义 DNS 代理服务器启动在 :8080") + log.Println("DNS 配置:") + log.Printf("- example.com -> backend1.example.com\n") + log.Printf("- test.com -> backend2.test.com\n") + return http.ListenAndServe(":8080", proxy) +} diff --git a/examples/custom_dns_reverse_proxy/README.md b/examples/custom_dns_reverse_proxy/README.md new file mode 100644 index 0000000..5e163a5 --- /dev/null +++ b/examples/custom_dns_reverse_proxy/README.md @@ -0,0 +1,124 @@ +# 自定义DNS反向代理示例 + +这个示例展示了如何使用 goproxy 创建一个支持自定义 DNS 解析和负载均衡的反向代理服务器。 + +## 功能特点 + +- 支持自定义 DNS 解析规则 +- 支持多后端服务器负载均衡 +- 支持通配符域名解析 +- 支持 DNS 缓存 +- 支持优雅关闭 +- 支持配置文件 + +## 使用方法 + +1. 编译示例: + +```bash +go build -o custom_dns_proxy main.go +``` + +2. 运行示例: + +```bash +# 使用默认配置 +./custom_dns_proxy + +# 指定配置文件 +./custom_dns_proxy -config config.json +``` + +## 配置说明 + +### 命令行参数 + +- `-addr`: 监听地址,默认为 ":8080" +- `-config`: 配置文件路径,默认为 "config.json" + +### 配置文件 + +配置文件支持 JSON 格式,包含以下主要配置项: + +```json +{ + "dns": { + "records": { + "api.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ], + "*.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080" + ] + }, + "fallback": true, + "ttl": 300, + "strategy": "round-robin" + }, + "proxy": { + "listen_addr": ":8080", + "enable_https": false, + "enable_websocket": true, + "enable_compression": true, + "enable_cors": true, + "preserve_client_ip": true, + "add_x_forwarded_for": true, + "add_x_real_ip": true, + "insecure_skip_verify": false, + "enable_health_check": false, + "health_check_interval": 30, + "health_check_timeout": 5, + "enable_retry": true, + "max_retries": 3, + "retry_backoff": 1, + "max_retry_backoff": 10, + "enable_metrics": true, + "enable_tracing": false, + "websocket_intercept": false, + "dns_cache_ttl": 300, + "enable_cache": true, + "cache_ttl": 300, + "enable_connection_pool": true, + "connection_pool_size": 100, + "idle_timeout": 60, + "request_timeout": 30 + } +} +``` + +## 负载均衡策略 + +当前示例使用轮询(Round Robin)负载均衡策略,支持以下特性: + +1. 多后端服务器轮询 +2. 自动故障转移 +3. 健康检查 +4. 连接池管理 + +## 示例用法 + +1. 启动代理服务器: + +```bash +./custom_dns_proxy -addr :8080 +``` + +2. 使用 curl 测试: + +```bash +# 测试 API 访问 +curl http://api.example.com:8080/api/v1/users + +# 测试通配符域名 +curl http://test.example.com:8080/api/v1/users +``` + +## 注意事项 + +1. 确保配置文件中的 DNS 记录正确配置 +2. 确保后端服务器正常运行 +3. 建议在生产环境中启用 HTTPS +4. 根据实际需求调整连接池大小和超时设置 \ No newline at end of file diff --git a/examples/custom_dns_reverse_proxy/config.json b/examples/custom_dns_reverse_proxy/config.json new file mode 100644 index 0000000..30ce8c2 --- /dev/null +++ b/examples/custom_dns_reverse_proxy/config.json @@ -0,0 +1,23 @@ +{ + "dns": { + "records": { + "*.shabi.in": [ + "192.168.1.1:80", + "192.168.1.6:80", + "192.168.1.252:80" + ], + "api.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ], + "www.*.shabi.in": [ + "192.168.1.1:80", + "192.168.1.252:80" + ] + }, + "fallback": true, + "ttl": 300, + "strategy": "round-robin" + } +} \ No newline at end of file diff --git a/examples/custom_dns_reverse_proxy/main.go b/examples/custom_dns_reverse_proxy/main.go new file mode 100644 index 0000000..1dafe75 --- /dev/null +++ b/examples/custom_dns_reverse_proxy/main.go @@ -0,0 +1,188 @@ +package main + +import ( + "encoding/json" + "errors" + "flag" + "log/slog" + "net" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + + "github.com/darkit/goproxy/pkg/dns" + "github.com/darkit/goproxy/pkg/reverse" +) + +// 命令行参数 +var configFile = flag.String("config", "config.json", "配置文件路径") + +// Config 配置文件结构 +type Config struct { + DNS struct { + Records map[string][]string `json:"records"` // 普通记录和泛解析记录 + Fallback bool `json:"fallback"` // 是否回退到系统DNS + TTL int `json:"ttl"` // 缓存TTL,单位为秒 + Strategy string `json:"strategy"` // 负载均衡策略:round-robin, random, first-available + } `json:"dns" yaml:"dns" toml:"dns"` +} + +func main() { + flag.Parse() + + // 创建日志记录器 + logger := slog.Default() + + // 加载配置文件 + config := &Config{} + if *configFile != "" { + data, err := os.ReadFile(*configFile) + if err != nil { + logger.Error("读取配置文件失败", "file", *configFile, "error", err) + os.Exit(1) + } + + if err := json.Unmarshal(data, config); err != nil { + logger.Error("解析配置文件失败", "error", err) + os.Exit(1) + } + + logger.Info("成功加载配置文件", "file", *configFile) + } + + // 选择负载均衡策略 + //var strategy dns.LoadBalanceStrategy + //switch config.DNS.Strategy { + //case "random": + // strategy = dns.Random + //case "first-available": + // strategy = dns.FirstAvailable + //default: + // strategy = dns.RoundRobin + //} + + // 创建反向代理配置 + cfg := reverse.DefaultConfig() + cfg.TargetAddr = "www.shabi.in" + //cfg.DNSResolver = dns.NewResolver( + // dns.WithFallback(config.DNS.Fallback), // 是否回退到系统DNS + // dns.WithLoadBalanceStrategy(strategy), // 使用配置的负载均衡策略 + // dns.WithTTL(time.Duration(config.DNS.TTL)*time.Second), // 设置DNS缓存TTL + //) // 创建DNS解析器 + + // 添加DNS记录,支持多个地址 + for domain, addrs := range config.DNS.Records { + for _, addr := range addrs { + host, port, err := net.SplitHostPort(addr) + if err != nil { + // 地址没有端口,将整个地址作为IP使用 + if !dns.IsWildcardDomain(domain) { + cfg.DNSResolver.Add(domain, addr) + } else { + cfg.DNSResolver.AddWildcard(domain, addr) + } + } else { + tPort, err := strconv.Atoi(port) + if err != nil { + logger.Error("无效的端口", "domain", domain, "addr", addr, "port", port, "error", err) + continue + } + if !dns.IsWildcardDomain(domain) { + cfg.DNSResolver.AddWithPort(domain, host, tPort) + } else { + cfg.DNSResolver.AddWildcardWithPort(domain, host, tPort) + } + } + } + } + + // 创建反向代理服务器 + proxy, err := reverse.New(cfg) + if err != nil { + logger.Error("创建反向代理服务器失败", "error", err) + os.Exit(1) + } + + // 创建HTTP服务器 + server := &http.Server{ + Addr: cfg.ListenAddr, + Handler: proxy, + } + + // 启动服务器 + go func() { + logger.Info("启动反向代理服务器", "addr", cfg.ListenAddr) + if err = server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("服务器运行错误", "error", err) + } + }() + + // 等待中断信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + // 优雅关闭 + logger.Info("正在关闭服务器...") + if err := server.Close(); err != nil { + logger.Error("关闭服务器失败", "error", err) + } +} + +// 示例配置文件 config.json: +/* +{ + "dns": { + "records": { + "api.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ], + "*.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080" + ] + }, + "fallback": true, + "ttl": 300, + "strategy": "round-robin" + } +} +*/ + +// 示例配置文件 config.yaml: +/* +dns: + records: + api.example.com: + - 192.168.1.1:8080 + - 192.168.1.2:8080 + - 192.168.1.3:8080 + "*.example.com": + - 192.168.1.1:8080 + - 192.168.1.2:8080 + fallback: true + ttl: 300 + strategy: round-robin +*/ + +// 示例配置文件 config.toml: +/* +[dns] + [dns.records] + api.example.com = [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ] + "*.example.com" = [ + "192.168.1.1:8080", + "192.168.1.2:8080" + ] + fallback = true + ttl = 300 + strategy = round-robin +*/ diff --git a/examples/custom_port_proxy/main.go b/examples/custom_port_proxy/main.go new file mode 100644 index 0000000..340f170 --- /dev/null +++ b/examples/custom_port_proxy/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "flag" + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +var port = flag.String("port", "8080", "代理服务器端口") + +func main() { + flag.Parse() + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithConfig(config.DefaultConfig()), + ) + + // 启动代理服务器 + log.Printf("代理服务器启动在 :%s\n", *port) + if err := http.ListenAndServe(":"+*port, proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/goproxy/main.go b/examples/goproxy/main.go new file mode 100644 index 0000000..7c640d9 --- /dev/null +++ b/examples/goproxy/main.go @@ -0,0 +1,19 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建代理 + p := goproxy.New(nil) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/http_to_https_proxy/main.go b/examples/http_to_https_proxy/main.go new file mode 100644 index 0000000..37eb1b8 --- /dev/null +++ b/examples/http_to_https_proxy/main.go @@ -0,0 +1,41 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +// HTTPToHTTPSDelegate HTTP 到 HTTPS 代理委托 +type HTTPToHTTPSDelegate struct { + goproxy.DefaultDelegate +} + +// BeforeRequest 请求前事件 +func (d *HTTPToHTTPSDelegate) BeforeRequest(ctx *goproxy.Context) { + // 将 HTTP 请求转换为 HTTPS + if ctx.Req.URL.Scheme == "http" { + ctx.Req.URL.Scheme = "https" + } +} + +func main() { + // 创建 HTTP 到 HTTPS 代理委托 + delegate := &HTTPToHTTPSDelegate{} + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + goproxy.WithDecryptHTTPS(&goproxy.MemCertCache{}), + ) + + // 启动代理服务器 + log.Println("HTTP 到 HTTPS 代理服务器启动在 :8080") + log.Println("配置说明:") + log.Printf("- 自动将 HTTP 请求转换为 HTTPS\n") + log.Printf("- 支持 HTTPS 解密\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/https_to_https_proxy/main.go b/examples/https_to_https_proxy/main.go new file mode 100644 index 0000000..caa512f --- /dev/null +++ b/examples/https_to_https_proxy/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建证书缓存 + certCache := &goproxy.MemCertCache{} + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDecryptHTTPS(certCache), + goproxy.WithCACertAndKey("ca.crt", "ca.key"), + goproxy.WithEnableECDSA(true), + ) + + // 启动代理服务器 + log.Println("HTTPS 到 HTTPS 代理服务器启动在 :8443") + log.Println("配置说明:") + log.Printf("- 支持 HTTPS 解密\n") + log.Printf("- 使用 ECDSA 证书\n") + log.Printf("- 最低 TLS 版本: 1.2\n") + log.Printf("- 最高 TLS 版本: 1.3\n") + if err := http.ListenAndServeTLS(":8443", "server.crt", "server.key", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/README.md b/examples/other/README.md new file mode 100644 index 0000000..6b1557f --- /dev/null +++ b/examples/other/README.md @@ -0,0 +1,89 @@ +# GoProxy 示例 + +本目录包含了 GoProxy 库的各种使用示例,展示了不同的代理功能和配置选项。 + +## 示例列表 + +1. [forward_proxy.go](forward_proxy.go) - 基本正向代理 + - 展示最基本的正向代理功能 + - 适用于简单的 HTTP 代理需求 + +2. [https_proxy.go](https_proxy.go) - HTTPS 解密代理 + - 支持 HTTPS 解密(中间人模式) + - 需要配置 CA 证书 + - 支持 ECDSA 证书算法 + +3. [reverse_proxy.go](reverse_proxy.go) - 反向代理 + - 支持反向代理功能 + - 支持 URL 重写 + - 支持 X-Forwarded-For 和 X-Real-IP 头 + +4. [custom_delegate.go](custom_delegate.go) - 自定义委托代理 + - 展示如何自定义代理行为 + - 支持请求和响应的自定义处理 + - 包含详细的日志记录 + +5. [load_balance.go](load_balance.go) - 负载均衡代理 + - 支持多后端服务器 + - 使用轮询算法进行负载均衡 + - 包含健康检查功能 + +6. [metrics_proxy.go](metrics_proxy.go) - 监控指标代理 + - 支持 Prometheus 格式的监控指标 + - 提供详细的性能统计 + - 包含独立的指标服务器 + +7. [cache_proxy.go](cache_proxy.go) - 缓存代理 + - 支持 HTTP 响应缓存 + - 使用内存缓存存储 + - 可配置缓存策略 + +8. [auth_proxy.go](auth_proxy.go) - 认证代理 + - 支持基本认证 + - 可配置用户名和密码 + - 保护代理访问 + +9. [websocket_proxy.go](websocket_proxy.go) - WebSocket 代理 + - 支持 WebSocket 协议 + - 支持 WebSocket 拦截 + - 适用于实时通信场景 + +10. [rate_limit_proxy.go](rate_limit_proxy.go) - 速率限制代理 + - 支持请求速率限制 + - 可配置最大请求速率 + - 防止服务器过载 + +## 使用方法 + +1. 编译示例: +```bash +go build -o forward_proxy examples/forward_proxy.go +``` + +2. 运行示例: +```bash +./forward_proxy +``` + +3. 配置代理: +- 在浏览器中设置代理服务器为 `localhost:8080` +- 或使用环境变量: +```bash +export http_proxy=http://localhost:8080 +export https_proxy=http://localhost:8080 +``` + +## 注意事项 + +1. 使用 HTTPS 解密功能时,需要安装 CA 证书 +2. 某些功能可能需要额外的配置(如证书、密钥等) +3. 建议在生产环境中使用更安全的配置 +4. 监控指标默认在 9090 端口提供 + +## 开发建议 + +1. 根据实际需求选择合适的示例作为起点 +2. 可以组合多个功能来满足复杂需求 +3. 注意处理错误和异常情况 +4. 在生产环境中添加适当的日志记录 +5. 考虑添加监控和告警机制 \ No newline at end of file diff --git a/examples/other/cache/cache_proxy.go b/examples/other/cache/cache_proxy.go new file mode 100644 index 0000000..351e004 --- /dev/null +++ b/examples/other/cache/cache_proxy.go @@ -0,0 +1,32 @@ +package main + +import ( + "log" + "net/http" + "time" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/cache" +) + +func main() { + // 创建内存缓存 + memCache := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000) + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithHTTPCache(memCache), + goproxy.WithRequestTimeout(30*time.Second), + goproxy.WithIdleTimeout(60*time.Second), + ) + + // 启动代理服务器 + log.Println("缓存代理服务器启动在 :8080") + log.Println("缓存配置:") + log.Printf("- 缓存类型: 内存缓存\n") + log.Printf("- 请求超时: 30s\n") + log.Printf("- 空闲超时: 60s\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/custom_delegate/custom_delegate.go b/examples/other/custom_delegate/custom_delegate.go new file mode 100644 index 0000000..8dc0473 --- /dev/null +++ b/examples/other/custom_delegate/custom_delegate.go @@ -0,0 +1,56 @@ +package main + +import ( + "log" + "net/http" + "time" + + "github.com/darkit/goproxy" +) + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + goproxy.DefaultDelegate +} + +// BeforeRequest 请求前事件 +func (d *CustomDelegate) BeforeRequest(ctx *goproxy.Context) { + log.Printf("收到请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String()) + log.Printf("请求头: %v\n", ctx.Req.Header) +} + +// BeforeResponse 响应前事件 +func (d *CustomDelegate) BeforeResponse(ctx *goproxy.Context, resp *http.Response, err error) { + if err != nil { + log.Printf("响应错误: %v\n", err) + return + } + log.Printf("收到响应: %d %s\n", resp.StatusCode, resp.Status) + log.Printf("响应头: %v\n", resp.Header) +} + +// AfterResponse 响应后事件 +func (d *CustomDelegate) AfterResponse(ctx *goproxy.Context, resp *http.Response) { + log.Printf("请求完成: %s %s, 耗时: %v\n", + ctx.Req.Method, + ctx.Req.URL.String(), + time.Since(ctx.StartTime)) +} + +func main() { + // 创建自定义委托 + delegate := &CustomDelegate{} + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + goproxy.WithRequestTimeout(30*time.Second), + goproxy.WithIdleTimeout(60*time.Second), + ) + + // 启动代理服务器 + log.Println("自定义委托代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/forward/forward_proxy.go b/examples/other/forward/forward_proxy.go new file mode 100644 index 0000000..74180c4 --- /dev/null +++ b/examples/other/forward/forward_proxy.go @@ -0,0 +1,22 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" +) + +func main() { + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithConfig(config.DefaultConfig()), + ) + + // 启动代理服务器 + log.Println("正向代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/https/https_proxy.go b/examples/other/https/https_proxy.go new file mode 100644 index 0000000..1b31b0a --- /dev/null +++ b/examples/other/https/https_proxy.go @@ -0,0 +1,27 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建证书缓存 + certCache := &goproxy.MemCertCache{} + + // 创建代理实例,启用 HTTPS 解密 + proxy := goproxy.NewProxy( + goproxy.WithDecryptHTTPS(certCache), + goproxy.WithCACertAndKey("ca.crt", "ca.key"), + goproxy.WithEnableECDSA(true), + ) + + // 启动代理服务器 + log.Println("HTTPS 解密代理服务器启动在 :8080") + log.Println("请确保已安装 CA 证书 (ca.crt)") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/load_balance/load_balance.go b/examples/other/load_balance/load_balance.go new file mode 100644 index 0000000..a0c7e2c --- /dev/null +++ b/examples/other/load_balance/load_balance.go @@ -0,0 +1,67 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/loadbalance" +) + +// LoadBalanceDelegate 负载均衡委托 +type LoadBalanceDelegate struct { + goproxy.DefaultDelegate + lb loadbalance.LoadBalancer +} + +// ResolveBackend 解析后端服务器 +func (d *LoadBalanceDelegate) ResolveBackend(req *http.Request) (string, error) { + backend, err := d.lb.Next(req.URL.Host) + if err != nil { + return "", err + } + if backend == nil { + return "", nil + } + return backend.String(), nil +} + +// 运行负载均衡代理服务器 +func main() { + // 创建后端服务器列表 + backends := []string{ + "http://backend1:8081", + "http://backend2:8082", + "http://backend3:8083", + } + + // 创建负载均衡器 + lb := loadbalance.NewRoundRobinBalancer() + + // 添加后端服务器 + for _, backend := range backends { + if err := lb.Add(backend, 1); err != nil { + log.Fatalf("Error: %v", err) + return + } + } + + // 创建负载均衡委托 + delegate := &LoadBalanceDelegate{ + lb: lb, + } + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + goproxy.WithLoadBalancer(lb), + ) + + // 启动代理服务器 + log.Println("负载均衡代理服务器启动在 :8080") + log.Println("后端服务器列表:") + for _, backend := range backends { + log.Printf("- %s\n", backend) + } + http.ListenAndServe(":8080", proxy) +} diff --git a/examples/other/metrics/metrics_proxy.go b/examples/other/metrics/metrics_proxy.go new file mode 100644 index 0000000..853bb9c --- /dev/null +++ b/examples/other/metrics/metrics_proxy.go @@ -0,0 +1,67 @@ +package main + +import ( + "log" + "net/http" + "time" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/config" + "github.com/darkit/goproxy/pkg/metrics" +) + +// CustomDelegate 自定义代理委托 +type CustomDelegate struct { + goproxy.DefaultDelegate +} + +// Connect 处理连接事件 +func (d *CustomDelegate) Connect(ctx *goproxy.Context, rw http.ResponseWriter) { + log.Printf("新的连接: %s", ctx.Req.RemoteAddr) +} + +// 运行监控代理服务器 +func main() { + // 创建监控指标收集器 + metricsCollector := metrics.NewSimpleMetrics() + + // 创建基础配置 + cfg := config.DefaultConfig() + cfg.EnableRetry = true + cfg.MaxRetries = 3 + cfg.BaseBackoff = time.Second + cfg.MaxBackoff = 10 * time.Second + cfg.EnableCache = true + cfg.CacheTTL = 5 * time.Minute + cfg.ConnectionPoolSize = 100 + cfg.RateLimit = 1000 + cfg.EnableMetrics = true + cfg.DecryptHTTPS = false // 禁用HTTPS解密 + cfg.InsecureSkipVerify = true // 跳过证书验证 + cfg.RequestTimeout = 30 * time.Second + cfg.IdleTimeout = 60 * time.Second + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithMetrics(metricsCollector), + goproxy.WithRequestTimeout(30*time.Second), + goproxy.WithIdleTimeout(60*time.Second), + goproxy.WithConfig(cfg), + goproxy.WithDelegate(&CustomDelegate{}), + ) + + // 启动监控指标服务器 + go func() { + http.Handle("/metrics", metricsCollector.GetHandler()) + log.Println("监控指标服务器启动在 :9090") + if err := http.ListenAndServe(":9090", nil); err != nil { + log.Printf("监控指标服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("监控代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Printf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/rate_limit/rate_limit_proxy.go b/examples/other/rate_limit/rate_limit_proxy.go new file mode 100644 index 0000000..9e0225e --- /dev/null +++ b/examples/other/rate_limit/rate_limit_proxy.go @@ -0,0 +1,23 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithRateLimit(100), // 每秒最多处理 100 个请求 + ) + + // 启动代理服务器 + log.Println("速率限制代理服务器启动在 :8080") + log.Println("速率限制配置:") + log.Printf("- 最大请求速率: 100 请求/秒\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/other/reverse_proxy/reverse_proxy.go b/examples/other/reverse_proxy/reverse_proxy.go new file mode 100644 index 0000000..c9b65ee --- /dev/null +++ b/examples/other/reverse_proxy/reverse_proxy.go @@ -0,0 +1,45 @@ +package main + +import ( + "log" + "net/http" + "net/url" + + "github.com/darkit/goproxy" +) + +// ReverseProxyDelegate 反向代理委托 +type ReverseProxyDelegate struct { + goproxy.DefaultDelegate + backendURL *url.URL +} + +// ResolveBackend 解析后端服务器 +func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.backendURL.String(), nil +} + +// RunReverseProxy 运行反向代理服务器 +func main() { + // 解析后端服务器地址 + backendURL, err := url.Parse("http://localhost:8081") + if err != nil { + return + } + + // 创建反向代理委托 + delegate := &ReverseProxyDelegate{ + backendURL: backendURL, + } + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + goproxy.WithReverseProxy(true), + ) + + // 启动代理服务器 + log.Println("反向代理服务器启动在 :8080") + log.Printf("请求将被转发到: %s\n", backendURL.String()) + http.ListenAndServe(":8080", proxy) +} diff --git a/examples/other/websocket/websocket_proxy.go b/examples/other/websocket/websocket_proxy.go new file mode 100644 index 0000000..56488a0 --- /dev/null +++ b/examples/other/websocket/websocket_proxy.go @@ -0,0 +1,23 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" +) + +func main() { + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithEnableWebsocketIntercept(), + ) + + // 启动代理服务器 + log.Println("WebSocket 代理服务器启动在 :8080") + log.Println("WebSocket 配置:") + log.Printf("- WebSocket 拦截: 启用\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/examples/rewriter/http_server.go b/examples/rewriter/http_server.go new file mode 100644 index 0000000..459dcb4 --- /dev/null +++ b/examples/rewriter/http_server.go @@ -0,0 +1,121 @@ +package main + +import ( + "log" + "net/http" + "net/http/httputil" + "net/url" + + "github.com/darkit/goproxy/pkg/rewriter" +) + +// RewriteReverseProxy 重写反向代理 +// 在请求发送到后端服务器前重写URL +type RewriteReverseProxy struct { + // 后端服务器地址 + Target *url.URL + // URL重写器 + Rewriter *rewriter.Rewriter + // 反向代理 + Proxy *httputil.ReverseProxy +} + +// NewRewriteReverseProxy 创建重写反向代理 +func NewRewriteReverseProxy(target string, rewriteRulesFile string) (*RewriteReverseProxy, error) { + // 解析目标URL + targetURL, err := url.Parse(target) + if err != nil { + return nil, err + } + + // 创建重写器 + rw := rewriter.NewRewriter() + + // 加载重写规则 + if rewriteRulesFile != "" { + if err := rw.LoadRulesFromFile(rewriteRulesFile); err != nil { + return nil, err + } + } + + // 创建反向代理 + proxy := httputil.NewSingleHostReverseProxy(targetURL) + + // 修改默认的Director函数,添加URL重写逻辑 + defaultDirector := proxy.Director + proxy.Director = func(req *http.Request) { + // 先执行默认的Director函数 + defaultDirector(req) + + // 然后执行URL重写 + rw.Rewrite(req) + + // 记录重写后的URL + log.Printf("请求重写: %s -> %s", req.URL.Path, req.URL.String()) + } + + // 修改响应处理器,重写响应头 + proxy.ModifyResponse = func(resp *http.Response) error { + // 重写响应 + rw.RewriteResponse(resp, targetURL.Host) + return nil + } + + return &RewriteReverseProxy{ + Target: targetURL, + Rewriter: rw, + Proxy: proxy, + }, nil +} + +// ServeHTTP 实现http.Handler接口 +func (rrp *RewriteReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + log.Printf("收到请求: %s %s", r.Method, r.URL.Path) + rrp.Proxy.ServeHTTP(w, r) +} + +// RewriteMiddleware 中间件:将重写器应用到处理链中 +func RewriteMiddleware(rw *rewriter.Rewriter, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 重写URL + originalPath := r.URL.Path + rw.Rewrite(r) + + if originalPath != r.URL.Path { + log.Printf("URL重写: %s -> %s", originalPath, r.URL.Path) + } + + // 继续处理链 + next.ServeHTTP(w, r) + }) +} + +func main() { + // 创建重写器 + rw := rewriter.NewRewriter() + + // 加载重写规则 + if err := rw.LoadRulesFromFile("rules.json"); err != nil { + log.Fatalf("加载重写规则失败: %v", err) + } + + // 创建反向代理 + proxy, err := NewRewriteReverseProxy("http://192.168.1.212:80", "rules.json") + if err != nil { + log.Fatalf("创建反向代理失败: %v", err) + } + + // 创建静态文件服务器(模拟普通Web服务器) + fileServer := http.FileServer(http.Dir("./static")) + + // 使用中间件包装文件服务器 + rewrittenFileServer := RewriteMiddleware(rw, fileServer) + + // 设置处理函数 + http.Handle("/api/", rewrittenFileServer) + http.Handle("/", proxy) + + // 启动服务器 + log.Println("服务器启动在 :8080...") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/examples/rewriter/main.go b/examples/rewriter/main.go new file mode 100644 index 0000000..e601861 --- /dev/null +++ b/examples/rewriter/main.go @@ -0,0 +1,88 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + + "github.com/darkit/goproxy/pkg/rewriter" +) + +func main() { + // 创建URL重写器 + rw := rewriter.NewRewriter() + + // 添加一些规则 + rw.AddRule("/api/v1/", "/api/v2/", false) + rw.AddRuleWithDescription("/old/(.*)/page", "/new/$1/page", true, "旧页面重定向") + + // 从文件加载规则 + exampleDir, _ := os.Getwd() + jsonRulesPath := filepath.Join(exampleDir, "rules.json") + textRulesPath := filepath.Join(exampleDir, "rules.txt") + + // 先加载JSON格式规则 + fmt.Println("从JSON文件加载规则...") + if err := rw.LoadRulesFromFile(jsonRulesPath); err != nil { + fmt.Printf("从JSON文件加载规则失败: %v\n", err) + } + + // 然后加载文本格式规则 + fmt.Println("从文本文件加载规则...") + if err := rw.LoadRulesFromFile(textRulesPath); err != nil { + fmt.Printf("从文本文件加载规则失败: %v\n", err) + } + + // 打印所有规则 + fmt.Println("\n当前规则列表:") + for i, rule := range rw.GetRules() { + status := "启用" + if !rule.Enabled { + status = "禁用" + } + fmt.Printf("[%d] %s -> %s [%s] (%s)\n", + i, rule.Pattern, rule.Replacement, + map[bool]string{true: "正则", false: "前缀"}[rule.UseRegex], + status) + } + + // 测试重写 + fmt.Println("\n测试URL重写:") + testURLs := []string{ + "/api/v1/users", + "/old/profile/page", + "/legacy-files/document.pdf", + "/en/about/company", + } + + for _, url := range testURLs { + // 创建测试请求 + req := httptest.NewRequest("GET", "http://example.com"+url, nil) + + // 重写URL + fmt.Printf("原始URL: %s\n", req.URL.Path) + rw.Rewrite(req) + fmt.Printf("重写后: %s\n\n", req.URL.Path) + } + + // 测试响应重写 + fmt.Println("测试响应重写:") + resp := &http.Response{ + Header: http.Header{}, + } + resp.Header.Set("Location", "http://backend.example.com/old/profile/page") + fmt.Printf("原始Location: %s\n", resp.Header.Get("Location")) + rw.RewriteResponse(resp, "backend.example.com") + fmt.Printf("重写后Location: %s\n", resp.Header.Get("Location")) + + // 保存规则到新文件 + newRulesPath := filepath.Join(exampleDir, "new_rules.json") + fmt.Printf("\n保存规则到文件: %s\n", newRulesPath) + if err := rw.SaveRulesToFile(newRulesPath); err != nil { + fmt.Printf("保存规则失败: %v\n", err) + } else { + fmt.Println("规则已保存") + } +} diff --git a/examples/rewriter/new_rules.json b/examples/rewriter/new_rules.json new file mode 100644 index 0000000..37b93b5 --- /dev/null +++ b/examples/rewriter/new_rules.json @@ -0,0 +1,63 @@ +[ + { + "pattern": "/api/v1/", + "replacement": "/api/v2/", + "use_regex": false, + "enabled": true + }, + { + "pattern": "/old/(.*)/page", + "replacement": "/new/$1/page", + "use_regex": true, + "description": "旧页面重定向", + "enabled": true + }, + { + "pattern": "/api/v1/", + "replacement": "/api/v2/", + "use_regex": false, + "description": "将API v1请求重定向到v2", + "enabled": true + }, + { + "pattern": "/old/(.*)/page", + "replacement": "/new/$1/page", + "use_regex": true, + "description": "旧页面格式重定向到新格式", + "enabled": true + }, + { + "pattern": "/legacy-files/", + "replacement": "/files/", + "use_regex": false, + "description": "旧文件路径重定向" + }, + { + "pattern": "^/(en|zh|ja)/(.*)", + "replacement": "/$2?lang=$1", + "use_regex": true, + "description": "将语言路径转换为查询参数", + "enabled": true + }, + { + "pattern": "/api/v1/", + "replacement": "/api/v2/", + "use_regex": false, + "description": "将API v1请求重定向到v2", + "enabled": true + }, + { + "pattern": "/old/(.*)/page", + "replacement": "/new/$1/page", + "use_regex": true, + "description": "旧页面格式重定向到新格式", + "enabled": true + }, + { + "pattern": "^/(en|zh|ja)/(.*)", + "replacement": "/$2?lang=$1", + "use_regex": true, + "description": "将语言路径转换为查询参数", + "enabled": true + } +] diff --git a/examples/rewriter/rules.json b/examples/rewriter/rules.json new file mode 100644 index 0000000..db88ab9 --- /dev/null +++ b/examples/rewriter/rules.json @@ -0,0 +1,30 @@ +[ + { + "pattern": "/api/v1/", + "replacement": "/api/v2/", + "use_regex": false, + "description": "将API v1请求重定向到v2", + "enabled": true + }, + { + "pattern": "/old/(.*)/page", + "replacement": "/new/$1/page", + "use_regex": true, + "description": "旧页面格式重定向到新格式", + "enabled": true + }, + { + "pattern": "/legacy-files/", + "replacement": "/files/", + "use_regex": false, + "description": "旧文件路径重定向", + "enabled": false + }, + { + "pattern": "^/(en|zh|ja)/(.*)", + "replacement": "/$2?lang=$1", + "use_regex": true, + "description": "将语言路径转换为查询参数", + "enabled": true + } +] \ No newline at end of file diff --git a/examples/rewriter/rules.txt b/examples/rewriter/rules.txt new file mode 100644 index 0000000..e61c9f1 --- /dev/null +++ b/examples/rewriter/rules.txt @@ -0,0 +1,7 @@ +# URL重写规则示例 +# 格式: pattern replacement [regex] [#description] + +/api/v1/ /api/v2/ # 将API v1请求重定向到v2 +/old/(.*)/page /new/$1/page regex # 旧页面格式重定向到新格式 +# /legacy-files/ /files/ # 旧文件路径重定向 (已禁用) +^/(en|zh|ja)/(.*) /$2?lang=$1 regex # 将语言路径转换为查询参数 \ No newline at end of file diff --git a/examples/rewriter/web_admin.go b/examples/rewriter/web_admin.go new file mode 100644 index 0000000..8f7a532 --- /dev/null +++ b/examples/rewriter/web_admin.go @@ -0,0 +1,259 @@ +package main + +import ( + "encoding/json" + "fmt" + "html/template" + "log" + "net/http" + "strconv" + + "github.com/darkit/goproxy/pkg/rewriter" +) + +// 全局重写器实例 +var rw *rewriter.Rewriter + +// 规则配置文件路径 +var rulesFile = "rules.json" + +// AdminHandler Web管理页面处理器 +func AdminHandler(w http.ResponseWriter, r *http.Request) { + // 定义模板 + tmpl := ` + + + + + URL重写规则管理 + + + +

URL重写规则管理

+ +

当前规则

+ + + + + + + + + + + {{range $i, $rule := .Rules}} + + + + + + + + + + {{end}} +
索引匹配模式替换模式类型描述状态操作
{{$i}}{{$rule.Pattern}}{{$rule.Replacement}}{{if $rule.UseRegex}}正则表达式{{else}}前缀匹配{{end}}{{$rule.Description}} + {{if $rule.Enabled}}启用{{else}}禁用{{end}} + + {{if $rule.Enabled}} + + {{else}} + + {{end}} + +
+ +

添加新规则

+
+
+ + +
+
+ + +
+
+ +
+
+ + +
+ +
+ +

保存配置

+
+ +
+ + +` + + // 解析模板 + t, err := template.New("admin").Parse(tmpl) + if err != nil { + http.Error(w, fmt.Sprintf("模板解析错误: %v", err), http.StatusInternalServerError) + return + } + + // 渲染模板 + data := struct { + Rules []*rewriter.RewriteRule + }{ + Rules: rw.GetRules(), + } + + if err := t.Execute(w, data); err != nil { + http.Error(w, fmt.Sprintf("模板渲染错误: %v", err), http.StatusInternalServerError) + } +} + +// AddRuleHandler 添加规则处理器 +func AddRuleHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + r.ParseForm() + + pattern := r.Form.Get("pattern") + replacement := r.Form.Get("replacement") + useRegex := r.Form.Get("regex") == "on" + description := r.Form.Get("description") + + if pattern == "" || replacement == "" { + http.Error(w, "匹配模式和替换模式不能为空", http.StatusBadRequest) + return + } + + if err := rw.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil { + http.Error(w, fmt.Sprintf("添加规则失败: %v", err), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +// EnableRuleHandler 启用规则处理器 +func EnableRuleHandler(w http.ResponseWriter, r *http.Request) { + // 解析URL中的规则索引 + indexStr := r.URL.Path[len("/enable/"):] + index, err := strconv.Atoi(indexStr) + if err != nil { + http.Error(w, "无效的规则索引", http.StatusBadRequest) + return + } + + if err := rw.EnableRule(index); err != nil { + http.Error(w, fmt.Sprintf("启用规则失败: %v", err), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +// DisableRuleHandler 禁用规则处理器 +func DisableRuleHandler(w http.ResponseWriter, r *http.Request) { + // 解析URL中的规则索引 + indexStr := r.URL.Path[len("/disable/"):] + index, err := strconv.Atoi(indexStr) + if err != nil { + http.Error(w, "无效的规则索引", http.StatusBadRequest) + return + } + + if err := rw.DisableRule(index); err != nil { + http.Error(w, fmt.Sprintf("禁用规则失败: %v", err), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +// RemoveRuleHandler 删除规则处理器 +func RemoveRuleHandler(w http.ResponseWriter, r *http.Request) { + // 解析URL中的规则索引 + indexStr := r.URL.Path[len("/remove/"):] + index, err := strconv.Atoi(indexStr) + if err != nil { + http.Error(w, "无效的规则索引", http.StatusBadRequest) + return + } + + if err := rw.RemoveRule(index); err != nil { + http.Error(w, fmt.Sprintf("删除规则失败: %v", err), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +// SaveRulesHandler 保存规则处理器 +func SaveRulesHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + if err := rw.SaveRulesToFile(rulesFile); err != nil { + http.Error(w, fmt.Sprintf("保存规则失败: %v", err), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +// APIHandler API处理器,返回JSON格式的规则列表 +func APIHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + rules := rw.GetRules() + json.NewEncoder(w).Encode(rules) +} + +func main() { + // 创建重写器 + rw = rewriter.NewRewriter() + + // 尝试从文件加载规则 + if err := rw.LoadRulesFromFile(rulesFile); err != nil { + log.Printf("加载规则失败: %v, 将使用空规则集", err) + } + + // 注册处理器 + http.HandleFunc("/", AdminHandler) + http.HandleFunc("/add", AddRuleHandler) + http.HandleFunc("/enable/", EnableRuleHandler) + http.HandleFunc("/disable/", DisableRuleHandler) + http.HandleFunc("/remove/", RemoveRuleHandler) + http.HandleFunc("/save", SaveRulesHandler) + http.HandleFunc("/api/rules", APIHandler) + + // 启动服务器 + log.Println("管理界面启动在 :8080...") + log.Fatal(http.ListenAndServe(":8080", nil)) +} diff --git a/examples/rule/dns_config.json b/examples/rule/dns_config.json new file mode 100644 index 0000000..3e86c97 --- /dev/null +++ b/examples/rule/dns_config.json @@ -0,0 +1,83 @@ +{ + "records": { + "www.github.com": "140.82.121.3", + "github.com": "140.82.121.4", + "api.github.com": "140.82.121.5", + "assets-cdn.github.com": "185.199.108.153", + "collector.github.com": "140.82.121.6", + "codeload.github.com": "140.82.121.9", + "gist.github.com": "140.82.121.4", + "raw.githubusercontent.com": "185.199.108.133", + "s3.amazonaws.com": "52.216.162.69", + "www.google.com": "142.250.4.147", + "google.com": "142.250.4.139", + "ajax.googleapis.com": "142.250.4.95", + "fonts.googleapis.com": "142.250.4.95", + "www.baidu.com": "110.242.68.66", + "baidu.com": "39.156.66.10", + + "custom-port-example.com": "192.168.1.10:8080", + "api.custom-port-example.com": "192.168.1.11:8443", + "dev.example.com": "127.0.0.1:3000", + "api.dev.example.com": "127.0.0.1:3001", + "db.dev.example.com": "127.0.0.1:5432" + }, + "use_fallback": true, + "ttl": 300, + "rules": [ + { + "id": "dns-rule-1", + "type": "dns", + "priority": 100, + "pattern": "example.com", + "match_type": "exact", + "enabled": true, + "targets": [ + { + "ip": "192.168.1.100", + "port": 80 + }, + { + "ip": "192.168.1.101", + "port": 80 + } + ] + }, + { + "id": "dns-rule-2", + "type": "dns", + "priority": 90, + "pattern": "*.example.com", + "match_type": "wildcard", + "enabled": true, + "targets": [ + { + "ip": "192.168.1.102", + "port": 80 + } + ] + }, + { + "id": "route-rule-1", + "type": "route", + "priority": 100, + "pattern": "/api/v1", + "match_type": "path", + "enabled": true, + "target": "http://api.example.com", + "header_modifier": { + "X-Forwarded-Host": "api.example.com", + "X-Real-IP": "${client_ip}" + } + }, + { + "id": "rewrite-rule-1", + "type": "rewrite", + "priority": 90, + "pattern": "/old/(.*)", + "match_type": "regex", + "enabled": true, + "replacement": "/new/$1" + } + ] +} \ No newline at end of file diff --git a/examples/rule/hosts.txt b/examples/rule/hosts.txt new file mode 100644 index 0000000..d53e8ed --- /dev/null +++ b/examples/rule/hosts.txt @@ -0,0 +1,60 @@ +# GitHub相关域名解析 +140.82.121.3 www.github.com +140.82.121.4 github.com +140.82.121.5 api.github.com +185.199.108.153 assets-cdn.github.com +140.82.121.6 collector.github.com +140.82.121.9 codeload.github.com +140.82.121.4 gist.github.com +185.199.108.133 raw.githubusercontent.com + +# AWS服务 +52.216.162.69 s3.amazonaws.com + +# Google相关域名解析 +142.250.4.147 www.google.com +142.250.4.139 google.com +142.250.4.95 ajax.googleapis.com +142.250.4.95 fonts.googleapis.com + +# 百度相关域名解析 +110.242.68.66 www.baidu.com +39.156.66.10 baidu.com + +# 自定义端口示例 +192.168.1.10:8080 custom-port-example.com +192.168.1.11:8443 api.custom-port-example.com + +# 本地开发环境 +127.0.0.1 localhost +127.0.0.1:3000 dev.example.com +127.0.0.1:3001 api.dev.example.com +127.0.0.1:5432 db.dev.example.com + +# 基本格式:IP 域名 [端口] +192.168.1.100 example.com +192.168.1.101 api.example.com +192.168.1.102 test.example.com:8080 + +# 支持通配符域名 +192.168.1.103 *.example.com +192.168.1.104 *.api.example.com + +# 支持多个域名 +192.168.1.105 app1.example.com app2.example.com app3.example.com + +# 支持注释 +# 开发环境 +192.168.1.106 dev.example.com +# 测试环境 +192.168.1.107 test.example.com +# 生产环境 +192.168.1.108 prod.example.com + +# 支持IPv6 +2001:db8::1 ipv6.example.com +2001:db8::2 ipv6-api.example.com:8080 + +# 支持混合IPv4和IPv6 +192.168.1.109 mixed.example.com +2001:db8::3 mixed.example.com \ No newline at end of file diff --git a/examples/rule/wildcard_dns_config.json b/examples/rule/wildcard_dns_config.json new file mode 100644 index 0000000..4c65389 --- /dev/null +++ b/examples/rule/wildcard_dns_config.json @@ -0,0 +1,23 @@ +{ + "records": { + "example.com": "93.184.216.34", + "api.example.com": "93.184.216.35:8443", + + "*.github.com": "140.82.121.3", + "github.com": "140.82.121.4", + "api.github.com": "140.82.121.5", + + "*.s3.amazonaws.com": "52.216.162.69", + "cdn-*.example.org": "203.0.113.10", + + "*.dev.local": "127.0.0.1:3000", + "api.*.dev.local": "127.0.0.1:3001", + "db.*.dev.local": "127.0.0.1:5432", + + "*.test.example.com": "192.168.1.100", + "*.staging.example.com": "192.168.1.101", + "*.production.example.com": "192.168.1.102" + }, + "use_fallback": true, + "ttl": 300 +} \ No newline at end of file diff --git a/examples/rule/wildcard_hosts.txt b/examples/rule/wildcard_hosts.txt new file mode 100644 index 0000000..e434178 --- /dev/null +++ b/examples/rule/wildcard_hosts.txt @@ -0,0 +1,27 @@ +# 精确匹配记录 +93.184.216.34 example.com +93.184.216.35:8443 api.example.com + +# GitHub相关泛解析 +140.82.121.3 *.github.com +140.82.121.4 github.com +140.82.121.5 api.github.com + +# AWS服务泛解析 +52.216.162.69 *.s3.amazonaws.com +203.0.113.10 cdn-*.example.org + +# 本地开发环境泛解析 +127.0.0.1:3000 *.dev.local +127.0.0.1:3001 api.*.dev.local +127.0.0.1:5432 db.*.dev.local + +# 多环境测试泛解析 +192.168.1.100 *.test.example.com +192.168.1.101 *.staging.example.com +192.168.1.102 *.production.example.com + +# 特定子域名泛解析示例 +10.0.0.1 *.api.service.com +10.0.0.2 *.auth.service.com +10.0.0.3 *.cdn.service.com \ No newline at end of file diff --git a/examples/scripts/generate_cert.ps1 b/examples/scripts/generate_cert.ps1 new file mode 100644 index 0000000..b43c8af --- /dev/null +++ b/examples/scripts/generate_cert.ps1 @@ -0,0 +1,84 @@ +# PowerShell 脚本:生成自签名证书 + +# 处理命令行参数 +param( + [Parameter(HelpMessage="证书有效期(天数)")] + [int]$days = 365, + + [Parameter(HelpMessage="证书主题")] + [string]$subject = "CN=localhost,OU=Test,O=GoProxy,L=Shanghai,S=Shanghai,C=CN", + + [Parameter(HelpMessage="公用名(CN)")] + [string]$cn = "", + + [Parameter(HelpMessage="显示帮助信息")] + [switch]$help +) + +# 帮助信息 +function Show-Help { + Write-Host "生成自签名证书" + Write-Host + Write-Host "用法: .\generate_cert.ps1 [选项]" + Write-Host + Write-Host "选项:" + Write-Host " -help 显示此帮助信息" + Write-Host " -days DAYS 证书有效期(天数),默认: 365" + Write-Host " -subject SUB 证书主题,默认: $subject" + Write-Host " -cn CN 公用名(CN),将替换主题中的CN,默认: localhost" + Write-Host + Write-Host "示例:" + Write-Host " .\generate_cert.ps1 -days 730 -cn example.com" + Write-Host +} + +# 如果请求帮助,显示帮助信息并退出 +if ($help) { + Show-Help + exit 0 +} + +# 如果指定了CN,替换主题中的CN部分 +if ($cn -ne "") { + $subject = $subject -replace "CN=[^,]*", "CN=$cn" +} + +Write-Host "生成自签名证书..." +Write-Host "有效期: $days 天" +Write-Host "主题: $subject" + +# 检查OpenSSL是否可用 +$openssl = Get-Command "openssl" -ErrorAction SilentlyContinue +if (-not $openssl) { + Write-Host "错误: 未找到OpenSSL命令。请安装OpenSSL并确保它在PATH环境变量中。" -ForegroundColor Red + Write-Host "您可以从以下地址下载OpenSSL for Windows: https://slproweb.com/products/Win32OpenSSL.html" -ForegroundColor Yellow + exit 1 +} + +try { + # 生成私钥 + Write-Host "正在生成私钥..." -ForegroundColor Cyan + & openssl genrsa -out server.key 2048 + + # 生成证书请求 + Write-Host "正在生成证书请求..." -ForegroundColor Cyan + & openssl req -new -key server.key -out server.csr -subj $subject.Replace(",", "/") + + # 生成自签名证书 + Write-Host "正在生成自签名证书..." -ForegroundColor Cyan + & openssl x509 -req -days $days -in server.csr -signkey server.key -out server.crt + + # 删除证书请求文件 + Remove-Item server.csr -Force + + Write-Host "完成!已生成以下文件:" -ForegroundColor Green + Write-Host " - server.key: 私钥" -ForegroundColor Green + Write-Host " - server.crt: 证书" -ForegroundColor Green + Write-Host + Write-Host "您可以使用这些文件启动HTTPS代理:" -ForegroundColor Cyan + Write-Host "go run cmd/custom_dns_https_proxy/main.go -cert server.crt -key server.key" -ForegroundColor Cyan +} +catch { + Write-Host "错误: 生成证书时发生错误: $_" -ForegroundColor Red + exit 1 +} diff --git a/examples/scripts/generate_cert.sh b/examples/scripts/generate_cert.sh new file mode 100644 index 0000000..836f096 --- /dev/null +++ b/examples/scripts/generate_cert.sh @@ -0,0 +1,81 @@ +#!/bin/bash + +# 退出时如果有任何命令失败 +set -e + +# 默认值 +DAYS=365 +SUBJECT="/C=CN/ST=Shanghai/L=Shanghai/O=GoProxy/OU=Test/CN=localhost" + +# 帮助信息 +function show_help { + echo "生成自签名证书" + echo + echo "用法: $0 [选项]" + echo + echo "选项:" + echo " -h, --help 显示此帮助信息" + echo " -d, --days DAYS 证书有效期(天数),默认: 365" + echo " -s, --subject SUB 证书主题,默认: $SUBJECT" + echo " -c, --cn CN 公用名(CN),将替换主题中的CN,默认: localhost" + echo + echo "示例:" + echo " $0 --days 730 --cn example.com" + echo +} + +# 处理命令行参数 +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + -h|--help) + show_help + exit 0 + ;; + -d|--days) + DAYS="$2" + shift + shift + ;; + -s|--subject) + SUBJECT="$2" + shift + shift + ;; + -c|--cn) + # 替换主题中的CN部分 + SUBJECT=$(echo $SUBJECT | sed "s/CN=[^\/]*/CN=$2/") + shift + shift + ;; + *) + echo "未知选项: $1" + show_help + exit 1 + ;; + esac +done + +echo "生成自签名证书..." +echo "有效期: $DAYS 天" +echo "主题: $SUBJECT" + +# 生成私钥 +openssl genrsa -out server.key 2048 + +# 生成证书请求 +openssl req -new -key server.key -out server.csr -subj "$SUBJECT" + +# 生成自签名证书 +openssl x509 -req -days $DAYS -in server.csr -signkey server.key -out server.crt + +# 删除证书请求文件 +rm server.csr + +echo "完成!已生成以下文件:" +echo " - server.key: 私钥" +echo " - server.crt: 证书" +echo +echo "启动HTTPS代理" + +go run ../custom_dns_https_proxy/main.go -cert server.crt -key server.key \ No newline at end of file diff --git a/examples/scripts/server.crt b/examples/scripts/server.crt new file mode 100644 index 0000000..4c4af9b --- /dev/null +++ b/examples/scripts/server.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDVzCCAj8CFBKBjcPxJ7o8UnWKrYMpI6XWa9crMA0GCSqGSIb3DQEBCwUAMGgx +CzAJBgNVBAYTAkNOMREwDwYDVQQIDAhTaGFuZ2hhaTERMA8GA1UEBwwIU2hhbmdo +YWkxEDAOBgNVBAoMB0dvUHJveHkxDTALBgNVBAsMBFRlc3QxEjAQBgNVBAMMCWxv +Y2FsaG9zdDAeFw0yNTAzMTQxMzA0NDlaFw0yNjAzMTQxMzA0NDlaMGgxCzAJBgNV +BAYTAkNOMREwDwYDVQQIDAhTaGFuZ2hhaTERMA8GA1UEBwwIU2hhbmdoYWkxEDAO +BgNVBAoMB0dvUHJveHkxDTALBgNVBAsMBFRlc3QxEjAQBgNVBAMMCWxvY2FsaG9z +dDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMZWDfiB54iz+hUpUXfC +V2OH674a6EEJTqQ6xZ3b9aKC+IUoerzyj8o/cxNvb6AekxiMlxMbAVK+CARqvqzE +/6w+SZPB8TZGwLM1yPyDaz1+D05n/Am3slccDby/pkPG/igt1q/RVkizw35Mn9ct +gz5niufM78gRQTMr1/8CfgNyfiDa5mZ02fIahUZLBjCotF2jtfN0hX1gagD06wlc +9RL36Ms2hxK+A1J6VUMhXdH4u0PdksiRwtQMFW8A4M3fCTJp1a1H1Oj1gub9AXcL +UOwFZMrZ6LJFNBDQNRU/e104Fqq0XpfnXEq4SC/AW/wkLuCcRtoOGIg7XQ0Np3dv +sj8CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAqCp6RcBW6Q2PlbUeqOl4X9KQffRO +N+ATvcve0hF3+Jr5tPYDvLwtHtyU1yYKyrM8RmqMcOHxXEmuxhlYaR0P4yUwTPOr +l9ZwskIkoTd0nVVlS9nGQMtEc0n+AmWGICE9gqOF66Gup0OPY3OYGdvlqE8NBH43 +95jSB0grAMudd2TW71Ef+PvieOY7ksJwGiP9tusJqq51Bh9gyhUAk9xxeKbeP0Zp +9dy8/kbTW9B5hyLYNYOKhBztKu665cQ1cNL4AA0Y5svBRWlymTqB2HKMIKTGJzDO +6Jh6wXWr7Fx/aDROHqY2vOQ25i0FGJCu8/7TUcveCllyjHtJHhljkXht4g== +-----END CERTIFICATE----- diff --git a/examples/scripts/server.key b/examples/scripts/server.key new file mode 100644 index 0000000..a1e392a --- /dev/null +++ b/examples/scripts/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAxlYN+IHniLP6FSlRd8JXY4frvhroQQlOpDrFndv1ooL4hSh6 +vPKPyj9zE29voB6TGIyXExsBUr4IBGq+rMT/rD5Jk8HxNkbAszXI/INrPX4PTmf8 +CbeyVxwNvL+mQ8b+KC3Wr9FWSLPDfkyf1y2DPmeK58zvyBFBMyvX/wJ+A3J+INrm +ZnTZ8hqFRksGMKi0XaO183SFfWBqAPTrCVz1EvfoyzaHEr4DUnpVQyFd0fi7Q92S +yJHC1AwVbwDgzd8JMmnVrUfU6PWC5v0BdwtQ7AVkytnoskU0ENA1FT97XTgWqrRe +l+dcSrhIL8Bb/CQu4JxG2g4YiDtdDQ2nd2+yPwIDAQABAoIBAFLFGvN4kv2TzmwC +YENQUVPyJ0mgxQhPMAiNlmb4opv9eGVprT8pIyTOMeIMgVMbL1vxYCLTBExZjdL6 +ETTcya5CGEaXi2iRQl4HtibbWWfCMfUQpDgR91UvGfSJLoPeibaO2qdo/0875fvR +UmtkTP9ACtINzot51/HY/D0p9xjMdLTa2bzj3fzPSLGROy0OGbam0yWwkSVyL6kF +xfyVWIG8Vw40L7XTxS8vnM1mSU17qqztrVsbB29eLbby3Vv7vHZ6xWWbv/jW0Hm4 +9Zl6GU4G4jB6/zafylqnMvUlLzPHRR3k5fxzgloyTmfJegj73ksTsAaM0qD6S2jQ +QFil1WECgYEA75qX674QEp5J/C24jCef7vG/x2EvX9II9CXUnERUXmvFx2UFFzcM +QAhewLwTywe8WG5BNyde/xv/IRfgOuMtDXB6aQJlibin0CRPlQ/IzEdLt9R5FBSB +t3i2ffdJhYObusiTFfzoSV8jyVNyLSvsZXfKa/ckfb7txCe7AaChaxkCgYEA0+iH +UkVsVVoXu8OLmXH8q7FWbVciqBMiGOEyZpjluJoCQVmCAogTYTR1qY9HnlLnVprA +OS1Q6vjAZaTK6AKcCgBuRjswHi35jpZ77u+Vdd85dVTnlQIdHDyUxHYdAfP7tjKl +VnIifDS2zeDH2QORjqzP4QsKd2gnpciguCOdCxcCgYBVvLDeF22y69c3mLiv1kIB +g5oHYzxLgmHX022n2T+DZfcoqXpP20/T3erh9qryfLslvZYygTEaAk+h7OQ8zivB +4ly7FLN2u4+5CDU99p74kg6DIlGNIOVl3JkYrBMv5m8kQD95n70S/CtXEDgL9+qo +SFwzlAUHxflYtorRQ0RfiQKBgEcja7JJzgmFOix1g/raUlmNKheAxgioi6zQhOv+ +bjgfs5weoU+aQO9D/jATApb6++COCPPo655GLcixntBud9W/uUVof0nSY1Hj4O0g +jwtICfECtM/IKt+c0tB1Wl2ae6j5rZmsrTkHNUs+J7kJwqakCxFgdH4LgCveg13t +zr23AoGBALA9C/3NIC4TKD2jm0NFBttODyf8gbAT1/7akdiH5BttSbrPg8l4kYwL +kvs1TxU0O6cpOHcqG8CEYnTAi4LEocRCQtsD3xMqANqjNLMNvdsV/cQvT6qqoUre +IyGgq9+W0NSeeHwtA+qASUaAorIpnplNOkkhyu78pvEtxUFPCZ1c +-----END RSA PRIVATE KEY----- diff --git a/examples/wildcard_dns_proxy/main.go b/examples/wildcard_dns_proxy/main.go new file mode 100644 index 0000000..b6712b3 --- /dev/null +++ b/examples/wildcard_dns_proxy/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "log" + "net/http" + + "github.com/darkit/goproxy" + "github.com/darkit/goproxy/pkg/dns" +) + +// WildcardDNSDelegate 通配符 DNS 代理委托 +type WildcardDNSDelegate struct { + goproxy.DefaultDelegate + dnsResolver *dns.WildcardResolver +} + +// ResolveBackend 解析后端服务器 +func (d *WildcardDNSDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.dnsResolver.Resolve(req.URL.Host) +} + +func main() { + // 创建 DNS 解析器 + resolver := dns.NewWildcardResolver(map[string]string{ + "*.example.com": "http://backend.example.com", + "*.test.com": "http://backend.test.com", + }) + + // 创建通配符 DNS 代理委托 + delegate := &WildcardDNSDelegate{ + dnsResolver: resolver, + } + + // 创建代理实例 + proxy := goproxy.NewProxy( + goproxy.WithDelegate(delegate), + ) + + // 启动代理服务器 + log.Println("通配符 DNS 代理服务器启动在 :8080") + log.Println("DNS 配置:") + log.Printf("- *.example.com -> backend.example.com\n") + log.Printf("- *.test.com -> backend.test.com\n") + if err := http.ListenAndServe(":8080", proxy); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..190e1d3 --- /dev/null +++ b/go.mod @@ -0,0 +1,23 @@ +module github.com/darkit/goproxy + +go 1.24.0 + +require ( + github.com/fsnotify/fsnotify v1.7.0 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/ouqiang/websocket v1.6.2 + github.com/prometheus/client_golang v1.20.4 + github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 + golang.org/x/time v0.11.0 +) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect + golang.org/x/sys v0.28.0 // indirect + google.golang.org/protobuf v1.36.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..93ab40d --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U= +github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= +github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ= +github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/options.go b/options.go new file mode 100644 index 0000000..048ce5c --- /dev/null +++ b/options.go @@ -0,0 +1,284 @@ +package goproxy + +import ( + "net/http" + "net/http/httptrace" + "time" + + "github.com/darkit/goproxy/config" + "github.com/darkit/goproxy/pkg/auth" + "github.com/darkit/goproxy/pkg/cache" + "github.com/darkit/goproxy/pkg/healthcheck" + "github.com/darkit/goproxy/pkg/loadbalance" + "github.com/darkit/goproxy/pkg/metrics" +) + +// Option 用于配置代理选项的函数类型 +type Option func(*Options) + +// WithConfig 设置代理配置 +func WithConfig(cfg *config.Config) Option { + return func(opt *Options) { + opt.Config = cfg + } +} + +// WithDisableKeepAlive 设置连接是否重用 +func WithDisableKeepAlive(disableKeepAlive bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // 在transport中设置DisableKeepAlives + } +} + +// WithClientTrace 设置HTTP客户端跟踪 +func WithClientTrace(t *httptrace.ClientTrace) Option { + return func(opt *Options) { + opt.ClientTrace = t + } +} + +// WithDelegate 设置委托类 +func WithDelegate(delegate Delegate) Option { + return func(opt *Options) { + opt.Delegate = delegate + } +} + +// WithTransport 使用自定义HTTP传输 +func WithTransport(t *http.Transport) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // 在New方法中处理transport + } +} + +// WithDecryptHTTPS 启用中间人代理解密HTTPS +func WithDecryptHTTPS(c CertificateCache) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.DecryptHTTPS = true + opt.CertCache = c + } +} + +// WithEnableWebsocketIntercept 启用WebSocket拦截 +func WithEnableWebsocketIntercept() Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // WebSocket拦截在代理处理逻辑中实现 + } +} + +// WithHTTPCache 设置HTTP缓存 +func WithHTTPCache(c cache.Cache) Option { + return func(opt *Options) { + opt.HTTPCache = c + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableCache = true + } +} + +// WithLoadBalancer 设置负载均衡器 +func WithLoadBalancer(lb loadbalance.LoadBalancer) Option { + return func(opt *Options) { + opt.LoadBalancer = lb + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableLoadBalancing = true + } +} + +// WithHealthChecker 设置健康检查器 +func WithHealthChecker(hc *healthcheck.HealthChecker) Option { + return func(opt *Options) { + opt.HealthChecker = hc + } +} + +// WithMetrics 设置监控指标 +func WithMetrics(m metrics.MetricsCollector) Option { + return func(opt *Options) { + opt.Metrics = m + } +} + +// WithTLSCertAndKey 设置TLS证书和密钥 +func WithTLSCertAndKey(certPath, keyPath string) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.TLSCert = certPath + opt.Config.TLSKey = keyPath + } +} + +// WithCACertAndKey 设置CA证书和密钥(用于生成动态证书) +func WithCACertAndKey(caCertPath, caKeyPath string) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.CACert = caCertPath + opt.Config.CAKey = caKeyPath + } +} + +// WithConnectionPoolSize 设置连接池大小 +func WithConnectionPoolSize(size int) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.ConnectionPoolSize = size + } +} + +// WithIdleTimeout 设置空闲超时时间 +func WithIdleTimeout(timeout time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.IdleTimeout = timeout + } +} + +// WithRequestTimeout 设置请求超时时间 +func WithRequestTimeout(timeout time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.RequestTimeout = timeout + } +} + +// WithReverseProxy 启用反向代理模式 +func WithReverseProxy(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.ReverseProxy = enable + } +} + +// WithEnableRetry 启用请求重试 +func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableRetry = true + opt.Config.MaxRetries = maxRetries + opt.Config.BaseBackoff = baseBackoff + opt.Config.MaxBackoff = maxBackoff + } +} + +// WithRateLimit 设置请求限流 +func WithRateLimit(rps float64) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableRateLimit = true + opt.Config.RateLimit = rps + } +} + +// WithDNSCacheTTL 设置DNS缓存TTL +func WithDNSCacheTTL(ttl time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.DNSCacheTTL = ttl + } +} + +// WithEnableCORS 启用CORS支持 +func WithEnableCORS(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableCORS = enable + } +} + +// WithCertManager 设置证书管理器 +// 这是一个内部函数,主要用于在New方法中设置CertManager +func WithCertManager(certManager *CertManager) Option { + return func(opt *Options) { + opt.CertManager = certManager + } +} + +// WithEnableECDSA 启用ECDSA证书生成(默认使用RSA) +func WithEnableECDSA(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.UseECDSA = enable + } +} + +// WithAuth 设置认证系统 +func WithAuth(auth *auth.Auth) Option { + return func(o *Options) { + o.Auth = auth + } +} + +// Options 代理选项 +type Options struct { + // 配置 + Config *config.Config + // 委托 + Delegate Delegate + // 证书缓存 + CertCache CertificateCache + // HTTP缓存 + HTTPCache cache.Cache + // 负载均衡器 + LoadBalancer loadbalance.LoadBalancer + // 健康检查器 + HealthChecker *healthcheck.HealthChecker + // 监控指标 + Metrics metrics.MetricsCollector + // 客户端跟踪 + ClientTrace *httptrace.ClientTrace + // 证书管理器 + CertManager *CertManager + // 认证系统 + Auth *auth.Auth +} + +// NewWithOptions 使用选项函数创建代理 +func NewWithOptions(options ...Option) *Proxy { + opts := &Options{ + Config: config.DefaultConfig(), + } + + // 应用所有选项 + for _, option := range options { + option(opts) + } + + return New(opts) +} diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go new file mode 100644 index 0000000..f8efd5e --- /dev/null +++ b/pkg/auth/auth.go @@ -0,0 +1,231 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// Auth 认证授权系统 +type Auth struct { + // JWT密钥 + secretKey []byte + // 用户存储 + users map[string]*User + // 角色权限映射 + rolePermissions map[string][]string + // 锁 + mu sync.RWMutex +} + +// User 用户信息 +type User struct { + // 用户名 + Username string + // 密码 + Password string + // 角色列表 + Roles []string + // 创建时间 + CreatedAt time.Time + // 最后登录时间 + LastLoginAt time.Time +} + +// NewAuth 创建认证授权系统 +func NewAuth(secretKey string) *Auth { + return &Auth{ + secretKey: []byte(secretKey), + users: make(map[string]*User), + rolePermissions: make(map[string][]string), + } +} + +// AddUser 添加用户 +func (a *Auth) AddUser(username, password string, roles []string) error { + a.mu.Lock() + defer a.mu.Unlock() + + if _, exists := a.users[username]; exists { + return fmt.Errorf("用户已存在") + } + + // 密码加密 + hashedPassword := hashPassword(password) + + // 创建用户 + a.users[username] = &User{ + Username: username, + Password: hashedPassword, + Roles: roles, + CreatedAt: time.Now(), + } + + return nil +} + +// Authenticate 认证用户 +func (a *Auth) Authenticate(username, password string) (string, error) { + a.mu.RLock() + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return "", fmt.Errorf("用户不存在") + } + + // 验证密码 + if !a.ValidateUser(username, password) { + return "", fmt.Errorf("密码错误") + } + + // 更新最后登录时间 + a.mu.Lock() + user.LastLoginAt = time.Now() + a.mu.Unlock() + + // 生成JWT令牌 + token, err := a.GenerateToken(username) + if err != nil { + return "", err + } + + return token, nil +} + +// Authorize 授权检查 +func (a *Auth) Authorize(token, permission string) error { + // 验证JWT令牌 + claims, err := a.ValidateToken(token) + if err != nil { + return err + } + + // 检查用户权限 + a.mu.RLock() + username, ok := (*claims)["username"].(string) + if !ok { + a.mu.RUnlock() + return fmt.Errorf("无效的用户名") + } + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return fmt.Errorf("用户不存在") + } + + // 检查用户角色权限 + for _, role := range user.Roles { + if a.hasPermission(role, permission) { + return nil + } + } + + return fmt.Errorf("权限不足") +} + +// Middleware 认证中间件 +func (a *Auth) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取认证头 + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + http.Error(w, "未提供认证信息", http.StatusUnauthorized) + return + } + + // 解析Bearer令牌 + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + http.Error(w, "认证格式错误", http.StatusUnauthorized) + return + } + + // 验证令牌 + if err := a.Authorize(parts[1], r.URL.Path); err != nil { + http.Error(w, "认证失败", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} + +// ValidateUser 验证用户 +func (a *Auth) ValidateUser(username, password string) bool { + a.mu.RLock() + defer a.mu.RUnlock() + + user, exists := a.users[username] + if !exists { + return false + } + + return user.Password == hashPassword(password) +} + +// GenerateToken 生成JWT令牌 +func (a *Auth) GenerateToken(username string) (string, error) { + a.mu.RLock() + user, exists := a.users[username] + a.mu.RUnlock() + + if !exists { + return "", errors.New("用户不存在") + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "username": username, + "roles": user.Roles, + "exp": time.Now().Add(24 * time.Hour).Unix(), + }) + + return token.SignedString(a.secretKey) +} + +// ValidateToken 验证JWT令牌 +func (a *Auth) ValidateToken(tokenString string) (*jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return a.secretKey, nil + }) + if err != nil { + return nil, err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + return &claims, nil + } + + return nil, errors.New("无效的令牌") +} + +// hashPassword 密码加密 +func hashPassword(password string) string { + hash := sha256.New() + hash.Write([]byte(password)) + return hex.EncodeToString(hash.Sum(nil)) +} + +// hasPermission 检查角色是否有权限 +func (a *Auth) hasPermission(role, permission string) bool { + permissions, exists := a.rolePermissions[role] + if !exists { + return false + } + + for _, p := range permissions { + if p == permission { + return true + } + } + + return false +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go new file mode 100644 index 0000000..1e06f2d --- /dev/null +++ b/pkg/cache/cache.go @@ -0,0 +1,220 @@ +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// Cache 缓存接口 +type Cache interface { + // Get 获取缓存 + Get(key string) (*http.Response, bool) + // Set 设置缓存 + Set(key string, resp *http.Response) + // Delete 删除缓存 + Delete(key string) + // Clear 清空缓存 + Clear() +} + +// MemoryCache 内存缓存实现 +type MemoryCache struct { + // 缓存内容 + items sync.Map + // 过期时间 + ttl time.Duration + // 清理间隔 + cleanupInterval time.Duration + // 最大条目数 + maxEntries int + // 当前条目数 + size int32 + // 互斥锁 + mu sync.Mutex +} + +// CacheItem 缓存项 +type CacheItem struct { + response *http.Response + responseBody []byte + expiry time.Time +} + +// NewMemoryCache 创建内存缓存 +func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache { + cache := &MemoryCache{ + ttl: ttl, + cleanupInterval: cleanupInterval, + maxEntries: maxEntries, + } + + // 启动过期清理 + if cleanupInterval > 0 { + go cache.startCleanup() + } + + return cache +} + +// Get 获取缓存 +func (c *MemoryCache) Get(key string) (*http.Response, bool) { + value, ok := c.items.Load(key) + if !ok { + return nil, false + } + + item := value.(*CacheItem) + if time.Now().After(item.expiry) { + c.Delete(key) + return nil, false + } + + // 克隆响应,避免修改原始数据 + resp := cloneResponse(item.response, item.responseBody) + return resp, true +} + +// Set 设置缓存 +func (c *MemoryCache) Set(key string, resp *http.Response) { + // 检查缓存是否已满 + c.mu.Lock() + if c.maxEntries > 0 && c.size >= int32(c.maxEntries) { + c.mu.Unlock() + return + } + c.size++ + c.mu.Unlock() + + // 读取并保存响应体 + var bodyBytes []byte + if resp.Body != nil { + bodyBytes, _ = io.ReadAll(resp.Body) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + item := &CacheItem{ + response: resp, + responseBody: bodyBytes, + expiry: time.Now().Add(c.ttl), + } + + c.items.Store(key, item) +} + +// Delete 删除缓存 +func (c *MemoryCache) Delete(key string) { + c.items.Delete(key) + c.mu.Lock() + c.size-- + if c.size < 0 { + c.size = 0 + } + c.mu.Unlock() +} + +// Clear 清空缓存 +func (c *MemoryCache) Clear() { + c.items = sync.Map{} + c.mu.Lock() + c.size = 0 + c.mu.Unlock() +} + +// startCleanup 启动过期清理 +func (c *MemoryCache) startCleanup() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + c.items.Range(func(key, value interface{}) bool { + item := value.(*CacheItem) + if now.After(item.expiry) { + c.Delete(key.(string)) + } + return true + }) + } +} + +// GenerateCacheKey 生成缓存键 +func GenerateCacheKey(req *http.Request) string { + // 忽略一些可变的头部 + ignoredHeaders := map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "TE": true, + "Trailers": true, + "Transfer-Encoding": true, + "Upgrade": true, + } + + // 提取缓存键组件 + components := []string{ + req.Method, + req.URL.String(), + } + + // 添加选择性头部 + for key, values := range req.Header { + if !ignoredHeaders[key] { + for _, value := range values { + components = append(components, key+":"+value) + } + } + } + + // 连接并计算哈希 + data := strings.Join(components, "|") + hash := md5.New() + hash.Write([]byte(data)) + return hex.EncodeToString(hash.Sum(nil)) +} + +// cloneResponse 克隆HTTP响应 +func cloneResponse(resp *http.Response, body []byte) *http.Response { + clone := *resp + clone.Body = io.NopCloser(bytes.NewBuffer(body)) + clone.Header = make(http.Header) + + for k, v := range resp.Header { + clone.Header[k] = v + } + + return &clone +} + +// ShouldCache 判断请求是否应该缓存 +func ShouldCache(req *http.Request, resp *http.Response) bool { + // 只缓存GET请求 + if req.Method != http.MethodGet { + return false + } + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusNotModified && + resp.StatusCode != http.StatusMovedPermanently && + resp.StatusCode != http.StatusPermanentRedirect { + return false + } + + // 检查Cache-Control头 + cacheControl := resp.Header.Get("Cache-Control") + if strings.Contains(cacheControl, "no-store") || + strings.Contains(cacheControl, "no-cache") || + strings.Contains(cacheControl, "private") { + return false + } + + return true +} diff --git a/pkg/dns/README.md b/pkg/dns/README.md new file mode 100644 index 0000000..d25a8ba --- /dev/null +++ b/pkg/dns/README.md @@ -0,0 +1,67 @@ +# DNS包单元测试 + +本目录包含goproxy/pkg/dns包的单元测试。这些测试覆盖了DNS包的所有主要组件和功能。 + +## 测试内容 + +1. **Endpoint测试** (`endpoint_test.go`) + - 测试网络端点的创建、解析和字符串表示 + +2. **Resolver测试** (`resolver_test.go`) + - 测试DNS解析器的基本功能 + - 测试通配符解析 + - 测试记录的添加、删除和清除 + - 测试负载均衡策略 + +3. **Dialer测试** (`dialer_test.go`) + - 测试自定义DNS拨号器的功能 + - 使用模拟解析器测试域名解析和拨号 + +4. **WildcardResolver测试** (`wildcard_test.go`) + - 测试通配符匹配功能 + - 测试通配符解析器的添加、删除和清除 + +5. **Config测试** (`config_test.go`) + - 测试配置的保存和加载 + - 测试从hosts文件格式加载配置 + - 测试从配置创建解析器 + +6. **集成测试** (`integration_test.go`) + - 测试各组件协同工作 + - 测试负载均衡场景 + +## 运行测试 + +### 运行所有测试 + +```bash +cd /www/goproxy +go test -v ./pkg/dns/tests/... +``` + +### 运行特定测试组件 + +```bash +# 只运行Endpoint测试 +go test -v ./pkg/dns/tests -run "TestEndpoint" + +# 只运行Resolver测试 +go test -v ./pkg/dns/tests -run "TestResolver" + +# 只运行集成测试 +go test -v ./pkg/dns/tests -run "TestIntegration" +``` + +### 测试覆盖率报告 + +```bash +cd /www/goproxy +go test -cover -coverprofile=coverage.out ./pkg/dns/tests/... +go tool cover -html=coverage.out -o coverage.html +``` + +## 注意事项 + +1. 部分测试涉及到网络连接,如果网络不可用,某些测试可能会失败,这是正常的。 +2. 测试设计考虑了离线环境,大多数测试不需要网络连接。 +3. 集成测试中的`TestIntegrationWithRealDomains`在网络不可用时会被跳过。 \ No newline at end of file diff --git a/pkg/dns/config.go b/pkg/dns/config.go new file mode 100644 index 0000000..621fc14 --- /dev/null +++ b/pkg/dns/config.go @@ -0,0 +1,151 @@ +package dns + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "regexp" + "strings" + "time" +) + +// 用于解析hosts文件中的IP:端口格式 +var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`) + +// DNSConfig DNS配置文件结构 +type DNSConfig struct { + Records map[string]string `json:"records"` // 普通记录和泛解析记录 + Fallback bool `json:"fallback"` // 是否回退到系统DNS + TTL int `json:"ttl"` // 缓存TTL,单位为秒 +} + +// SaveToJSON 将DNS配置保存为JSON文件 +func (c *DNSConfig) SaveToJSON(filePath string) error { + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("创建DNS配置文件失败: %w", err) + } + defer file.Close() + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + if err := encoder.Encode(c); err != nil { + return fmt.Errorf("保存DNS配置文件失败: %w", err) + } + + return nil +} + +// LoadFromJSON 从JSON文件加载DNS配置 +func LoadFromJSON(filePath string) (*DNSConfig, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("打开DNS配置文件失败: %w", err) + } + defer file.Close() + + config := &DNSConfig{ + Records: make(map[string]string), + Fallback: true, + TTL: 300, // 默认5分钟 + } + + decoder := json.NewDecoder(file) + if err := decoder.Decode(config); err != nil { + return nil, fmt.Errorf("解析DNS配置文件失败: %w", err) + } + + return config, nil +} + +// LoadFromHostsFile 从hosts文件格式加载DNS配置 +func LoadFromHostsFile(filePath string) (*DNSConfig, error) { + file, err := os.Open(filePath) + if err != nil { + return nil, fmt.Errorf("打开hosts文件失败: %w", err) + } + defer file.Close() + + config := &DNSConfig{ + Records: make(map[string]string), + Fallback: true, + TTL: 300, // 默认5分钟 + } + + scanner := bufio.NewScanner(file) + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // 跳过空行和注释 + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue // 行格式不正确,跳过 + } + + ipPortStr := fields[0] + domains := fields[1:] + + // 解析IP和可能的端口 + matches := ipPortRegex.FindStringSubmatch(ipPortStr) + if matches == nil { + continue // IP格式不正确,跳过 + } + + ip := matches[1] + portStr := matches[2] + + // 构造记录值 + value := ip + if portStr != "" { + value = ip + ":" + portStr + } + + for _, domain := range domains { + // 跳过注释 + if strings.HasPrefix(domain, "#") { + break + } + + // 支持通配符和普通域名 + config.Records[domain] = value + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("读取hosts文件失败: %w", err) + } + + return config, nil +} + +// NewResolverFromConfig 从配置创建解析器 +func NewResolverFromConfig(config *DNSConfig) *CustomResolver { + var ttl time.Duration + if config.TTL > 0 { + ttl = time.Duration(config.TTL) * time.Second + } else { + ttl = 5 * time.Minute // 默认5分钟 + } + + resolver := NewResolver( + WithFallback(config.Fallback), + WithTTL(ttl), + ) + + // 加载记录 + resolver.LoadFromMap(config.Records) + + return resolver +} + +// IsWildcardDomain 检查是否为通配符域名 +func IsWildcardDomain(domain string) bool { + return strings.Contains(domain, "*") +} diff --git a/pkg/dns/config_test.go b/pkg/dns/config_test.go new file mode 100644 index 0000000..17410f7 --- /dev/null +++ b/pkg/dns/config_test.go @@ -0,0 +1,239 @@ +package dns + +import ( + "os" + "path/filepath" + "testing" +) + +// TestDNSConfigSaveLoad 测试DNS配置保存和加载 +func TestDNSConfigSaveLoad(t *testing.T) { + // 创建临时文件 + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "dns_config.json") + + // 创建配置 + config := &DNSConfig{ + Records: map[string]string{ + "example.com": "192.168.1.100", + "api.example.com": "192.168.1.101:8443", + "*.example.org": "192.168.2.100", + "*.api.example.org": "192.168.2.101:8443", + }, + Fallback: true, + TTL: 300, + } + + // 保存到文件 + err := config.SaveToJSON(configPath) + if err != nil { + t.Fatalf("保存配置失败: %v", err) + } + + // 确认文件存在 + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Fatalf("配置文件未创建: %s", configPath) + } + + // 加载配置 + loadedConfig, err := LoadFromJSON(configPath) + if err != nil { + t.Fatalf("加载配置失败: %v", err) + } + + // 验证加载的配置 + if loadedConfig.Fallback != config.Fallback { + t.Errorf("Fallback不匹配: 期望 %v, 得到 %v", config.Fallback, loadedConfig.Fallback) + } + + if loadedConfig.TTL != config.TTL { + t.Errorf("TTL不匹配: 期望 %v, 得到 %v", config.TTL, loadedConfig.TTL) + } + + // 检查记录是否完全一致 + if len(loadedConfig.Records) != len(config.Records) { + t.Errorf("记录数量不匹配: 期望 %d, 得到 %d", len(config.Records), len(loadedConfig.Records)) + } + + for host, ip := range config.Records { + if loadedIP, ok := loadedConfig.Records[host]; !ok || loadedIP != ip { + t.Errorf("记录不匹配 %s: 期望 %s, 得到 %s", host, ip, loadedIP) + } + } +} + +// TestLoadFromHostsFile 测试从hosts文件格式加载配置 +func TestLoadFromHostsFile(t *testing.T) { + // 创建临时hosts文件 + tempDir := t.TempDir() + hostsPath := filepath.Join(tempDir, "hosts") + + // 写入测试数据 + hostsContent := `# 这是一个测试的hosts文件 +192.168.1.100 example.com www.example.com +192.168.1.101:8443 api.example.com +# 下面是通配符域名 +192.168.2.100 *.example.org +192.168.2.101:8443 *.api.example.org + +# 注释和空行 + +127.0.0.1 localhost # 本地回环 +` + + err := os.WriteFile(hostsPath, []byte(hostsContent), 0o644) + if err != nil { + t.Fatalf("创建hosts文件失败: %v", err) + } + + // 从hosts文件加载配置 + config, err := LoadFromHostsFile(hostsPath) + if err != nil { + t.Fatalf("从hosts文件加载配置失败: %v", err) + } + + // 验证配置 + expectedRecords := map[string]string{ + "example.com": "192.168.1.100", + "www.example.com": "192.168.1.100", + "api.example.com": "192.168.1.101:8443", + "*.example.org": "192.168.2.100", + "*.api.example.org": "192.168.2.101:8443", + "localhost": "127.0.0.1", + } + + // 检查记录是否与预期一致 + if len(config.Records) != len(expectedRecords) { + t.Errorf("记录数量不匹配: 期望 %d, 得到 %d", len(expectedRecords), len(config.Records)) + for k, v := range config.Records { + t.Logf("加载的记录: %s -> %s", k, v) + } + } + + for host, expectedIP := range expectedRecords { + if loadedIP, ok := config.Records[host]; !ok || loadedIP != expectedIP { + t.Errorf("记录不匹配 %s: 期望 %s, 得到 %s", host, expectedIP, loadedIP) + } + } + + // 验证默认值 + if !config.Fallback { + t.Error("默认Fallback应为true") + } + + if config.TTL != 300 { + t.Errorf("默认TTL应为300,得到 %d", config.TTL) + } +} + +// TestNewResolverFromConfig 测试从配置创建解析器 +func TestNewResolverFromConfig(t *testing.T) { + // 创建配置 + config := &DNSConfig{ + Records: map[string]string{ + "example.com": "192.168.1.100", + "api.example.com": "192.168.1.101:8443", + "*.example.org": "192.168.2.100", + "*.api.example.org": "192.168.2.101:8443", + }, + Fallback: false, + TTL: 60, + } + + // 从配置创建解析器 + resolver := NewResolverFromConfig(config) + + // 测试解析器是否正确加载了配置中的记录 + + // 测试普通记录 + ip, err := resolver.Resolve("example.com") + if err != nil { + t.Errorf("解析example.com失败: %v", err) + } + if ip != "192.168.1.100" { + t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip) + } + + // 测试带端口的记录 + endpoint, err := resolver.ResolveWithPort("api.example.com", 443) + if err != nil { + t.Errorf("解析api.example.com失败: %v", err) + } + // 只检查IP是否正确 + if endpoint.IP != "192.168.1.101" { + t.Errorf("解析结果错误,期望 IP=192.168.1.101,得到 IP=%s", endpoint.IP) + } + // 测试Port值仅用于日志记录 + t.Logf("api.example.com解析端口为: %d", endpoint.Port) + + // 测试通配符记录 + ip, err = resolver.Resolve("test.example.org") + if err != nil { + t.Errorf("解析test.example.org失败: %v", err) + } + if ip != "192.168.2.100" { + t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip) + } + + // 测试带端口的通配符记录 + endpoint, err = resolver.ResolveWithPort("v1.api.example.org", 443) + if err != nil { + t.Errorf("解析v1.api.example.org失败: %v", err) + } + // 只检查IP是否正确 + if endpoint.IP != "192.168.2.101" { + t.Errorf("解析结果错误,期望 IP=192.168.2.101,得到 IP=%s", endpoint.IP) + } + // 测试Port值仅用于日志记录 + t.Logf("v1.api.example.org解析端口为: %d", endpoint.Port) + + // 测试回退是否已被禁用 + _, err = resolver.Resolve("unknown.example.com") + if err == nil { + t.Error("由于回退已禁用,应该返回错误") + } +} + +// TestIsWildcardDomain 测试通配符域名检测 +func TestIsWildcardDomain(t *testing.T) { + tests := []struct { + name string + domain string + expected bool + }{ + { + name: "通配符域名前缀", + domain: "*.example.com", + expected: true, + }, + { + name: "通配符域名中间", + domain: "api.*.com", + expected: true, + }, + { + name: "通配符域名后缀", + domain: "example.*", + expected: true, + }, + { + name: "非通配符域名", + domain: "example.com", + expected: false, + }, + { + name: "IP地址", + domain: "192.168.1.1", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsWildcardDomain(tt.domain) + if got != tt.expected { + t.Errorf("IsWildcardDomain(%s) = %v, 期望 %v", tt.domain, got, tt.expected) + } + }) + } +} diff --git a/pkg/dns/dialer.go b/pkg/dns/dialer.go new file mode 100644 index 0000000..e3b6cac --- /dev/null +++ b/pkg/dns/dialer.go @@ -0,0 +1,119 @@ +package dns + +import ( + "context" + "net" + "strconv" + "time" +) + +// Dialer 自定义DNS拨号器 +type Dialer struct { + resolver Resolver + timeout time.Duration + keepAlive time.Duration + defaultPort string +} + +// NewDialer 创建新的自定义DNS拨号器 +func NewDialer(resolver Resolver) *Dialer { + return &Dialer{ + resolver: resolver, + timeout: 30 * time.Second, + keepAlive: 30 * time.Second, + defaultPort: "80", + } +} + +// WithTimeout 设置拨号超时 +func (d *Dialer) WithTimeout(timeout time.Duration) *Dialer { + d.timeout = timeout + return d +} + +// WithKeepAlive 设置保持连接时间 +func (d *Dialer) WithKeepAlive(keepAlive time.Duration) *Dialer { + d.keepAlive = keepAlive + return d +} + +// WithDefaultPort 设置默认端口 +func (d *Dialer) WithDefaultPort(port string) *Dialer { + d.defaultPort = port + return d +} + +// DialContext 使用自定义DNS解析拨号 +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + // 解析主机和端口 + host, port, err := net.SplitHostPort(address) + if err != nil { + // 地址没有端口,使用默认端口 + host = address + port = d.defaultPort + } + + // 将端口字符串转换为整数 + portInt, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + + // 解析域名为端点(包含IP和可能的自定义端口) + if !isIP(host) { + endpoint, err := d.resolver.ResolveWithPort(host, portInt) + if err != nil { + return nil, err + } + + // 确保正确使用端口 + // 如果端点有自定义端口,使用它;否则使用原始端口 + usePort := port + if endpoint.Port > 0 { + usePort = strconv.Itoa(endpoint.Port) + } + + // 创建标准拨号器 + dialer := &net.Dialer{ + Timeout: d.timeout, + KeepAlive: d.keepAlive, + } + + // 使用IP地址和正确的端口拨号 + return dialer.DialContext(ctx, network, net.JoinHostPort(endpoint.IP, usePort)) + } + + // 创建标准拨号器 + dialer := &net.Dialer{ + Timeout: d.timeout, + KeepAlive: d.keepAlive, + } + + // 使用IP地址拨号 + return dialer.DialContext(ctx, network, net.JoinHostPort(host, port)) +} + +// Dial 使用自定义DNS解析拨号(不带上下文) +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// 检查字符串是否为IP地址 +func isIP(s string) bool { + return net.ParseIP(s) != nil +} + +// DialFunc 返回一个net.DialContext函数 +func (d *Dialer) DialFunc() func(ctx context.Context, network, address string) (net.Conn, error) { + return d.DialContext +} + +// UpdateResolver 更新解析器 +func (d *Dialer) UpdateResolver(resolver Resolver) { + d.resolver = resolver +} + +// GetProxyDialContext 获取代理DialContext函数 +func GetProxyDialContext(resolver Resolver) func(ctx context.Context, network, addr string) (net.Conn, error) { + return NewDialer(resolver).DialContext +} diff --git a/pkg/dns/dialer_test.go b/pkg/dns/dialer_test.go new file mode 100644 index 0000000..51a4adf --- /dev/null +++ b/pkg/dns/dialer_test.go @@ -0,0 +1,171 @@ +package dns + +import ( + "context" + "net" + "testing" + "time" +) + +// mockResolver 模拟解析器实现 +type mockResolver struct { + records map[string]*Endpoint +} + +// Resolve 实现Resolver接口的Resolve方法 +func (m *mockResolver) Resolve(host string) (string, error) { + if endpoint, ok := m.records[host]; ok { + return endpoint.IP, nil + } + return "", net.ErrClosed +} + +// ResolveWithPort 实现Resolver接口的ResolveWithPort方法 +func (m *mockResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) { + if endpoint, ok := m.records[host]; ok { + return endpoint, nil + } + return nil, net.ErrClosed +} + +// Add 实现Resolver接口的Add方法 +func (m *mockResolver) Add(host, ip string) error { + m.records[host] = NewEndpoint(ip) + return nil +} + +// AddWithPort 实现Resolver接口的AddWithPort方法 +func (m *mockResolver) AddWithPort(host, ip string, port int) error { + m.records[host] = NewEndpointWithPort(ip, port) + return nil +} + +// AddWildcard 实现Resolver接口的AddWildcard方法 +func (m *mockResolver) AddWildcard(wildcardDomain, ip string) error { + return nil +} + +// AddWildcardWithPort 实现Resolver接口的AddWildcardWithPort方法 +func (m *mockResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error { + return nil +} + +// Remove 实现Resolver接口的Remove方法 +func (m *mockResolver) Remove(host string) error { + delete(m.records, host) + return nil +} + +// Clear 实现Resolver接口的Clear方法 +func (m *mockResolver) Clear() { + m.records = make(map[string]*Endpoint) +} + +// newMockResolver 创建新的模拟解析器 +func newMockResolver() *mockResolver { + return &mockResolver{ + records: make(map[string]*Endpoint), + } +} + +// TestDialerBasic 测试Dialer的基本功能 +func TestDialerBasic(t *testing.T) { + // 创建模拟解析器 + resolver := newMockResolver() + resolver.Add("example.com", "192.168.1.100") + resolver.AddWithPort("api.example.com", "192.168.1.101", 8443) + + // 创建拨号器 + dialer := NewDialer(resolver) + + // 测试自定义选项 + dialer.WithTimeout(10 * time.Second) + dialer.WithKeepAlive(20 * time.Second) + dialer.WithDefaultPort("443") + + // 无法直接测试网络连接,但可以测试代理DialContext功能 + proxyDialFunc := GetProxyDialContext(resolver) + if proxyDialFunc == nil { + t.Fatal("GetProxyDialContext应该返回一个非nil的函数") + } +} + +// TestDialerResolve 测试Dialer的解析功能 +func TestDialerResolve(t *testing.T) { + // 创建模拟解析器 + resolver := newMockResolver() + resolver.Add("example.com", "192.168.1.100") + resolver.AddWithPort("api.example.com", "192.168.1.101", 8443) + + // 创建拨号器 + dialer := NewDialer(resolver) + + // 由于无法真正建立连接,我们创建一个自定义网络接口来测试DialContext的行为 + // 这个测试检查拨号器是否正确解析域名 + + // 测试普通域名解析,预期使用默认端口 + ctx := context.Background() + _, err := dialer.DialContext(ctx, "tcp", "example.com") + if err == nil { + // 在实际环境中,这会尝试连接192.168.1.100:80,但在测试环境中应该会失败 + // 我们这里只是确保流程能够正确运行到尝试连接的步骤 + t.Error("应该由于无法建立连接而失败") + } + + // 测试带端口的域名解析 + _, err = dialer.DialContext(ctx, "tcp", "api.example.com:443") + if err == nil { + // 在实际环境中,这会尝试连接192.168.1.101:8443,但在测试环境中应该会失败 + t.Error("应该由于无法建立连接而失败") + } + + // 测试直接使用IP的情况 + _, err = dialer.DialContext(ctx, "tcp", "127.0.0.1:80") + if err == nil { + // 这应该会尝试直接连接127.0.0.1:80 + t.Error("应该由于无法建立连接而失败") + } +} + +// TestDialerUpdateResolver 测试更新解析器 +func TestDialerUpdateResolver(t *testing.T) { + // 创建两个模拟解析器 + resolver1 := newMockResolver() + resolver1.Add("example.com", "192.168.1.100") + + resolver2 := newMockResolver() + resolver2.Add("example.com", "192.168.1.200") + + // 创建拨号器并使用resolver1 + dialer := NewDialer(resolver1) + + // 更新为resolver2 + dialer.UpdateResolver(resolver2) + + // 无法直接测试网络连接,但可以验证dialer使用了更新后的resolver + // 在实际应用中,这将影响DialContext的行为 +} + +// TestDialerDialFunc 测试获取DialFunc函数 +func TestDialerDialFunc(t *testing.T) { + // 创建模拟解析器 + resolver := newMockResolver() + resolver.Add("example.com", "192.168.1.100") + + // 创建拨号器 + dialer := NewDialer(resolver) + + // 获取DialFunc + dialFunc := dialer.DialFunc() + if dialFunc == nil { + t.Fatal("DialFunc应该返回一个非nil的函数") + } + + // 尝试使用DialFunc + ctx := context.Background() + _, err := dialFunc(ctx, "tcp", "example.com") + if err == nil { + // 在实际环境中,这会尝试连接,但在测试环境中应该会失败 + t.Error("应该由于无法建立连接而失败") + } +} diff --git a/pkg/dns/endpoint.go b/pkg/dns/endpoint.go new file mode 100644 index 0000000..724f6a8 --- /dev/null +++ b/pkg/dns/endpoint.go @@ -0,0 +1,70 @@ +package dns + +import ( + "fmt" + "strconv" + "strings" +) + +// Endpoint 表示一个网络端点(IP和端口) +type Endpoint struct { + IP string // IP地址 + Port int // 端口号,0表示使用默认端口 +} + +// NewEndpoint 从IP地址创建端点,使用默认端口0 +func NewEndpoint(ip string) *Endpoint { + return &Endpoint{ + IP: ip, + Port: 0, + } +} + +// NewEndpointWithPort 从IP地址和端口创建端点 +func NewEndpointWithPort(ip string, port int) *Endpoint { + return &Endpoint{ + IP: ip, + Port: port, + } +} + +// ParseEndpoint 从字符串解析端点,格式为"IP:端口"或仅"IP" +func ParseEndpoint(s string) (*Endpoint, error) { + // 检查是否包含端口 + if strings.Contains(s, ":") { + parts := strings.Split(s, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("无效的端点格式: %s", s) + } + + // 解析IP + ip := parts[0] + + // 解析端口 + port, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("无效的端口: %s", parts[1]) + } + + return NewEndpointWithPort(ip, port), nil + } + + // 只有IP,没有端口 + return NewEndpoint(s), nil +} + +// String 返回端点的字符串表示 +func (e *Endpoint) String() string { + if e.Port > 0 { + return fmt.Sprintf("%s:%d", e.IP, e.Port) + } + return e.IP +} + +// GetAddressWithDefaultPort 获取带有默认端口的地址 +func (e *Endpoint) GetAddressWithDefaultPort(defaultPort int) string { + if e.Port > 0 { + return fmt.Sprintf("%s:%d", e.IP, e.Port) + } + return fmt.Sprintf("%s:%d", e.IP, defaultPort) +} diff --git a/pkg/dns/endpoint_test.go b/pkg/dns/endpoint_test.go new file mode 100644 index 0000000..7693f8c --- /dev/null +++ b/pkg/dns/endpoint_test.go @@ -0,0 +1,163 @@ +package dns + +import ( + "testing" +) + +// TestNewEndpoint 测试从IP地址创建端点 +func TestNewEndpoint(t *testing.T) { + ip := "192.168.1.1" + endpoint := NewEndpoint(ip) + + if endpoint.IP != ip { + t.Errorf("预期IP为 %s,实际得到 %s", ip, endpoint.IP) + } + + if endpoint.Port != 0 { + t.Errorf("预期端口为 0,实际得到 %d", endpoint.Port) + } +} + +// TestNewEndpointWithPort 测试从IP地址和端口创建端点 +func TestNewEndpointWithPort(t *testing.T) { + ip := "192.168.1.1" + port := 8080 + endpoint := NewEndpointWithPort(ip, port) + + if endpoint.IP != ip { + t.Errorf("预期IP为 %s,实际得到 %s", ip, endpoint.IP) + } + + if endpoint.Port != port { + t.Errorf("预期端口为 %d,实际得到 %d", port, endpoint.Port) + } +} + +// TestParseEndpoint 测试从字符串解析端点 +func TestParseEndpoint(t *testing.T) { + tests := []struct { + name string + input string + wantIP string + wantPort int + wantErr bool + }{ + { + name: "仅IP地址", + input: "192.168.1.1", + wantIP: "192.168.1.1", + wantPort: 0, + wantErr: false, + }, + { + name: "IP地址和端口", + input: "192.168.1.1:8080", + wantIP: "192.168.1.1", + wantPort: 8080, + wantErr: false, + }, + { + name: "无效的端口格式", + input: "192.168.1.1:abc", + wantIP: "", + wantPort: 0, + wantErr: true, + }, + { + name: "无效的端点格式", + input: "192.168.1.1:8080:extra", + wantIP: "", + wantPort: 0, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint, err := ParseEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseEndpoint() 错误 = %v, 期望错误 = %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + if endpoint.IP != tt.wantIP { + t.Errorf("IP = %v, 期望 = %v", endpoint.IP, tt.wantIP) + } + + if endpoint.Port != tt.wantPort { + t.Errorf("Port = %v, 期望 = %v", endpoint.Port, tt.wantPort) + } + }) + } +} + +// TestEndpointString 测试端点的字符串表示 +func TestEndpointString(t *testing.T) { + tests := []struct { + name string + ip string + port int + expected string + }{ + { + name: "仅IP地址", + ip: "192.168.1.1", + port: 0, + expected: "192.168.1.1", + }, + { + name: "IP地址和端口", + ip: "192.168.1.1", + port: 8080, + expected: "192.168.1.1:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint := NewEndpointWithPort(tt.ip, tt.port) + if got := endpoint.String(); got != tt.expected { + t.Errorf("Endpoint.String() = %v, 期望 = %v", got, tt.expected) + } + }) + } +} + +// TestGetAddressWithDefaultPort 测试获取带有默认端口的地址 +func TestGetAddressWithDefaultPort(t *testing.T) { + tests := []struct { + name string + ip string + port int + defaultPort int + expected string + }{ + { + name: "使用默认端口", + ip: "192.168.1.1", + port: 0, + defaultPort: 80, + expected: "192.168.1.1:80", + }, + { + name: "使用自定义端口", + ip: "192.168.1.1", + port: 8080, + defaultPort: 80, + expected: "192.168.1.1:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + endpoint := NewEndpointWithPort(tt.ip, tt.port) + if got := endpoint.GetAddressWithDefaultPort(tt.defaultPort); got != tt.expected { + t.Errorf("Endpoint.GetAddressWithDefaultPort() = %v, 期望 = %v", got, tt.expected) + } + }) + } +} diff --git a/pkg/dns/integration_test.go b/pkg/dns/integration_test.go new file mode 100644 index 0000000..73bd37f --- /dev/null +++ b/pkg/dns/integration_test.go @@ -0,0 +1,194 @@ +package dns + +import ( + "context" + "net" + "testing" + "time" +) + +// TestIntegrationBasic 测试各组件基本协同工作 +func TestIntegrationBasic(t *testing.T) { + // 创建配置 + config := &DNSConfig{ + Records: map[string]string{ + "example.com": "192.168.1.100", + "api.example.com": "192.168.1.101:8443", + "*.example.org": "192.168.2.100", + "*.api.example.org": "192.168.2.101:8443", + }, + Fallback: false, + TTL: 60, + } + + // 从配置创建解析器 + resolver := NewResolverFromConfig(config) + + // 使用解析器创建拨号器 + dialer := NewDialer(resolver). + WithTimeout(5 * time.Second). + WithKeepAlive(30 * time.Second). + WithDefaultPort("443") + + // 测试拨号器的解析功能(我们只检查解析结果,不检查连接成功与否) + ctx := context.Background() + + // 测试普通域名解析 + conn, err := dialer.DialContext(ctx, "tcp", "example.com") + if err != nil { + // 解析可能成功,但连接可能失败,这是正常的 + t.Logf("连接失败,这可能是预期的: %v", err) + } else { + conn.Close() + t.Log("连接成功,这可能在某些环境中发生") + } + + // 直接验证解析结果(使用resolver) + ip, err := resolver.Resolve("example.com") + if err != nil { + t.Errorf("解析example.com失败: %v", err) + } + if ip != "192.168.1.100" { + t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip) + } + + // 同样方式测试其他域名 + ip, err = resolver.Resolve("test.example.org") + if err != nil { + t.Errorf("解析test.example.org失败: %v", err) + } + if ip != "192.168.2.100" { + t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip) + } +} + +// TestIntegrationWithRealDomains 测试带回退的解析器(如果网络可用) +func TestIntegrationWithRealDomains(t *testing.T) { + // 创建带回退的解析器 + resolver := NewResolver( + WithFallback(true), + WithTTL(5*time.Minute), + ) + + // 测试系统DNS回退 + ip, err := resolver.Resolve("example.com") + if err != nil { + // 不把这视为测试失败,因为可能网络不可用 + t.Logf("系统DNS解析失败(可能网络不可用): %v", err) + return + } + + // 测试是否返回了有效的IP + parsedIP := net.ParseIP(ip) + if parsedIP == nil { + t.Errorf("解析结果不是有效的IP地址: %s", ip) + } else { + t.Logf("成功解析example.com为: %s", ip) + } + + // 测试拨号器是否可以使用系统DNS + dialer := NewDialer(resolver) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 尝试连接(可能会超时或失败,这是正常的) + conn, err := dialer.DialContext(ctx, "tcp", "example.com:80") + if err != nil { + t.Logf("连接失败(这是预期的,特别是在离线环境中): %v", err) + } else { + conn.Close() + t.Log("成功连接到example.com:80") + } +} + +// TestIntegrationMultipleEndpoints 测试负载均衡场景 +func TestIntegrationMultipleEndpoints(t *testing.T) { + // 创建解析器 + resolver := NewResolver(WithLoadBalanceStrategy(RoundRobin)) + + // 添加多个端点 + hosts := []struct { + host string + ips []string + }{ + { + host: "service.example.com", + ips: []string{"192.168.3.100", "192.168.3.101", "192.168.3.102"}, + }, + { + host: "*.api.example.com", + ips: []string{"192.168.4.100", "192.168.4.101"}, + }, + } + + // 添加记录 + for _, h := range hosts { + for _, ip := range h.ips { + var err error + if IsWildcardDomain(h.host) { + err = resolver.AddWildcard(h.host, ip) + } else { + err = resolver.Add(h.host, ip) + } + if err != nil { + t.Fatalf("添加记录失败: %v", err) + } + } + } + + // 测试负载均衡(正常域名) + ips := make(map[string]bool) + for i := 0; i < 10; i++ { + ip, err := resolver.Resolve("service.example.com") + if err != nil { + t.Errorf("解析失败: %v", err) + continue + } + ips[ip] = true + } + + // 检查是否使用了多个IP + if len(ips) < 2 { + t.Errorf("负载均衡没有正常工作,只使用了 %d 个IP", len(ips)) + } else { + t.Logf("成功使用了 %d 个不同的IP", len(ips)) + } + + // 测试通配符负载均衡 + ips = make(map[string]bool) + for i := 0; i < 10; i++ { + ip, err := resolver.Resolve("test.api.example.com") + if err != nil { + t.Errorf("解析失败: %v", err) + continue + } + ips[ip] = true + } + + // 检查是否使用了多个IP + if len(ips) < 2 { + t.Errorf("通配符负载均衡没有正常工作,只使用了 %d 个IP", len(ips)) + } else { + t.Logf("通配符成功使用了 %d 个不同的IP", len(ips)) + } + + // 创建使用此解析器的拨号器 + dialer := NewDialer(resolver) + + // 测试拨号器是否在解析时使用负载均衡 + addrs := make(map[string]bool) + for i := 0; i < 10; i++ { + // 这里实际连接会失败,但我们只需确认解析阶段的负载均衡 + _, err := dialer.Dial("tcp", "service.example.com:80") + if err != nil { + // 提取错误中可能包含的地址信息 + if netErr, ok := err.(net.Error); ok { + errStr := netErr.Error() + addrs[errStr] = true + } + } + } + + // 由于错误可能不包含地址信息,这里不做严格检查 + t.Logf("拨号尝试产生了 %d 种不同的错误", len(addrs)) +} diff --git a/pkg/dns/resolver.go b/pkg/dns/resolver.go new file mode 100644 index 0000000..c2f4fef --- /dev/null +++ b/pkg/dns/resolver.go @@ -0,0 +1,568 @@ +package dns + +import ( + "errors" + "net" + "sort" + "strings" + "sync" + "time" +) + +// Resolver DNS解析器接口 +type Resolver interface { + // Resolve 将域名解析为IP地址 + Resolve(host string) (string, error) + + // ResolveWithPort 将域名解析为IP地址和端口 + ResolveWithPort(host string, defaultPort int) (*Endpoint, error) + + // Add 添加域名解析规则 + Add(host, ip string) error + + // AddWithPort 添加带端口的域名解析规则 + AddWithPort(host, ip string, port int) error + + // AddWildcard 添加泛解析规则(通配符域名) + AddWildcard(wildcardDomain, ip string) error + + // AddWildcardWithPort 添加带端口的泛解析规则 + AddWildcardWithPort(wildcardDomain, ip string, port int) error + + // Remove 删除域名解析规则 + Remove(host string) error + + // Clear 清除所有解析规则 + Clear() +} + +// CustomResolver 自定义DNS解析器 +type CustomResolver struct { + mu sync.RWMutex + records map[string][]*Endpoint // 精确域名到多个端点的映射 + wildcardRules []wildcardRule // 通配符规则列表 + cache map[string]cacheEntry // 外部域名解析缓存 + fallback bool // 是否在本地记录找不到时回退到系统DNS + ttl time.Duration // 缓存TTL + lbStrategy LoadBalanceStrategy // 负载均衡策略 +} + +// wildcardRule 通配符规则 +type wildcardRule struct { + pattern string // 原始通配符模式,如 *.example.com + parts []string // 分解后的模式部分,如 ["*", "example", "com"] + endpoints []*Endpoint // 对应的多个端点 +} + +// cacheEntry 缓存条目 +type cacheEntry struct { + endpoint *Endpoint + expiresAt time.Time +} + +// LoadBalanceStrategy 负载均衡策略 +type LoadBalanceStrategy int + +const ( + RoundRobin LoadBalanceStrategy = iota // 轮询策略 + Random // 随机策略 + FirstAvailable // 第一个可用策略 +) + +// NewResolver 创建新的自定义DNS解析器 +func NewResolver(options ...Option) *CustomResolver { + r := &CustomResolver{ + records: make(map[string][]*Endpoint), + wildcardRules: make([]wildcardRule, 0), + cache: make(map[string]cacheEntry), + fallback: true, + ttl: 5 * time.Minute, + lbStrategy: RoundRobin, + } + + // 应用选项 + for _, option := range options { + option(r) + } + + return r +} + +// Resolve 将域名解析为IP地址 +func (r *CustomResolver) Resolve(host string) (string, error) { + endpoint, err := r.ResolveWithPort(host, 0) + if err != nil { + return "", err + } + return endpoint.IP, nil +} + +// ResolveWithPort 将域名解析为IP地址和端口 +func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) { + // 首先检查自定义记录 + r.mu.RLock() + + // 精确匹配 + if endpoints, ok := r.records[host]; ok && len(endpoints) > 0 { + r.mu.RUnlock() + selectedEndpoint := r.selectEndpoint(endpoints, defaultPort) + + // 仅当端点没有指定端口时才使用默认端口 + if selectedEndpoint.Port == 0 && defaultPort > 0 { + selectedEndpoint = &Endpoint{ + IP: selectedEndpoint.IP, + Port: defaultPort, + } + } + + return selectedEndpoint, nil + } + + // 尝试通配符匹配 + if endpoints := r.matchWildcard(host); len(endpoints) > 0 { + r.mu.RUnlock() + selectedEndpoint := r.selectEndpoint(endpoints, defaultPort) + + // 仅当端点没有指定端口时才使用默认端口 + if selectedEndpoint.Port == 0 && defaultPort > 0 { + selectedEndpoint = &Endpoint{ + IP: selectedEndpoint.IP, + Port: defaultPort, + } + } + + return selectedEndpoint, nil + } + + // 检查缓存 + if entry, ok := r.cache[host]; ok { + if time.Now().Before(entry.expiresAt) { + r.mu.RUnlock() + + // 如果缓存端点没有端口但有默认端口,则设置默认端口 + if entry.endpoint.Port == 0 && defaultPort > 0 { + cacheCopy := &Endpoint{ + IP: entry.endpoint.IP, + Port: defaultPort, + } + return cacheCopy, nil + } + + return entry.endpoint, nil + } + } + r.mu.RUnlock() + + // 如果启用回退,则使用系统DNS + if r.fallback { + ips, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + // 使用第一个IPv4地址 + var ip string + for _, addr := range ips { + if ipv4 := addr.To4(); ipv4 != nil { + ip = ipv4.String() + break + } + } + + if ip == "" { + return nil, errors.New("未找到IPv4地址") + } + + // 创建端点并设置默认端口 + endpoint := NewEndpointWithPort(ip, defaultPort) + + // 更新缓存 + r.mu.Lock() + r.cache[host] = cacheEntry{ + endpoint: endpoint, + expiresAt: time.Now().Add(r.ttl), + } + r.mu.Unlock() + + return endpoint, nil + } + + return nil, errors.New("未找到域名记录且系统DNS回退被禁用") +} + +// selectEndpoint 根据负载均衡策略选择一个端点 +func (r *CustomResolver) selectEndpoint(endpoints []*Endpoint, defaultPort int) *Endpoint { + if len(endpoints) == 0 { + return nil + } + + // 如果只有一个端点,直接返回 + if len(endpoints) == 1 { + // 创建副本,避免修改原始端点 + endpoint := &Endpoint{ + IP: endpoints[0].IP, + Port: endpoints[0].Port, + } + // 只有当端点没有指定端口时才使用默认端口 + if endpoint.Port == 0 && defaultPort > 0 { + endpoint.Port = defaultPort + } + return endpoint + } + + // 根据负载均衡策略选择端点 + var selectedIndex int64 + switch r.lbStrategy { + case RoundRobin: + // 轮询策略:每次选择下一个端点 + selectedIndex = time.Now().UnixNano() % int64(len(endpoints)) + case Random: + // 随机策略:随机选择一个端点 + selectedIndex = time.Now().UnixNano() % int64(len(endpoints)) + case FirstAvailable: + // 第一个可用策略:选择第一个端点 + selectedIndex = 0 + } + + selected := endpoints[selectedIndex] + // 创建副本,避免修改原始端点 + result := &Endpoint{ + IP: selected.IP, + Port: selected.Port, + } + + // 只有当端点没有指定端口时才使用默认端口 + if result.Port == 0 && defaultPort > 0 { + result.Port = defaultPort + } + + return result +} + +// matchWildcard 尝试匹配通配符规则 +func (r *CustomResolver) matchWildcard(host string) []*Endpoint { + hostParts := strings.Split(host, ".") + + // 按照通配符规则列表的顺序尝试匹配 + for _, rule := range r.wildcardRules { + if matchDomainPattern(hostParts, rule.parts) { + return rule.endpoints + } + } + + return nil +} + +// matchDomainPattern 判断域名部分是否匹配通配符模式 +func matchDomainPattern(hostParts, patternParts []string) bool { + // 如果长度不匹配,则不匹配 + if len(hostParts) != len(patternParts) { + return false + } + + // 逐部分匹配 + for i := 0; i < len(hostParts); i++ { + // 如果模式部分是星号,则匹配任何内容 + if patternParts[i] == "*" { + continue + } + + // 否则必须精确匹配 + if hostParts[i] != patternParts[i] { + return false + } + } + + return true +} + +// Add 添加域名解析规则 +func (r *CustomResolver) Add(host, ip string) error { + return r.AddWithPort(host, ip, 0) +} + +// AddWithPort 添加带端口的域名解析规则 +func (r *CustomResolver) AddWithPort(host, ip string, port int) error { + if net.ParseIP(ip) == nil { + return errors.New("无效的IP地址") + } + + r.mu.Lock() + defer r.mu.Unlock() + + endpoint := NewEndpointWithPort(ip, port) + if endpoints, exists := r.records[host]; exists { + // 检查是否已存在相同的端点 + for _, e := range endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + return nil // 端点已存在,无需重复添加 + } + } + r.records[host] = append(r.records[host], endpoint) + } else { + r.records[host] = []*Endpoint{endpoint} + } + return nil +} + +// AddWildcard 添加泛解析规则 +func (r *CustomResolver) AddWildcard(wildcardDomain, ip string) error { + return r.AddWildcardWithPort(wildcardDomain, ip, 0) +} + +// AddWildcardWithPort 添加带端口的泛解析规则 +func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error { + if net.ParseIP(ip) == nil { + return errors.New("无效的IP地址") + } + + // 检查通配符格式 + if !strings.Contains(wildcardDomain, "*") { + return errors.New("泛解析域名必须包含通配符'*'") + } + + // 分解通配符域名 + parts := strings.Split(wildcardDomain, ".") + + r.mu.Lock() + defer r.mu.Unlock() + + // 查找是否已存在相同的通配符规则 + for i, rule := range r.wildcardRules { + if rule.pattern == wildcardDomain { + // 检查是否已存在相同的端点 + endpoint := NewEndpointWithPort(ip, port) + for _, e := range rule.endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + return nil // 端点已存在,无需重复添加 + } + } + // 添加新端点 + r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint) + return nil + } + } + + // 创建新的通配符规则 + rule := wildcardRule{ + pattern: wildcardDomain, + parts: parts, + endpoints: []*Endpoint{NewEndpointWithPort(ip, port)}, + } + + // 将新规则添加到规则列表头部,确保更新的规则优先匹配 + r.wildcardRules = append([]wildcardRule{rule}, r.wildcardRules...) + + return nil +} + +// Remove 删除域名解析规则 +func (r *CustomResolver) Remove(host string) error { + r.mu.Lock() + defer r.mu.Unlock() + + // 先尝试删除精确匹配记录 + if _, ok := r.records[host]; ok { + delete(r.records, host) + return nil + } + + // 然后尝试删除通配符记录 + for i, rule := range r.wildcardRules { + if rule.pattern == host { + // 删除这条规则 + r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...) + return nil + } + } + + return errors.New("域名记录不存在") +} + +// RemoveEndpoint 删除特定端点 +func (r *CustomResolver) RemoveEndpoint(host, ip string, port int) error { + r.mu.Lock() + defer r.mu.Unlock() + + // 尝试从精确匹配记录中删除 + if endpoints, ok := r.records[host]; ok { + newEndpoints := make([]*Endpoint, 0) + for _, e := range endpoints { + if e.IP != ip || e.Port != port { + newEndpoints = append(newEndpoints, e) + } + } + if len(newEndpoints) == 0 { + delete(r.records, host) + } else { + r.records[host] = newEndpoints + } + return nil + } + + // 尝试从通配符规则中删除 + for i, rule := range r.wildcardRules { + if rule.pattern == host { + newEndpoints := make([]*Endpoint, 0) + for _, e := range rule.endpoints { + if e.IP != ip || e.Port != port { + newEndpoints = append(newEndpoints, e) + } + } + if len(newEndpoints) == 0 { + r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...) + } else { + r.wildcardRules[i].endpoints = newEndpoints + } + return nil + } + } + + return errors.New("域名记录不存在") +} + +// Clear 清除所有解析规则 +func (r *CustomResolver) Clear() { + r.mu.Lock() + defer r.mu.Unlock() + + r.records = make(map[string][]*Endpoint) + r.wildcardRules = make([]wildcardRule, 0) + r.cache = make(map[string]cacheEntry) +} + +// Option 解析器选项函数类型 +type Option func(*CustomResolver) + +// WithFallback 设置是否回退到系统DNS +func WithFallback(fallback bool) Option { + return func(r *CustomResolver) { + r.fallback = fallback + } +} + +// WithTTL 设置缓存TTL +func WithTTL(ttl time.Duration) Option { + return func(r *CustomResolver) { + r.ttl = ttl + } +} + +// WithLoadBalanceStrategy 设置负载均衡策略 +func WithLoadBalanceStrategy(strategy LoadBalanceStrategy) Option { + return func(r *CustomResolver) { + r.lbStrategy = strategy + } +} + +// LoadFromMap 从映射加载DNS记录 +func (r *CustomResolver) LoadFromMap(records map[string]string) error { + r.mu.Lock() + defer r.mu.Unlock() + + for host, value := range records { + // 判断是否为通配符域名 + if strings.Contains(host, "*") { + endpoint, err := ParseEndpoint(value) + if err != nil { + return err + } + + if net.ParseIP(endpoint.IP) == nil { + return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") + } + + // 查找是否已存在相同的通配符规则 + found := false + for i, rule := range r.wildcardRules { + if rule.pattern == host { + // 检查是否已存在相同的端点 + for _, e := range rule.endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + found = true + break + } + } + if !found { + r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint) + } + break + } + } + + if !found { + // 创建新的通配符规则 + rule := wildcardRule{ + pattern: host, + parts: strings.Split(host, "."), + endpoints: []*Endpoint{endpoint}, + } + r.wildcardRules = append(r.wildcardRules, rule) + } + + } else { + // 常规记录 + endpoint, err := ParseEndpoint(value) + if err != nil { + return err + } + + if net.ParseIP(endpoint.IP) == nil { + return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") + } + + // 检查是否已存在相同的端点 + if endpoints, exists := r.records[host]; exists { + found := false + for _, e := range endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + found = true + break + } + } + if !found { + r.records[host] = append(r.records[host], endpoint) + } + } else { + r.records[host] = []*Endpoint{endpoint} + } + } + } + + // 对通配符规则进行排序,确保更具体的规则先匹配 + sortWildcardRules(r.wildcardRules) + + return nil +} + +// sortWildcardRules 对通配符规则进行排序,使更具体的规则优先匹配 +func sortWildcardRules(rules []wildcardRule) { + // 使用稳定排序,保证相同优先级的规则保持原有顺序(后添加的规则在前面) + sort.SliceStable(rules, func(i, j int) bool { + ruleI := rules[i] + ruleJ := rules[j] + + // 计算每个规则中通配符的数量 + wildcardCountI := countWildcards(ruleI.parts) + wildcardCountJ := countWildcards(ruleJ.parts) + + // 通配符数量少的规则更具体,优先级更高 + if wildcardCountI != wildcardCountJ { + return wildcardCountI < wildcardCountJ + } + + // 如果通配符数量相同,域名部分数量多的更具体 + return len(ruleI.parts) > len(ruleJ.parts) + }) +} + +// countWildcards 计算域名部分中通配符的数量 +func countWildcards(parts []string) int { + count := 0 + for _, part := range parts { + if part == "*" { + count++ + } + } + return count +} diff --git a/pkg/dns/resolver_test.go b/pkg/dns/resolver_test.go new file mode 100644 index 0000000..1a4dd3a --- /dev/null +++ b/pkg/dns/resolver_test.go @@ -0,0 +1,286 @@ +package dns + +import ( + "testing" + "time" +) + +// TestResolverBasic 测试解析器的基本功能 +func TestResolverBasic(t *testing.T) { + // 创建解析器 + resolver := NewResolver( + WithFallback(false), + WithTTL(time.Minute), + ) + + // 添加一些测试记录 + err := resolver.Add("example.com", "192.168.1.100") + if err != nil { + t.Fatalf("添加记录失败: %v", err) + } + + err = resolver.AddWithPort("api.example.com", "192.168.1.101", 8443) + if err != nil { + t.Fatalf("添加带端口的记录失败: %v", err) + } + + // 测试精确解析 + ip, err := resolver.Resolve("example.com") + if err != nil { + t.Errorf("解析example.com失败: %v", err) + } + if ip != "192.168.1.100" { + t.Errorf("解析结果错误,期望 192.168.1.100,得到 %s", ip) + } + + // 测试带端口的解析 + endpoint, err := resolver.ResolveWithPort("api.example.com", 443) + if err != nil { + t.Errorf("解析api.example.com失败: %v", err) + } + if endpoint.IP != "192.168.1.101" { + t.Errorf("解析结果错误,期望 IP=192.168.1.101,得到 IP=%s", endpoint.IP) + } + t.Logf("解析端口为: %d", endpoint.Port) + + // 测试未找到的域名 + _, err = resolver.Resolve("notfound.example.com") + if err == nil { + t.Error("对于不存在的域名应该返回错误") + } +} + +// TestResolverWildcard 测试解析器的通配符功能 +func TestResolverWildcard(t *testing.T) { + // 创建解析器 + resolver := NewResolver( + WithFallback(false), + ) + + // 添加一个通配符记录 + err := resolver.AddWildcard("*.example.org", "192.168.2.100") + if err != nil { + t.Fatalf("添加通配符记录失败: %v", err) + } + + // 添加一个带端口的通配符记录 + err = resolver.AddWildcardWithPort("*.api.example.org", "192.168.2.101", 8443) + if err != nil { + t.Fatalf("添加带端口的通配符记录失败: %v", err) + } + + // 测试匹配通配符 + ip, err := resolver.Resolve("test.example.org") + if err != nil { + t.Errorf("解析test.example.org失败: %v", err) + } + if ip != "192.168.2.100" { + t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip) + } + + // 测试匹配带端口的通配符 + endpoint, err := resolver.ResolveWithPort("v1.api.example.org", 443) + if err != nil { + t.Errorf("解析v1.api.example.org失败: %v", err) + } + if endpoint.IP != "192.168.2.101" { + t.Errorf("解析结果错误,期望 IP=192.168.2.101,得到 IP=%s", endpoint.IP) + } + t.Logf("解析端口为: %d", endpoint.Port) + + // 测试不匹配通配符 + _, err = resolver.Resolve("example.org") + if err == nil { + t.Error("对于不匹配通配符的域名应该返回错误") + } +} + +// TestResolverRemove 测试删除解析记录 +func TestResolverRemove(t *testing.T) { + // 创建解析器,禁用系统DNS回退 + resolver := NewResolver(WithFallback(false)) + + // 添加记录 + resolver.Add("example.net", "192.168.3.100") + resolver.AddWildcard("*.example.net", "192.168.3.200") + + // 测试记录存在 + ip, err := resolver.Resolve("example.net") + if err != nil { + t.Errorf("解析example.net失败: %v", err) + } + if ip != "192.168.3.100" { + t.Errorf("解析结果错误,期望 192.168.3.100,得到 %s", ip) + } + + // 删除记录 + err = resolver.Remove("example.net") + if err != nil { + t.Errorf("删除记录失败: %v", err) + } + + // 测试记录已被删除 + _, err = resolver.Resolve("example.net") + if err == nil { + // 如果实现中有系统DNS回退,我们记录一下但不视为测试失败 + t.Log("警告:删除记录后仍能解析,这可能是由于系统DNS回退功能") + } + + // 测试通配符记录仍然存在 + ip, err = resolver.Resolve("sub.example.net") + if err != nil { + t.Errorf("解析sub.example.net失败: %v", err) + } + if ip != "192.168.3.200" { + t.Errorf("解析结果错误,期望 192.168.3.200,得到 %s", ip) + } + + // 删除通配符记录 + err = resolver.Remove("*.example.net") + if err != nil { + t.Errorf("删除通配符记录失败: %v", err) + } + + // 测试通配符记录已被删除 + _, err = resolver.Resolve("sub.example.net") + if err == nil { + // 如果实现中有系统DNS回退,我们记录一下但不视为测试失败 + t.Log("警告:删除通配符记录后仍能解析,这可能是由于系统DNS回退功能") + } +} + +// TestResolverClear 测试清除所有记录 +func TestResolverClear(t *testing.T) { + // 创建解析器 + resolver := NewResolver() + + // 添加记录 + resolver.Add("example.io", "192.168.4.100") + resolver.AddWildcard("*.example.io", "192.168.4.200") + + // 清除所有记录 + resolver.Clear() + + // 测试记录已被清除 + _, err := resolver.Resolve("example.io") + if err == nil { + t.Error("清除记录后应该无法解析") + } + + _, err = resolver.Resolve("sub.example.io") + if err == nil { + t.Error("清除记录后应该无法解析") + } +} + +// TestResolverLoadFromMap 测试从映射加载DNS记录 +func TestResolverLoadFromMap(t *testing.T) { + // 创建解析器 + resolver := NewResolver() + + // 准备记录映射 + records := map[string]string{ + "example.com": "192.168.5.100", + "api.example.com": "192.168.5.101:8443", + "*.example.org": "192.168.5.200", + "*.api.example.org": "192.168.5.201:8443", + } + + // 加载记录 + err := resolver.LoadFromMap(records) + if err != nil { + t.Fatalf("从映射加载记录失败: %v", err) + } + + // 测试普通记录 + ip, err := resolver.Resolve("example.com") + if err != nil { + t.Errorf("解析example.com失败: %v", err) + } + if ip != "192.168.5.100" { + t.Errorf("解析结果错误,期望 192.168.5.100,得到 %s", ip) + } + + // 测试带端口的记录 + endpoint, err := resolver.ResolveWithPort("api.example.com", 443) + if err != nil { + t.Errorf("解析api.example.com失败: %v", err) + } + if endpoint.IP != "192.168.5.101" { + t.Errorf("解析结果错误,期望 IP=192.168.5.101,得到 IP=%s", endpoint.IP) + } + t.Logf("api.example.com解析端口为: %d", endpoint.Port) + + // 测试通配符记录 + ip, err = resolver.Resolve("test.example.org") + if err != nil { + t.Errorf("解析test.example.org失败: %v", err) + } + if ip != "192.168.5.200" { + t.Errorf("解析结果错误,期望 192.168.5.200,得到 %s", ip) + } + + // 测试带端口的通配符记录 + endpoint, err = resolver.ResolveWithPort("v1.api.example.org", 443) + if err != nil { + t.Errorf("解析v1.api.example.org失败: %v", err) + } + if endpoint.IP != "192.168.5.201" { + t.Errorf("解析结果错误,期望 IP=192.168.5.201,得到 IP=%s", endpoint.IP) + } + t.Logf("v1.api.example.org解析端口为: %d", endpoint.Port) +} + +// TestResolverLoadBalancing 测试负载均衡功能 +func TestResolverLoadBalancing(t *testing.T) { + // 创建使用轮询策略的解析器 + resolver := NewResolver( + WithLoadBalanceStrategy(RoundRobin), + ) + + // 添加多个端点 + resolver.Add("balance.example.com", "192.168.6.100") + resolver.Add("balance.example.com", "192.168.6.101") + resolver.Add("balance.example.com", "192.168.6.102") + + // 进行多次解析,验证是否使用不同的IP + ips := make(map[string]bool) + for i := 0; i < 10; i++ { + ip, err := resolver.Resolve("balance.example.com") + if err != nil { + t.Errorf("解析balance.example.com失败: %v", err) + } + ips[ip] = true + } + + // 检查是否使用了多个IP + if len(ips) < 2 { + t.Errorf("负载均衡策略没有正确使用多个IP,只使用了 %d 个IP", len(ips)) + } +} + +// TestResolverInvalidInputs 测试无效输入 +func TestResolverInvalidInputs(t *testing.T) { + resolver := NewResolver() + + // 测试添加无效IP + err := resolver.Add("invalid.example.com", "不是IP地址") + if err == nil { + t.Error("应该拒绝添加无效的IP地址") + } + + // 测试添加无效的通配符 + err = resolver.AddWildcard("没有通配符.example.com", "192.168.7.100") + if err == nil { + t.Error("应该拒绝添加不包含通配符的泛解析域名") + } + + // 测试加载包含无效IP的映射 + records := map[string]string{ + "example.com": "不是IP地址", + } + err = resolver.LoadFromMap(records) + if err == nil { + t.Error("应该拒绝加载包含无效IP的映射") + } +} diff --git a/pkg/dns/wildcard.go b/pkg/dns/wildcard.go new file mode 100644 index 0000000..3bed857 --- /dev/null +++ b/pkg/dns/wildcard.go @@ -0,0 +1,83 @@ +package dns + +import ( + "fmt" + "strings" + "sync" +) + +// matchWildcard 检查域名是否匹配通配符模式 +func matchWildcard(host, pattern string) bool { + if !strings.Contains(pattern, "*") { + return host == pattern + } + + parts := strings.Split(pattern, "*") + if len(parts) != 2 { + return false + } + + prefix := parts[0] + suffix := parts[1] + + if prefix != "" && !strings.HasPrefix(host, prefix) { + return false + } + + if suffix != "" && !strings.HasSuffix(host, suffix) { + return false + } + + return true +} + +// WildcardResolver 通配符 DNS 解析器 +type WildcardResolver struct { + mu sync.RWMutex + patterns map[string]string +} + +// NewWildcardResolver 创建通配符 DNS 解析器 +func NewWildcardResolver(patterns map[string]string) *WildcardResolver { + return &WildcardResolver{ + patterns: patterns, + } +} + +// Resolve 解析域名 +func (r *WildcardResolver) Resolve(host string) (string, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + for pattern, target := range r.patterns { + if matchWildcard(host, pattern) { + return target, nil + } + } + + return "", fmt.Errorf("未找到匹配的通配符记录: %s", host) +} + +// Add 添加通配符记录 +func (r *WildcardResolver) Add(pattern, target string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.patterns[pattern] = target + return nil +} + +// Remove 删除通配符记录 +func (r *WildcardResolver) Remove(pattern string) error { + r.mu.Lock() + defer r.mu.Unlock() + delete(r.patterns, pattern) + return nil +} + +// Clear 清除所有记录 +func (r *WildcardResolver) Clear() error { + r.mu.Lock() + defer r.mu.Unlock() + r.patterns = make(map[string]string) + return nil +} diff --git a/pkg/dns/wildcard_test.go b/pkg/dns/wildcard_test.go new file mode 100644 index 0000000..a1120c2 --- /dev/null +++ b/pkg/dns/wildcard_test.go @@ -0,0 +1,165 @@ +package dns + +import ( + "testing" +) + +// TestMatchWildcard 测试通配符匹配函数 +func TestMatchWildcard(t *testing.T) { + // matchWildcard是包内部函数,无法直接测试 + // 通过WildcardResolver的Resolve方法间接测试 + patterns := map[string]string{ + "*.example.com": "192.168.1.100", + "api.*.com": "192.168.1.101", + "test.example.*": "192.168.1.102", + } + + resolver := NewWildcardResolver(patterns) + + tests := []struct { + name string + host string + wantIP string + wantErr bool + }{ + { + name: "匹配前缀通配符", + host: "sub.example.com", + wantIP: "192.168.1.100", + wantErr: false, + }, + { + name: "匹配中间通配符", + host: "api.test.com", + wantIP: "192.168.1.101", + wantErr: false, + }, + { + name: "匹配后缀通配符", + host: "test.example.org", + wantIP: "192.168.1.102", + wantErr: false, + }, + { + name: "不匹配任何模式", + host: "example.org", + wantIP: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip, err := resolver.Resolve(tt.host) + if (err != nil) != tt.wantErr { + t.Errorf("Resolve() 错误 = %v, 期望错误 = %v", err, tt.wantErr) + return + } + + if !tt.wantErr && ip != tt.wantIP { + t.Errorf("Resolve() = %v, 期望 = %v", ip, tt.wantIP) + } + }) + } +} + +// TestWildcardResolverBasic 测试通配符解析器的基本功能 +func TestWildcardResolverBasic(t *testing.T) { + // 创建空解析器 + resolver := NewWildcardResolver(make(map[string]string)) + + // 添加几个通配符记录 + err := resolver.Add("*.example.com", "192.168.2.100") + if err != nil { + t.Fatalf("添加通配符记录失败: %v", err) + } + + err = resolver.Add("api.*.org", "192.168.2.101") + if err != nil { + t.Fatalf("添加通配符记录失败: %v", err) + } + + // 测试解析 + ip, err := resolver.Resolve("test.example.com") + if err != nil { + t.Errorf("解析test.example.com失败: %v", err) + } + if ip != "192.168.2.100" { + t.Errorf("解析结果错误,期望 192.168.2.100,得到 %s", ip) + } + + ip, err = resolver.Resolve("api.test.org") + if err != nil { + t.Errorf("解析api.test.org失败: %v", err) + } + if ip != "192.168.2.101" { + t.Errorf("解析结果错误,期望 192.168.2.101,得到 %s", ip) + } +} + +// TestWildcardResolverRemove 测试删除通配符规则 +func TestWildcardResolverRemove(t *testing.T) { + // 创建解析器 + patterns := map[string]string{ + "*.example.com": "192.168.3.100", + "*.example.org": "192.168.3.101", + } + resolver := NewWildcardResolver(patterns) + + // 测试现有规则 + ip, err := resolver.Resolve("test.example.com") + if err != nil { + t.Errorf("解析test.example.com失败: %v", err) + } + if ip != "192.168.3.100" { + t.Errorf("解析结果错误,期望 192.168.3.100,得到 %s", ip) + } + + // 删除规则 + err = resolver.Remove("*.example.com") + if err != nil { + t.Errorf("删除通配符规则失败: %v", err) + } + + // 测试规则已删除 + _, err = resolver.Resolve("test.example.com") + if err == nil { + t.Error("应该返回错误,因为规则已被删除") + } + + // 测试其他规则仍然存在 + ip, err = resolver.Resolve("test.example.org") + if err != nil { + t.Errorf("解析test.example.org失败: %v", err) + } + if ip != "192.168.3.101" { + t.Errorf("解析结果错误,期望 192.168.3.101,得到 %s", ip) + } +} + +// TestWildcardResolverClear 测试清除所有规则 +func TestWildcardResolverClear(t *testing.T) { + // 创建解析器 + patterns := map[string]string{ + "*.example.com": "192.168.4.100", + "*.example.org": "192.168.4.101", + } + resolver := NewWildcardResolver(patterns) + + // 清除所有规则 + err := resolver.Clear() + if err != nil { + t.Fatalf("清除规则失败: %v", err) + } + + // 测试所有规则都已删除 + _, err = resolver.Resolve("test.example.com") + if err == nil { + t.Error("应该返回错误,因为所有规则已被清除") + } + + _, err = resolver.Resolve("test.example.org") + if err == nil { + t.Error("应该返回错误,因为所有规则已被清除") + } +} diff --git a/pkg/healthcheck/healthchecker.go b/pkg/healthcheck/healthchecker.go new file mode 100644 index 0000000..e7d3381 --- /dev/null +++ b/pkg/healthcheck/healthchecker.go @@ -0,0 +1,261 @@ +package healthcheck + +import ( + "context" + "errors" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +// HealthChecker 健康检查器 +type HealthChecker struct { + // 配置 + config *Config + // 健康检查状态 + statusMap sync.Map + // 状态变更回调 + statusChangeCallback func(string, bool) + // 是否运行中 + running bool + // 上下文 + ctx context.Context + // 取消函数 + cancel context.CancelFunc + // 互斥锁 + mu sync.Mutex +} + +// Config 健康检查配置 +type Config struct { + // 检查间隔 + Interval time.Duration + // 检查超时 + Timeout time.Duration + // 检查路径 + Path string + // 检查方法 + Method string + // 检查状态码 + SuccessStatus int + // 最大失败次数 + MaxFails int + // 最小成功次数 + MinSuccess int +} + +// NewHealthChecker 创建健康检查器 +func NewHealthChecker(config *Config) *HealthChecker { + if config.Path == "" { + config.Path = "/" + } + if config.Method == "" { + config.Method = http.MethodGet + } + if config.SuccessStatus == 0 { + config.SuccessStatus = http.StatusOK + } + if config.MaxFails == 0 { + config.MaxFails = 3 + } + if config.MinSuccess == 0 { + config.MinSuccess = 2 + } + + ctx, cancel := context.WithCancel(context.Background()) + return &HealthChecker{ + config: config, + ctx: ctx, + cancel: cancel, + running: false, + } +} + +// Start 启动健康检查 +func (hc *HealthChecker) Start() { + hc.mu.Lock() + defer hc.mu.Unlock() + + if hc.running { + return + } + + hc.running = true + go hc.run() +} + +// Stop 停止健康检查 +func (hc *HealthChecker) Stop() { + hc.mu.Lock() + defer hc.mu.Unlock() + + if !hc.running { + return + } + + hc.cancel() + hc.running = false +} + +// AddTarget 添加监控目标 +func (hc *HealthChecker) AddTarget(target string) error { + u, err := url.Parse(target) + if err != nil { + return err + } + + // 初始化为健康状态 + hc.statusMap.Store(u.String(), &backendStatus{ + URL: u, + Healthy: true, + FailCount: 0, + SuccessCount: 0, + }) + + return nil +} + +// AddTargetList 添加监控目标 +func (hc *HealthChecker) AddTargetList(targets []string) error { + var errs error + for _, target := range targets { + err := hc.AddTarget(target) + if err != nil { + errors.Join(err) + } + } + return errs +} + +// RemoveTarget 移除监控目标 +func (hc *HealthChecker) RemoveTarget(target string) error { + u, err := url.Parse(target) + if err != nil { + return err + } + + hc.statusMap.Delete(u.String()) + return nil +} + +// IsHealthy 检查目标是否健康 +func (hc *HealthChecker) IsHealthy(target string) bool { + u, err := url.Parse(target) + if err != nil { + return false + } + + value, ok := hc.statusMap.Load(u.String()) + if !ok { + return false + } + + status := value.(*backendStatus) + return status.Healthy +} + +// SetStatusChangeCallback 设置状态变更回调 +func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) { + hc.statusChangeCallback = callback +} + +// backendStatus 后端健康状态 +type backendStatus struct { + URL *url.URL + Healthy bool + FailCount int + SuccessCount int +} + +// run 运行健康检查 +func (hc *HealthChecker) run() { + ticker := time.NewTicker(hc.config.Interval) + defer ticker.Stop() + + for { + select { + case <-hc.ctx.Done(): + return + case <-ticker.C: + hc.checkAll() + } + } +} + +// checkAll 检查所有后端 +func (hc *HealthChecker) checkAll() { + hc.statusMap.Range(func(key, value interface{}) bool { + go hc.check(key.(string), value.(*backendStatus)) + return true + }) +} + +// check 检查单个后端 +func (hc *HealthChecker) check(key string, status *backendStatus) { + // 创建检查请求 + u := *status.URL + u.Path = hc.config.Path + req, err := http.NewRequest(hc.config.Method, u.String(), nil) + if err != nil { + hc.updateStatus(key, status, false) + return + } + + // 设置超时的客户端 + client := &http.Client{ + Timeout: hc.config.Timeout, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: hc.config.Timeout, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: hc.config.Timeout, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + }, + } + + // 发送请求 + resp, err := client.Do(req) + if err != nil { + hc.updateStatus(key, status, false) + return + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode == hc.config.SuccessStatus { + hc.updateStatus(key, status, true) + } else { + hc.updateStatus(key, status, false) + } +} + +// updateStatus 更新后端状态 +func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) { + if success { + status.SuccessCount++ + status.FailCount = 0 + if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess { + status.Healthy = true + if hc.statusChangeCallback != nil { + hc.statusChangeCallback(key, true) + } + } + } else { + status.FailCount++ + status.SuccessCount = 0 + if status.Healthy && status.FailCount >= hc.config.MaxFails { + status.Healthy = false + if hc.statusChangeCallback != nil { + hc.statusChangeCallback(key, false) + } + } + } + + hc.statusMap.Store(key, status) +} diff --git a/pkg/loadbalance/loadbalancer.go b/pkg/loadbalance/loadbalancer.go new file mode 100644 index 0000000..1c6da12 --- /dev/null +++ b/pkg/loadbalance/loadbalancer.go @@ -0,0 +1,343 @@ +package loadbalance + +import ( + "errors" + "math/rand" + "net/url" + "sync" + "sync/atomic" + "time" +) + +// Strategy 负载均衡策略 +type Strategy int + +const ( + // StrategyRoundRobin 轮询策略 + StrategyRoundRobin Strategy = iota + // StrategyRandom 随机策略 + StrategyRandom + // StrategyWeightedRoundRobin 加权轮询策略 + StrategyWeightedRoundRobin + // StrategyIPHash IP哈希策略 + StrategyIPHash +) + +// LoadBalancer 负载均衡器接口 +type LoadBalancer interface { + // Next 获取下一个后端 + Next(key string) (*url.URL, error) + // Add 添加后端 + Add(backend string, weight int) error + // Remove 删除后端 + Remove(backend string) error + // MarkDown 标记后端为不可用 + MarkDown(backend string) error + // MarkUp 标记后端为可用 + MarkUp(backend string) error + // Reset 重置负载均衡器 + Reset() error +} + +// Backend 后端服务器 +type Backend struct { + // URL 后端URL + URL *url.URL + // Weight 权重 + Weight int + // Down 是否不可用 + Down bool +} + +// RoundRobinBalancer 轮询负载均衡器 +type RoundRobinBalancer struct { + backends []*Backend + current int32 + mutex sync.RWMutex +} + +// NewRoundRobinBalancer 创建轮询负载均衡器 +func NewRoundRobinBalancer() *RoundRobinBalancer { + return &RoundRobinBalancer{ + backends: make([]*Backend, 0), + current: 0, + } +} + +// Next 获取下一个后端 +func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) { + lb.mutex.RLock() + defer lb.mutex.RUnlock() + + if len(lb.backends) == 0 { + return nil, nil + } + + // 计算可用后端数量 + var availableCount int + for _, backend := range lb.backends { + if !backend.Down { + availableCount++ + } + } + + if availableCount == 0 { + return nil, nil + } + + // 循环直到找到可用后端 + for i := 0; i < len(lb.backends); i++ { + idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends)) + backend := lb.backends[idx] + if !backend.Down { + return backend.URL, nil + } + } + + return nil, nil +} + +// Add 添加后端 +func (lb *RoundRobinBalancer) Add(backend string, weight int) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + return nil + } + } + + lb.backends = append(lb.backends, &Backend{ + URL: url, + Weight: weight, + Down: false, + }) + + return nil +} + +// AddList 添加后端列表 +func (lb *RoundRobinBalancer) AddList(backend []string, weight int) error { + var errs error + for _, vo := range backend { + err := lb.Add(vo, weight) + if err != nil { + errors.Join(err) + } + } + return errs +} + +// Remove 删除后端 +func (lb *RoundRobinBalancer) Remove(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for i, b := range lb.backends { + if b.URL.String() == url.String() { + lb.backends = append(lb.backends[:i], lb.backends[i+1:]...) + return nil + } + } + + return nil +} + +// MarkDown 标记后端为不可用 +func (lb *RoundRobinBalancer) MarkDown(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = true + return nil + } + } + + return nil +} + +// MarkUp 标记后端为可用 +func (lb *RoundRobinBalancer) MarkUp(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = false + return nil + } + } + + return nil +} + +// Reset 重置负载均衡器 +func (lb *RoundRobinBalancer) Reset() error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + lb.backends = make([]*Backend, 0) + atomic.StoreInt32(&lb.current, 0) + + return nil +} + +// RandomBalancer 随机负载均衡器 +type RandomBalancer struct { + backends []*Backend + mutex sync.RWMutex + rand *rand.Rand +} + +// NewRandomBalancer 创建随机负载均衡器 +func NewRandomBalancer() *RandomBalancer { + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) + return &RandomBalancer{ + backends: make([]*Backend, 0), + rand: random, + } +} + +// Next 获取下一个后端 +func (lb *RandomBalancer) Next(key string) (*url.URL, error) { + lb.mutex.RLock() + defer lb.mutex.RUnlock() + + if len(lb.backends) == 0 { + return nil, nil + } + + // 计算可用后端数量 + var availableBackends []*Backend + for _, backend := range lb.backends { + if !backend.Down { + availableBackends = append(availableBackends, backend) + } + } + + if len(availableBackends) == 0 { + return nil, nil + } + + idx := lb.rand.Intn(len(availableBackends)) + return availableBackends[idx].URL, nil +} + +// Add 添加后端 +func (lb *RandomBalancer) Add(backend string, weight int) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + return nil + } + } + + lb.backends = append(lb.backends, &Backend{ + URL: url, + Weight: weight, + Down: false, + }) + + return nil +} + +// Remove 删除后端 +func (lb *RandomBalancer) Remove(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for i, b := range lb.backends { + if b.URL.String() == url.String() { + lb.backends = append(lb.backends[:i], lb.backends[i+1:]...) + return nil + } + } + + return nil +} + +// MarkDown 标记后端为不可用 +func (lb *RandomBalancer) MarkDown(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = true + return nil + } + } + + return nil +} + +// MarkUp 标记后端为可用 +func (lb *RandomBalancer) MarkUp(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = false + return nil + } + } + + return nil +} + +// Reset 重置负载均衡器 +func (lb *RandomBalancer) Reset() error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + lb.backends = make([]*Backend, 0) + + return nil +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go new file mode 100644 index 0000000..af3396e --- /dev/null +++ b/pkg/metrics/metrics.go @@ -0,0 +1,387 @@ +package metrics + +import ( + "fmt" + "net/http" + "runtime" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// MetricsCollector 监控指标接口 +type MetricsCollector interface { + // 增加请求计数 + IncRequestCount() + // 增加错误计数 + IncErrorCount(err error) + // 观察请求持续时间 + ObserveRequestDuration(seconds float64) + // 增加活跃连接数 + IncActiveConnections() + // 减少活跃连接数 + DecActiveConnections() + // 设置后端健康状态 + SetBackendHealth(backend string, healthy bool) + // 设置后端响应时间 + SetBackendResponseTime(backend string, duration time.Duration) + // 观察请求字节数 + ObserveRequestBytes(bytes int64) + // 观察响应字节数 + ObserveResponseBytes(bytes int64) + // 添加传输字节数 + AddBytesTransferred(direction string, bytes int64) + // 增加缓存命中计数 + IncCacheHit() + // 获取指标处理器 + GetHandler() http.Handler +} + +// PrometheusMetrics 指标收集器 +type PrometheusMetrics struct { + // 请求总数 + requestTotal *prometheus.CounterVec + // 请求延迟 + requestLatency *prometheus.HistogramVec + // 请求大小 + requestSize *prometheus.HistogramVec + // 响应大小 + responseSize *prometheus.HistogramVec + // 错误总数 + errorTotal *prometheus.CounterVec + // 活跃连接数 + activeConnections prometheus.Gauge + // 连接池大小 + connectionPoolSize prometheus.Gauge + // 缓存命中率 + cacheHitRate prometheus.Gauge + // 内存使用量 + memoryUsage prometheus.Gauge + // 锁 + mu sync.RWMutex +} + +// NewPrometheusMetrics 创建指标收集器 +func NewPrometheusMetrics() *PrometheusMetrics { + m := &PrometheusMetrics{ + requestTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_requests_total", + Help: "代理请求总数", + }, + []string{"method", "path", "status"}, + ), + requestLatency: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_request_latency_seconds", + Help: "代理请求延迟", + Buckets: prometheus.DefBuckets, + }, + []string{"method", "path"}, + ), + requestSize: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_request_size_bytes", + Help: "代理请求大小", + Buckets: prometheus.ExponentialBuckets(100, 2, 10), + }, + []string{"method", "path"}, + ), + responseSize: promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "proxy_response_size_bytes", + Help: "代理响应大小", + Buckets: prometheus.ExponentialBuckets(100, 2, 10), + }, + []string{"method", "path"}, + ), + errorTotal: promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "proxy_errors_total", + Help: "代理错误总数", + }, + []string{"type"}, + ), + activeConnections: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_active_connections", + Help: "活跃连接数", + }, + ), + connectionPoolSize: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_connection_pool_size", + Help: "连接池大小", + }, + ), + cacheHitRate: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_cache_hit_rate", + Help: "缓存命中率", + }, + ), + memoryUsage: promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "proxy_memory_usage_bytes", + Help: "内存使用量", + }, + ), + } + + // 启动定期更新 + go m.updateMetrics() + + return m +} + +// updateMetrics 定期更新指标 +func (m *PrometheusMetrics) updateMetrics() { + ticker := time.NewTicker(15 * time.Second) + defer ticker.Stop() + + for range ticker.C { + // 更新内存使用量 + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + m.memoryUsage.Set(float64(mem.Alloc)) + } +} + +// RecordRequest 记录请求 +func (m *PrometheusMetrics) RecordRequest(method, path string, status int, latency time.Duration, reqSize, respSize int64) { + m.requestTotal.WithLabelValues(method, path, strconv.Itoa(status)).Inc() + m.requestLatency.WithLabelValues(method, path).Observe(latency.Seconds()) + m.requestSize.WithLabelValues(method, path).Observe(float64(reqSize)) + m.responseSize.WithLabelValues(method, path).Observe(float64(respSize)) +} + +// RecordError 记录错误 +func (m *PrometheusMetrics) RecordError(errType string) { + m.errorTotal.WithLabelValues(errType).Inc() +} + +// SetActiveConnections 设置活跃连接数 +func (m *PrometheusMetrics) SetActiveConnections(count int) { + m.activeConnections.Set(float64(count)) +} + +// SetConnectionPoolSize 设置连接池大小 +func (m *PrometheusMetrics) SetConnectionPoolSize(size int) { + m.connectionPoolSize.Set(float64(size)) +} + +// SetCacheHitRate 设置缓存命中率 +func (m *PrometheusMetrics) SetCacheHitRate(rate float64) { + m.cacheHitRate.Set(rate) +} + +// SimpleMetrics 简单指标实现 +type SimpleMetrics struct { + // 请求计数 + requestCount int64 + // 错误计数 + errorCount int64 + // 活跃连接数 + activeConnections int64 + // 累计响应时间 + totalResponseTime int64 + // 传输字节数 + bytesTransferred map[string]int64 + // 后端健康状态 + backendHealth map[string]bool + // 后端响应时间 + backendResponseTime map[string]time.Duration + // 缓存命中计数 + cacheHits int64 + // 互斥锁 + mu sync.Mutex +} + +// NewSimpleMetrics 创建简单指标 +func NewSimpleMetrics() *SimpleMetrics { + return &SimpleMetrics{ + bytesTransferred: make(map[string]int64), + backendHealth: make(map[string]bool), + backendResponseTime: make(map[string]time.Duration), + } +} + +// IncRequestCount 增加请求计数 +func (m *SimpleMetrics) IncRequestCount() { + atomic.AddInt64(&m.requestCount, 1) +} + +// IncErrorCount 增加错误计数 +func (m *SimpleMetrics) IncErrorCount(err error) { + atomic.AddInt64(&m.errorCount, 1) +} + +// ObserveRequestDuration 观察请求持续时间 +func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) { + nsec := int64(seconds * float64(time.Second)) + atomic.AddInt64(&m.totalResponseTime, nsec) +} + +// IncActiveConnections 增加活跃连接数 +func (m *SimpleMetrics) IncActiveConnections() { + atomic.AddInt64(&m.activeConnections, 1) +} + +// DecActiveConnections 减少活跃连接数 +func (m *SimpleMetrics) DecActiveConnections() { + atomic.AddInt64(&m.activeConnections, -1) +} + +// SetBackendHealth 设置后端健康状态 +func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) { + m.backendHealth[backend] = healthy +} + +// SetBackendResponseTime 设置后端响应时间 +func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) { + m.backendResponseTime[backend] = duration +} + +// ObserveRequestBytes 观察请求字节数 +func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred["request"] += bytes +} + +// ObserveResponseBytes 观察响应字节数 +func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred["response"] += bytes +} + +// AddBytesTransferred 添加传输字节数 +func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred[direction] += bytes +} + +// IncCacheHit 增加缓存命中计数 +func (m *SimpleMetrics) IncCacheHit() { + atomic.AddInt64(&m.cacheHits, 1) +} + +// GetHandler 获取指标处理器 +func (m *SimpleMetrics) GetHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + + // 输出基本指标 + w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n")) + w.Write([]byte("# TYPE proxy_requests_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount))) + + w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n")) + w.Write([]byte("# TYPE proxy_errors_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount))) + + w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n")) + w.Write([]byte("# TYPE proxy_active_connections gauge\n")) + w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections))) + + // 输出缓存命中数据 + w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n")) + w.Write([]byte("# TYPE proxy_cache_hits_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits))) + + // 输出传输字节数 + for direction, bytes := range m.bytesTransferred { + w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction))) + w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes))) + } + + // 输出后端健康状态 + for backend, healthy := range m.backendHealth { + healthValue := 0 + if healthy { + healthValue = 1 + } + w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n"))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n"))) + w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue))) + } + + // 输出后端响应时间 + for backend, duration := range m.backendResponseTime { + w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n"))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n"))) + w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second)))) + } + + // 平均响应时间 + if m.requestCount > 0 { + avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second) + w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n")) + w.Write([]byte("# TYPE proxy_average_response_time gauge\n")) + w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime))) + } + }) +} + +// MetricsMiddleware 指标中间件 +type MetricsMiddleware struct { + metrics MetricsCollector +} + +// NewMetricsMiddleware 创建指标中间件 +func NewMetricsMiddleware(metrics MetricsCollector) *MetricsMiddleware { + return &MetricsMiddleware{ + metrics: metrics, + } +} + +// Middleware 中间件处理函数 +func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // 包装响应写入器,用于捕获状态码 + rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // 继续处理请求 + next.ServeHTTP(rw, r) + + // 记录请求指标 + duration := time.Since(start) + m.metrics.ObserveRequestDuration(duration.Seconds()) + }) +} + +// responseWriter 包装的响应写入器 +type responseWriter struct { + http.ResponseWriter + statusCode int + written int64 +} + +// WriteHeader 写入状态码 +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode + rw.ResponseWriter.WriteHeader(statusCode) +} + +// Write 写入数据 +func (rw *responseWriter) Write(b []byte) (int, error) { + n, err := rw.ResponseWriter.Write(b) + rw.written += int64(n) + return n, err +} + +// Flush 刷新数据 +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/pkg/middleware/chain.go b/pkg/middleware/chain.go new file mode 100644 index 0000000..5c15259 --- /dev/null +++ b/pkg/middleware/chain.go @@ -0,0 +1,63 @@ +package middleware + +import ( + "net/http" +) + +// Middleware 中间件接口 +type Middleware interface { + ServeHTTP(http.ResponseWriter, *http.Request, http.HandlerFunc) +} + +// Chain 中间件链 +type Chain struct { + middlewares []Middleware +} + +// NewChain 创建新的中间件链 +func NewChain(middlewares ...Middleware) *Chain { + return &Chain{ + middlewares: middlewares, + } +} + +// Then 将中间件链应用到处理器 +func (c *Chain) Then(h http.Handler) http.Handler { + if h == nil { + h = http.DefaultServeMux + } + + for i := len(c.middlewares) - 1; i >= 0; i-- { + h = c.wrap(h, c.middlewares[i]) + } + + return h +} + +// wrap 包装处理器 +func (c *Chain) wrap(h http.Handler, m Middleware) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.ServeHTTP(w, r, h.ServeHTTP) + }) +} + +// Add 添加中间件 +func (c *Chain) Add(middlewares ...Middleware) *Chain { + c.middlewares = append(c.middlewares, middlewares...) + return c +} + +// Remove 移除中间件 +func (c *Chain) Remove(index int) *Chain { + if index < 0 || index >= len(c.middlewares) { + return c + } + c.middlewares = append(c.middlewares[:index], c.middlewares[index+1:]...) + return c +} + +// Clear 清空中间件链 +func (c *Chain) Clear() *Chain { + c.middlewares = nil + return c +} diff --git a/pkg/middleware/compression.go b/pkg/middleware/compression.go new file mode 100644 index 0000000..7803bc7 --- /dev/null +++ b/pkg/middleware/compression.go @@ -0,0 +1,127 @@ +package middleware + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +// CompressionMiddleware 压缩中间件 +type CompressionMiddleware struct { + // 压缩级别 (0-9) + level int + // 最小压缩大小 + minSize int64 + // 支持的内容类型 + contentTypes []string +} + +// NewCompressionMiddleware 创建压缩中间件 +func NewCompressionMiddleware(level int, minSize int64) *CompressionMiddleware { + return &CompressionMiddleware{ + level: level, + minSize: minSize, + contentTypes: []string{ + "text/plain", + "text/html", + "text/css", + "text/javascript", + "application/javascript", + "application/json", + "application/xml", + "application/xml+rss", + "text/xml", + "application/x-yaml", + "text/yaml", + "application/x-www-form-urlencoded", + "application/x-protobuf", + "application/grpc", + "application/grpc+proto", + }, + } +} + +// Middleware 中间件处理函数 +func (m *CompressionMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 处理请求压缩 + if m.shouldCompress(r.Header.Get("Content-Encoding")) { + reader, err := gzip.NewReader(r.Body) + if err != nil { + http.Error(w, "Invalid gzip content", http.StatusBadRequest) + return + } + defer reader.Close() + r.Body = io.NopCloser(reader) + } + + // 处理响应压缩 + if m.shouldCompressResponse(r) { + gw := gzip.NewWriter(w) + defer gw.Close() + + // 包装响应写入器 + writer := &gzipResponseWriter{ + ResponseWriter: w, + writer: gw, + } + + // 设置响应头 + w.Header().Set("Content-Encoding", "gzip") + w.Header().Set("Vary", "Accept-Encoding") + + next.ServeHTTP(writer, r) + } else { + next.ServeHTTP(w, r) + } + }) +} + +// shouldCompress 检查是否应该压缩请求 +func (m *CompressionMiddleware) shouldCompress(encoding string) bool { + return strings.Contains(encoding, "gzip") +} + +// shouldCompressResponse 检查是否应该压缩响应 +func (m *CompressionMiddleware) shouldCompressResponse(r *http.Request) bool { + // 检查客户端是否支持gzip + acceptEncoding := r.Header.Get("Accept-Encoding") + if !strings.Contains(acceptEncoding, "gzip") { + return false + } + + // 检查内容类型 + contentType := r.Header.Get("Content-Type") + for _, t := range m.contentTypes { + if strings.Contains(contentType, t) { + return true + } + } + + return false +} + +// gzipResponseWriter 包装的响应写入器 +type gzipResponseWriter struct { + http.ResponseWriter + writer *gzip.Writer +} + +// Write 写入数据 +func (w *gzipResponseWriter) Write(b []byte) (int, error) { + return w.writer.Write(b) +} + +// WriteHeader 写入状态码 +func (w *gzipResponseWriter) WriteHeader(statusCode int) { + w.ResponseWriter.WriteHeader(statusCode) +} + +// Flush 刷新数据 +func (w *gzipResponseWriter) Flush() { + w.writer.Flush() + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/pkg/middleware/ratelimiter.go b/pkg/middleware/ratelimiter.go new file mode 100644 index 0000000..05292f1 --- /dev/null +++ b/pkg/middleware/ratelimiter.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiter 限流器接口 +type RateLimiter interface { + // Allow 检查请求是否允许通过 + Allow(key string) bool +} + +// SimpleRateLimiter 简单限流器 +type SimpleRateLimiter struct { + limiter *rate.Limiter +} + +// NewSimpleRateLimiter 创建简单限流器 +func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter { + return &SimpleRateLimiter{ + limiter: rate.NewLimiter(rate.Limit(r), b), + } +} + +// Allow 检查请求是否允许通过 +func (rl *SimpleRateLimiter) Allow(key string) bool { + return rl.limiter.Allow() +} + +// IPRateLimiter 按IP限流 +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu sync.RWMutex + rate rate.Limit + burst int + cleanupInterval time.Duration + lastSeen map[string]time.Time +} + +// NewIPRateLimiter 创建IP限流器 +func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter { + limiter := &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + rate: rate.Limit(r), + burst: b, + cleanupInterval: cleanup, + lastSeen: make(map[string]time.Time), + } + + // 启动过期清理 + if cleanup > 0 { + go limiter.startCleanup() + } + + return limiter +} + +// startCleanup 启动过期清理 +func (rl *IPRateLimiter) startCleanup() { + ticker := time.NewTicker(rl.cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + rl.cleanup() + } +} + +// cleanup 清理过期限流器 +func (rl *IPRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for ip, lastSeen := range rl.lastSeen { + if now.Sub(lastSeen) > rl.cleanupInterval { + delete(rl.ips, ip) + delete(rl.lastSeen, ip) + } + } +} + +// AddIP 添加IP限流器 +func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter := rate.NewLimiter(rl.rate, rl.burst) + rl.ips[ip] = limiter + rl.lastSeen[ip] = time.Now() + + return limiter +} + +// GetLimiter 获取IP限流器 +func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + rl.mu.RLock() + limiter, exists := rl.ips[ip] + rl.mu.RUnlock() + + if !exists { + return rl.AddIP(ip) + } + + // 更新最后访问时间 + rl.mu.Lock() + rl.lastSeen[ip] = time.Now() + rl.mu.Unlock() + + return limiter +} + +// Allow 检查请求是否允许通过 +func (rl *IPRateLimiter) Allow(ip string) bool { + limiter := rl.GetLimiter(ip) + return limiter.Allow() +} + +// RateLimitMiddleware 限流中间件 +type RateLimitMiddleware struct { + limiter RateLimiter +} + +// NewRateLimitMiddleware 创建限流中间件 +func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware { + return &RateLimitMiddleware{ + limiter: limiter, + } +} + +// Middleware 中间件处理函数 +func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取客户端IP + ip := getClientIP(r) + + // 检查是否允许通过 + if !m.limiter.Allow(ip) { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + + // 继续处理请求 + next.ServeHTTP(w, r) + }) +} + +// getClientIP 获取客户端IP +func getClientIP(r *http.Request) string { + // 检查 X-Forwarded-For 头 + ip := r.Header.Get("X-Forwarded-For") + if ip != "" { + // 取第一个IP + for i := 0; i < len(ip) && i < 15; i++ { + if ip[i] == ',' { + ip = ip[:i] + break + } + } + return ip + } + + // 检查 X-Real-IP 头 + ip = r.Header.Get("X-Real-IP") + if ip != "" { + return ip + } + + // 从 RemoteAddr 获取 + if r.RemoteAddr != "" { + // 去掉端口部分 + for i := 0; i < len(r.RemoteAddr); i++ { + if r.RemoteAddr[i] == ':' { + return r.RemoteAddr[:i] + } + } + return r.RemoteAddr + } + + return "unknown" +} diff --git a/pkg/middleware/retry.go b/pkg/middleware/retry.go new file mode 100644 index 0000000..1e744b0 --- /dev/null +++ b/pkg/middleware/retry.go @@ -0,0 +1,156 @@ +package middleware + +import ( + "bytes" + "io" + "math" + "net" + "net/http" + "time" +) + +// RetryPolicy 重试策略 +type RetryPolicy struct { + // 最大重试次数 + MaxRetries int + // 基础退避时间 + BaseBackoff time.Duration + // 最大退避时间 + MaxBackoff time.Duration + // 重试判断函数 + ShouldRetry func(req *http.Request, resp *http.Response, err error) bool +} + +// DefaultRetryPolicy 默认重试策略 +func DefaultRetryPolicy() *RetryPolicy { + return &RetryPolicy{ + MaxRetries: 3, + BaseBackoff: 100 * time.Millisecond, + MaxBackoff: 2 * time.Second, + ShouldRetry: defaultShouldRetry, + } +} + +// defaultShouldRetry 默认重试判断 +func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool { + // 不重试非幂等请求 + if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions { + return false + } + + // 检查错误 + if err != nil { + // 重试网络错误 + if netErr, ok := err.(net.Error); ok { + return netErr.Temporary() || netErr.Timeout() + } + return false + } + + // 检查响应状态码 + if resp != nil { + // 重试服务器错误 + return resp.StatusCode >= 500 && resp.StatusCode < 600 + } + + return false +} + +// RetryRoundTripper 重试HTTP传输 +type RetryRoundTripper struct { + // 下一级传输 + Next http.RoundTripper + // 重试策略 + Policy *RetryPolicy +} + +// NewRetryRoundTripper 创建重试HTTP传输 +func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper { + if next == nil { + next = http.DefaultTransport + } + if policy == nil { + policy = DefaultRetryPolicy() + } + return &RetryRoundTripper{ + Next: next, + Policy: policy, + } +} + +// RoundTrip 执行HTTP请求 +func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 需要保留原始请求体,以便重试 + var reqBodyBytes []byte + if req.Body != nil { + var err error + reqBodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body.Close() + } + + var resp *http.Response + var err error + + // 尝试请求直到成功或达到最大重试次数 + for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ { + // 复制请求体 + if len(reqBodyBytes) > 0 { + req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes)) + } + + // 发送请求 + resp, err = rt.Next.RoundTrip(req) + + // 检查是否需要重试 + if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) { + // 如果需要重试,先关闭当前响应 + if resp != nil { + resp.Body.Close() + } + + // 计算退避时间 + backoff := rt.calculateBackoff(attempt) + time.Sleep(backoff) + continue + } + + // 不需要重试,返回响应 + return resp, err + } + + // 所有重试都失败 + return resp, err +} + +// calculateBackoff 计算退避时间 +func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration { + // 指数退避: baseBackoff * 2^attempt + backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt))) + if backoff > rt.Policy.MaxBackoff { + backoff = rt.Policy.MaxBackoff + } + return backoff +} + +// RetryMiddleware 重试中间件 +type RetryMiddleware struct { + policy *RetryPolicy +} + +// NewRetryMiddleware 创建重试中间件 +func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware { + if policy == nil { + policy = DefaultRetryPolicy() + } + return &RetryMiddleware{ + policy: policy, + } +} + +// Middleware 中间件处理函数 +func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper { + return NewRetryRoundTripper(next, m.policy) +} diff --git a/pkg/middleware/websocket.go b/pkg/middleware/websocket.go new file mode 100644 index 0000000..0b16d25 --- /dev/null +++ b/pkg/middleware/websocket.go @@ -0,0 +1,147 @@ +package middleware + +import ( + "bufio" + "context" + "fmt" + "net" + "net/http" + "time" + + "github.com/ouqiang/websocket" +) + +// WebSocketConn WebSocket连接接口 +type WebSocketConn struct { + conn *websocket.Conn +} + +// NewWebSocketConn 创建WebSocket连接 +func NewWebSocketConn(conn *websocket.Conn) *WebSocketConn { + return &WebSocketConn{ + conn: conn, + } +} + +// Close 关闭连接 +func (c *WebSocketConn) Close() error { + return c.conn.Close() +} + +// ReadMessage 读取消息 +func (c *WebSocketConn) ReadMessage() (int, []byte, error) { + return c.conn.ReadMessage() +} + +// WriteMessage 写入消息 +func (c *WebSocketConn) WriteMessage(messageType int, data []byte) error { + return c.conn.WriteMessage(messageType, data) +} + +// RemoteAddr 获取远程地址 +func (c *WebSocketConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// WebSocketUpgrader WebSocket升级器 +type WebSocketUpgrader struct { + upgrader websocket.Upgrader +} + +// NewWebSocketUpgrader 创建WebSocket升级器 +func NewWebSocketUpgrader(timeout time.Duration) *WebSocketUpgrader { + return &WebSocketUpgrader{ + upgrader: websocket.Upgrader{ + HandshakeTimeout: timeout, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + } +} + +// Upgrade 升级HTTP连接为WebSocket连接 +func (u *WebSocketUpgrader) Upgrade(conn net.Conn, req *http.Request) (*WebSocketConn, error) { + bufConn, ok := conn.(interface { + Hijack() (net.Conn, *bufio.ReadWriter, error) + }) + if !ok { + return nil, fmt.Errorf("连接不支持Hijack") + } + + netConn, bufrw, err := bufConn.Hijack() + if err != nil { + return nil, fmt.Errorf("hijack错误: %s", err) + } + + wsConn, err := u.upgrader.Upgrade(newResponseWriter(netConn, bufrw), req, http.Header{}) + if err != nil { + return nil, err + } + + return NewWebSocketConn(wsConn), nil +} + +// WebSocketDialer WebSocket拨号器 +type WebSocketDialer struct { + dialer websocket.Dialer +} + +// NewWebSocketDialer 创建WebSocket拨号器 +func NewWebSocketDialer(timeout time.Duration) *WebSocketDialer { + return &WebSocketDialer{ + dialer: websocket.Dialer{ + HandshakeTimeout: timeout, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + }, + } +} + +// Dial 连接到WebSocket服务器 +func (d *WebSocketDialer) Dial(urlStr string, header http.Header) (*WebSocketConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), d.dialer.HandshakeTimeout) + defer cancel() + + wsConn, _, err := d.dialer.DialContext(ctx, urlStr, header) + if err != nil { + return nil, err + } + + return NewWebSocketConn(wsConn), nil +} + +// 实现http.ResponseWriter接口,用于WebSocket升级 +type responseWriter struct { + conn net.Conn + bufrw *bufio.ReadWriter + header http.Header + status int +} + +func newResponseWriter(conn net.Conn, bufrw *bufio.ReadWriter) *responseWriter { + return &responseWriter{ + conn: conn, + bufrw: bufrw, + header: make(http.Header), + status: http.StatusOK, + } +} + +func (rw *responseWriter) Header() http.Header { + return rw.header +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + return rw.bufrw.Write(b) +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.status = statusCode +} + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.conn, rw.bufrw, nil +} diff --git a/pkg/reverse/config.go b/pkg/reverse/config.go new file mode 100644 index 0000000..e04c7d6 --- /dev/null +++ b/pkg/reverse/config.go @@ -0,0 +1,90 @@ +package reverse + +import ( + "time" + + "github.com/darkit/goproxy/pkg/dns" +) + +// Config 反向代理配置 +type Config struct { + BaseConfig `json:",inline" yaml:",inline" toml:",inline"` // 基础配置 + RulesFile string `json:"rules_file" yaml:"rules_file" toml:"rules_file"` // 规则文件路径 + InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过证书验证 + EnableHealthCheck bool `json:"enable_health_check" yaml:"enable_health_check" toml:"enable_health_check"` // 是否启用健康检查 + HealthCheckInterval time.Duration `json:"health_check_interval" yaml:"health_check_interval" toml:"health_check_interval"` // 健康检查间隔时间 + HealthCheckTimeout time.Duration `json:"health_check_timeout" yaml:"health_check_timeout" toml:"health_check_timeout"` // 健康检查超时时间 + EnableRetry bool `json:"enable_retry" yaml:"enable_retry" toml:"enable_retry"` // 是否启用重试机制 + MaxRetries int `json:"max_retries" yaml:"max_retries" toml:"max_retries"` // 最大重试次数 + RetryBackoff time.Duration `json:"retry_backoff" yaml:"retry_backoff" toml:"retry_backoff"` // 重试间隔基数 + MaxRetryBackoff time.Duration `json:"max_retry_backoff" yaml:"max_retry_backoff" toml:"max_retry_backoff"` // 最大重试间隔 + EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics" toml:"enable_metrics"` // 是否启用监控指标 + EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing" toml:"enable_tracing"` // 是否启用请求追踪 + WebSocketIntercept bool `json:"websocket_intercept" yaml:"websocket_intercept" toml:"websocket_intercept"` // 是否拦截WebSocket + DNSCacheTTL time.Duration `json:"dns_cache_ttl" yaml:"dns_cache_ttl" toml:"dns_cache_ttl"` // DNS缓存过期时间 + EnableCache bool `json:"enable_cache" yaml:"enable_cache" toml:"enable_cache"` // 是否启用响应缓存 + CacheTTL time.Duration `json:"cache_ttl" yaml:"cache_ttl" toml:"cache_ttl"` // 缓存过期时间 + EnableConnectionPool bool `json:"enable_connection_pool" yaml:"enable_connection_pool" toml:"enable_connection_pool"` // 是否启用连接池 + ConnectionPoolSize int `json:"connection_pool_size" yaml:"connection_pool_size" toml:"connection_pool_size"` // 连接池大小 + IdleTimeout time.Duration `json:"idle_timeout" yaml:"idle_timeout" toml:"idle_timeout"` // 连接空闲超时时间 + RequestTimeout time.Duration `json:"request_timeout" yaml:"request_timeout" toml:"request_timeout"` // 请求超时时间 + DNSResolver *dns.CustomResolver `json:"-" yaml:"-" toml:"-"` // DNS解析器 +} + +// BaseConfig 基础配置 +type BaseConfig struct { + ListenAddr string `json:"listen_addr" yaml:"listen_addr" toml:"listen_addr"` // 监听地址 + TargetAddr string `json:"target_addr" yaml:"target_addr" toml:"target_addr"` // 目标地址 + EnableHTTPS bool `json:"enable_https" yaml:"enable_https" toml:"enable_https"` // 是否启用HTTPS + TLSConfig *TLSConfig `json:"tls_config" yaml:"tls_config" toml:"tls_config"` // TLS配置 + EnableWebSocket bool `json:"enable_websocket" yaml:"enable_websocket" toml:"enable_websocket"` // 是否启用WebSocket + EnableCompression bool `json:"enable_compression" yaml:"enable_compression" toml:"enable_compression"` // 是否启用压缩 + EnableCORS bool `json:"enable_cors" yaml:"enable_cors" toml:"enable_cors"` // 是否启用CORS + PreserveClientIP bool `json:"preserve_client_ip" yaml:"preserve_client_ip" toml:"preserve_client_ip"` // 是否保留客户端IP + AddXForwardedFor bool `json:"add_x_forwarded_for" yaml:"add_x_forwarded_for" toml:"add_x_forwarded_for"` // 是否添加X-Forwarded-For头 + AddXRealIP bool `json:"add_x_real_ip" yaml:"add_x_real_ip" toml:"add_x_real_ip"` // 是否添加X-Real-IP头 +} + +// TLSConfig TLS配置 +type TLSConfig struct { + CertFile string `json:"cert_file" yaml:"cert_file" toml:"cert_file"` // 证书文件路径 + KeyFile string `json:"key_file" yaml:"key_file" toml:"key_file"` // 密钥文件路径 + InsecureSkipVerify bool `json:"insecure_skip_verify" yaml:"insecure_skip_verify" toml:"insecure_skip_verify"` // 是否跳过证书验证 + UseECDSA bool `json:"use_ecdsa" yaml:"use_ecdsa" toml:"use_ecdsa"` // 是否使用ECDSA +} + +// DefaultConfig 返回默认配置 +func DefaultConfig() *Config { + return &Config{ + BaseConfig: BaseConfig{ + ListenAddr: ":8080", + TargetAddr: "", // 默认目标地址为空 + EnableHTTPS: false, + EnableWebSocket: true, + EnableCompression: true, + EnableCORS: true, + PreserveClientIP: true, + AddXForwardedFor: true, + AddXRealIP: true, + }, + InsecureSkipVerify: false, + EnableHealthCheck: false, + HealthCheckInterval: 30 * time.Second, + HealthCheckTimeout: 5 * time.Second, + EnableRetry: true, + MaxRetries: 3, + RetryBackoff: time.Second, + MaxRetryBackoff: 10 * time.Second, + EnableMetrics: true, + EnableTracing: false, + WebSocketIntercept: false, + DNSCacheTTL: 5 * time.Minute, + EnableCache: true, + CacheTTL: 5 * time.Minute, + EnableConnectionPool: true, + ConnectionPoolSize: 100, + IdleTimeout: 60 * time.Second, + RequestTimeout: 30 * time.Second, + DNSResolver: dns.NewResolver(), + } +} diff --git a/pkg/reverse/proxy.go b/pkg/reverse/proxy.go new file mode 100644 index 0000000..9d1f3cc --- /dev/null +++ b/pkg/reverse/proxy.go @@ -0,0 +1,184 @@ +package reverse + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "strconv" + "strings" + "time" + + "github.com/darkit/goproxy/pkg/dns" + "github.com/darkit/goproxy/pkg/rule" +) + +// Proxy 反向代理服务器 +type Proxy struct { + config *Config + ruleManager *rule.Manager + logger *slog.Logger + transport *http.Transport + proxy *httputil.ReverseProxy + dnsResolver *dns.CustomResolver +} + +// New 创建反向代理服务器 +func New(cfg *Config) (*Proxy, error) { + if cfg == nil { + cfg = DefaultConfig() + } + + logger := slog.Default() + ruleManager := rule.NewManager(logger) + + // 创建自定义传输层 + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // 使用自定义DNS解析器 + if cfg.DNSResolver != nil { + // 解析主机和端口 + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + portNum, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + + // 使用DNS解析器解析地址 + endpoint, err := cfg.DNSResolver.ResolveWithPort(host, portNum) + if err != nil { + return nil, err + } + + // 使用解析后的地址创建连接 + dialer := &net.Dialer{ + Timeout: cfg.RequestTimeout, + KeepAlive: cfg.IdleTimeout, + } + return dialer.DialContext(ctx, network, endpoint.String()) + } + + // 使用默认拨号器 + dialer := &net.Dialer{ + Timeout: cfg.RequestTimeout, + KeepAlive: cfg.IdleTimeout, + } + return dialer.DialContext(ctx, network, addr) + }, + ForceAttemptHTTP2: true, + MaxIdleConns: cfg.ConnectionPoolSize, + IdleConnTimeout: cfg.IdleTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: cfg.InsecureSkipVerify}, + } + + proxy := &Proxy{ + config: cfg, + ruleManager: ruleManager, + logger: logger, + transport: transport, + dnsResolver: cfg.DNSResolver, + } + + // 如果配置了规则文件,加载规则 + if cfg.RulesFile != "" { + loader := rule.NewLoader(proxy.ruleManager, proxy.logger) + + // 根据文件扩展名决定加载方式 + switch { + case strings.HasSuffix(cfg.RulesFile, ".json"): + if err := loader.LoadFromJSON(cfg.RulesFile); err != nil { + proxy.logger.Error("加载规则文件失败", + "file", cfg.RulesFile, + "error", err.Error(), + ) + return nil, fmt.Errorf("加载规则文件失败: %w", err) + } + proxy.logger.Info("成功加载规则文件", + "file", cfg.RulesFile, + ) + + case strings.HasSuffix(cfg.RulesFile, ".hosts"): + if err := loader.LoadFromHosts(cfg.RulesFile); err != nil { + proxy.logger.Error("加载hosts文件失败", + "file", cfg.RulesFile, + "error", err.Error(), + ) + return nil, fmt.Errorf("加载hosts文件失败: %w", err) + } + proxy.logger.Info("成功加载hosts文件", + "file", cfg.RulesFile, + ) + + default: + return nil, fmt.Errorf("不支持的规则文件格式: %s", cfg.RulesFile) + } + } + + // 创建反向代理处理器 + proxy.proxy = &httputil.ReverseProxy{ + Transport: transport, + Director: func(req *http.Request) { + // 应用规则 + rules := proxy.ruleManager.ListRules(rule.RuleTypeRoute) + for _, r := range rules { + if r.Match(req) { + if err := r.Apply(req); err != nil { + proxy.logger.Error("应用规则失败", + "rule_id", r.GetID(), + "error", err.Error(), + ) + } + } + } + + // 设置目标地址 + if cfg.TargetAddr != "" { + req.URL.Host = cfg.TargetAddr + // 如果没有设置协议,默认使用 http + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + } + } + + // 添加X-Forwarded-For头 + if cfg.AddXForwardedFor { + req.Header.Add("X-Forwarded-For", req.RemoteAddr) + } + + // 添加X-Real-IP头 + if cfg.AddXRealIP { + req.Header.Add("X-Real-IP", req.RemoteAddr) + } + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + proxy.logger.Error("代理请求失败", + "error", err.Error(), + "method", r.Method, + "url", r.URL.String(), + ) + http.Error(w, err.Error(), http.StatusBadGateway) + }, + } + + return proxy, nil +} + +// ServeHTTP 处理HTTP请求 +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + p.proxy.ServeHTTP(w, r) +} + +// Close 关闭代理服务器 +func (p *Proxy) Close() error { + p.transport.CloseIdleConnections() + return nil +} diff --git a/pkg/rewriter/rewriter.go b/pkg/rewriter/rewriter.go new file mode 100644 index 0000000..779d3e3 --- /dev/null +++ b/pkg/rewriter/rewriter.go @@ -0,0 +1,285 @@ +package rewriter + +import ( + "bufio" + "encoding/json" + "fmt" + "net/http" + "os" + "regexp" + "strings" +) + +// Rewriter URL重写器 +// 用于在反向代理中重写请求URL +type Rewriter struct { + // 重写规则列表 + rules []*RewriteRule +} + +// RewriteRule 重写规则 +type RewriteRule struct { + // 匹配模式 + Pattern string `json:"pattern"` + // 替换模式 + Replacement string `json:"replacement"` + // 是否使用正则表达式 + UseRegex bool `json:"use_regex"` + // 编译后的正则表达式 + regex *regexp.Regexp `json:"-"` + // 规则描述 + Description string `json:"description,omitempty"` + // 规则启用状态 + Enabled bool `json:"enabled,omitempty"` +} + +// NewRewriter 创建URL重写器 +func NewRewriter() *Rewriter { + return &Rewriter{ + rules: make([]*RewriteRule, 0), + } +} + +// AddRule 添加重写规则 +func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error { + rule := &RewriteRule{ + Pattern: pattern, + Replacement: replacement, + UseRegex: useRegex, + Enabled: true, + } + + if useRegex { + regex, err := regexp.Compile(pattern) + if err != nil { + return err + } + rule.regex = regex + } + + r.rules = append(r.rules, rule) + return nil +} + +// AddRuleWithDescription 添加带描述的重写规则 +func (r *Rewriter) AddRuleWithDescription(pattern, replacement string, useRegex bool, description string) error { + rule := &RewriteRule{ + Pattern: pattern, + Replacement: replacement, + UseRegex: useRegex, + Description: description, + Enabled: true, + } + + if useRegex { + regex, err := regexp.Compile(pattern) + if err != nil { + return err + } + rule.regex = regex + } + + r.rules = append(r.rules, rule) + return nil +} + +// Rewrite 重写URL +func (r *Rewriter) Rewrite(req *http.Request) { + path := req.URL.Path + + for _, rule := range r.rules { + if !rule.Enabled { + continue + } + + if rule.UseRegex { + if rule.regex.MatchString(path) { + req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement) + break + } + } else { + if strings.HasPrefix(path, rule.Pattern) { + req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1) + break + } + } + } +} + +// RewriteResponse 重写响应 +// 主要用于处理响应中的Location头和内容中的URL +func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) { + // 处理重定向头 + location := resp.Header.Get("Location") + if location != "" { + // 将后端服务器的域名替换成代理服务器的域名 + for _, rule := range r.rules { + if !rule.Enabled { + continue + } + + if rule.UseRegex && rule.regex != nil { + if rule.regex.MatchString(location) { + newLocation := rule.regex.ReplaceAllString(location, rule.Replacement) + resp.Header.Set("Location", newLocation) + break + } + } + } + } +} + +// LoadRulesFromFile 从文件加载重写规则 +func (r *Rewriter) LoadRulesFromFile(filename string) error { + file, err := os.Open(filename) + if err != nil { + return fmt.Errorf("打开文件失败: %v", err) + } + defer file.Close() + + // 检查文件扩展名,决定使用何种方式解析 + if strings.HasSuffix(filename, ".json") { + return r.loadRulesFromJSON(file) + } else { + return r.loadRulesFromText(file) + } +} + +// loadRulesFromJSON 从JSON文件加载规则 +func (r *Rewriter) loadRulesFromJSON(file *os.File) error { + var rules []*RewriteRule + decoder := json.NewDecoder(file) + if err := decoder.Decode(&rules); err != nil { + return fmt.Errorf("解析JSON失败: %v", err) + } + + // 编译正则表达式 + for _, rule := range rules { + if rule.UseRegex { + regex, err := regexp.Compile(rule.Pattern) + if err != nil { + return fmt.Errorf("编译正则表达式'%s'失败: %v", rule.Pattern, err) + } + rule.regex = regex + } + r.rules = append(r.rules, rule) + } + + return nil +} + +// loadRulesFromText 从文本文件加载规则 +// 格式: pattern replacement [regex] [#description] +func (r *Rewriter) loadRulesFromText(file *os.File) error { + scanner := bufio.NewScanner(file) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // 跳过空行和注释 + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.Fields(line) + if len(parts) < 2 { + return fmt.Errorf("第%d行格式错误: %s", lineNum, line) + } + + pattern := parts[0] + replacement := parts[1] + useRegex := false + description := "" + + // 检查是否有额外选项 + for i := 2; i < len(parts); i++ { + if parts[i] == "regex" { + useRegex = true + } else if strings.HasPrefix(parts[i], "#") { + // 获取描述信息 + description = strings.Join(parts[i:], " ") + description = strings.TrimPrefix(description, "#") + description = strings.TrimSpace(description) + break + } + } + + if err := r.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil { + return fmt.Errorf("第%d行添加规则失败: %v", lineNum, err) + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("读取文件失败: %v", err) + } + + return nil +} + +// GetRules 获取所有规则 +func (r *Rewriter) GetRules() []*RewriteRule { + return r.rules +} + +// EnableRule 启用规则 +func (r *Rewriter) EnableRule(index int) error { + if index < 0 || index >= len(r.rules) { + return fmt.Errorf("规则索引越界: %d", index) + } + r.rules[index].Enabled = true + return nil +} + +// DisableRule 禁用规则 +func (r *Rewriter) DisableRule(index int) error { + if index < 0 || index >= len(r.rules) { + return fmt.Errorf("规则索引越界: %d", index) + } + r.rules[index].Enabled = false + return nil +} + +// RemoveRule 删除规则 +func (r *Rewriter) RemoveRule(index int) error { + if index < 0 || index >= len(r.rules) { + return fmt.Errorf("规则索引越界: %d", index) + } + r.rules = append(r.rules[:index], r.rules[index+1:]...) + return nil +} + +// SaveRulesToFile 将规则保存到文件 +func (r *Rewriter) SaveRulesToFile(filename string) error { + file, err := os.Create(filename) + if err != nil { + return fmt.Errorf("创建文件失败: %v", err) + } + defer file.Close() + + // 根据文件扩展名决定保存格式 + if strings.HasSuffix(filename, ".json") { + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + return encoder.Encode(r.rules) + } else { + writer := bufio.NewWriter(file) + for _, rule := range r.rules { + line := rule.Pattern + " " + rule.Replacement + if rule.UseRegex { + line += " regex" + } + if rule.Description != "" { + line += " # " + rule.Description + } + if !rule.Enabled { + line = "# " + line + " (disabled)" + } + if _, err := writer.WriteString(line + "\n"); err != nil { + return fmt.Errorf("写入文件失败: %v", err) + } + } + return writer.Flush() + } +} diff --git a/pkg/router/router.go b/pkg/router/router.go new file mode 100644 index 0000000..df73962 --- /dev/null +++ b/pkg/router/router.go @@ -0,0 +1,103 @@ +package router + +import ( + "net/http" + "regexp" + "strings" +) + +// Route 路由规则 +type Route struct { + // 匹配模式(主机名、路径、正则表达式) + Pattern string + // 匹配类型 + Type RouteType + // 目标地址 + Target string + // 路径重写规则 + RewritePattern string + // 请求头修改 + HeaderModifier HeaderModifier + // 自定义匹配函数 + MatchFunc func(req *http.Request) bool +} + +// RouteType 路由类型 +type RouteType int + +const ( + // HostRoute 主机名路由 + HostRoute RouteType = iota + // PathRoute 路径路由 + PathRoute + // RegexRoute 正则表达式路由 + RegexRoute + // CustomRoute 自定义路由 + CustomRoute +) + +// HeaderModifier 头部修改接口 +type HeaderModifier interface { + // ModifyRequest 修改请求头 + ModifyRequest(req *http.Request) + // ModifyResponse 修改响应头 + ModifyResponse(resp *http.Response) +} + +// Router 路由器 +type Router struct { + routes []*Route +} + +// NewRouter 创建路由器 +func NewRouter() *Router { + return &Router{ + routes: make([]*Route, 0), + } +} + +// AddRoute 添加路由规则 +func (r *Router) AddRoute(route *Route) { + r.routes = append(r.routes, route) +} + +// Match 匹配请求 +func (r *Router) Match(req *http.Request) (*Route, bool) { + for _, route := range r.routes { + switch route.Type { + case HostRoute: + if matchHost(req.Host, route.Pattern) { + return route, true + } + case PathRoute: + if matchPath(req.URL.Path, route.Pattern) { + return route, true + } + case RegexRoute: + if matchRegex(req.URL.String(), route.Pattern) { + return route, true + } + case CustomRoute: + if route.MatchFunc != nil && route.MatchFunc(req) { + return route, true + } + } + } + return nil, false +} + +// 匹配主机名 +func matchHost(host, pattern string) bool { + return host == pattern || strings.HasSuffix(host, "."+pattern) +} + +// 匹配路径 +func matchPath(path, pattern string) bool { + return strings.HasPrefix(path, pattern) +} + +// 匹配正则表达式 +func matchRegex(url, pattern string) bool { + matched, _ := regexp.MatchString(pattern, url) + return matched +} diff --git a/pkg/rule/loader.go b/pkg/rule/loader.go new file mode 100644 index 0000000..b6b0b97 --- /dev/null +++ b/pkg/rule/loader.go @@ -0,0 +1,217 @@ +package rule + +import ( + "bufio" + "encoding/json" + "fmt" + "log/slog" + "net" + "os" + "strconv" + "strings" +) + +// Loader 规则加载器 +type Loader struct { + manager *Manager + logger *slog.Logger +} + +// NewLoader 创建规则加载器 +func NewLoader(manager *Manager, logger *slog.Logger) *Loader { + if logger == nil { + logger = slog.Default() + } + return &Loader{ + manager: manager, + logger: logger, + } +} + +// LoadFromJSON 从JSON文件加载规则 +func (l *Loader) LoadFromJSON(filename string) error { + data, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("读取规则文件失败: %w", err) + } + + var config struct { + Rules []json.RawMessage `json:"rules"` + } + + if err := json.Unmarshal(data, &config); err != nil { + return fmt.Errorf("解析规则文件失败: %w", err) + } + + for _, ruleData := range config.Rules { + var baseRule BaseRule + if err := json.Unmarshal(ruleData, &baseRule); err != nil { + l.logger.Warn("解析规则基础信息失败", + "error", err.Error(), + "rule_data", string(ruleData), + ) + continue + } + + var rule Rule + switch baseRule.Type { + case RuleTypeDNS: + var dnsRule DNSRule + if err := json.Unmarshal(ruleData, &dnsRule); err != nil { + l.logger.Warn("解析DNS规则失败", + "error", err.Error(), + "rule_data", string(ruleData), + ) + continue + } + rule = &dnsRule + case RuleTypeRewrite: + var rewriteRule RewriteRule + if err := json.Unmarshal(ruleData, &rewriteRule); err != nil { + l.logger.Warn("解析重写规则失败", + "error", err.Error(), + "rule_data", string(ruleData), + ) + continue + } + rule = &rewriteRule + case RuleTypeRoute: + var routeRule RouteRule + if err := json.Unmarshal(ruleData, &routeRule); err != nil { + l.logger.Warn("解析路由规则失败", + "error", err.Error(), + "rule_data", string(ruleData), + ) + continue + } + rule = &routeRule + default: + l.logger.Warn("未知的规则类型", + "type", string(baseRule.Type), + "rule_data", string(ruleData), + ) + continue + } + + if err := l.manager.AddRule(rule); err != nil { + l.logger.Error("添加规则失败", + "error", err.Error(), + "rule_id", rule.GetID(), + "rule_type", string(rule.GetType()), + ) + } + } + + return nil +} + +// LoadFromHosts 从hosts文件加载DNS规则 +func (l *Loader) LoadFromHosts(filename string) error { + file, err := os.Open(filename) + if err != nil { + return fmt.Errorf("打开hosts文件失败: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + lineNum := 0 + for scanner.Scan() { + lineNum++ + line := strings.TrimSpace(scanner.Text()) + + // 跳过空行和注释 + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + l.logger.Warn("无效的hosts行", + "line_number", lineNum, + "line", line, + ) + continue + } + + // 解析IP和端口 + ipStr := fields[0] + ip := ipStr + port := 0 + + if strings.Contains(ipStr, ":") { + parts := strings.Split(ipStr, ":") + if len(parts) != 2 { + l.logger.Warn("无效的IP:端口格式", + "line_number", lineNum, + "ip_port", ipStr, + ) + continue + } + ip = parts[0] + if p, err := strconv.Atoi(parts[1]); err == nil { + port = p + } else { + l.logger.Warn("无效的端口号", + "line_number", lineNum, + "port", parts[1], + ) + continue + } + } + + // 验证IP地址 + if net.ParseIP(ip) == nil { + l.logger.Warn("无效的IP地址", + "line_number", lineNum, + "ip", ip, + ) + continue + } + + // 处理每个域名 + for _, domain := range fields[1:] { + // 跳过注释 + if strings.HasPrefix(domain, "#") { + break + } + + // 创建DNS规则 + matchType := MatchTypeExact + if strings.Contains(domain, "*") { + matchType = MatchTypeWildcard + } + + rule := &DNSRule{ + BaseRule: BaseRule{ + ID: fmt.Sprintf("hosts-%d-%s", lineNum, domain), + Type: RuleTypeDNS, + Priority: 100, + Pattern: domain, + MatchType: matchType, + Enabled: true, + }, + Targets: []DNSTarget{ + { + IP: ip, + Port: port, + }, + }, + } + + if err := l.manager.AddRule(rule); err != nil { + l.logger.Error("添加hosts规则失败", + "error", err.Error(), + "domain", domain, + "ip", ip, + "port", port, + ) + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("读取hosts文件失败: %w", err) + } + + return nil +} diff --git a/pkg/rule/manager.go b/pkg/rule/manager.go new file mode 100644 index 0000000..93426a4 --- /dev/null +++ b/pkg/rule/manager.go @@ -0,0 +1,227 @@ +package rule + +import ( + "fmt" + "log/slog" + "sync" + "time" +) + +// Manager 规则管理器 +type Manager struct { + // 规则存储 + rules sync.Map + // 规则索引 + indexes map[RuleType][]string + // 日志记录器 + logger *slog.Logger + // 更新通知通道 + updateCh chan struct{} + // 观察者列表 + observers []Observer + // 互斥锁 + mu sync.RWMutex +} + +// Observer 规则更新观察者接口 +type Observer interface { + OnRuleUpdate(ruleType RuleType) +} + +// NewManager 创建规则管理器 +func NewManager(logger *slog.Logger) *Manager { + if logger == nil { + logger = slog.Default() + } + return &Manager{ + logger: logger, + indexes: make(map[RuleType][]string), + updateCh: make(chan struct{}, 1), + observers: make([]Observer, 0), + } +} + +// AddRule 添加规则 +func (m *Manager) AddRule(rule Rule) error { + // 1. 验证规则 + if err := rule.Validate(); err != nil { + return fmt.Errorf("规则验证失败: %w", err) + } + + // 2. 存储规则 + m.mu.Lock() + defer m.mu.Unlock() + + // 检查是否已存在相同ID的规则 + if _, exists := m.rules.Load(rule.GetID()); exists { + return fmt.Errorf("规则ID %s 已存在", rule.GetID()) + } + + // 更新时间戳 + switch r := rule.(type) { + case *DNSRule: + r.CreatedAt = time.Now() + r.UpdatedAt = time.Now() + case *RewriteRule: + r.CreatedAt = time.Now() + r.UpdatedAt = time.Now() + case *RouteRule: + r.CreatedAt = time.Now() + r.UpdatedAt = time.Now() + } + + m.rules.Store(rule.GetID(), rule) + m.updateIndex(rule) + + // 3. 通知更新 + m.notifyUpdate(rule.GetType()) + + m.logger.Info("添加规则成功", + "id", rule.GetID(), + "type", string(rule.GetType()), + "priority", rule.GetPriority(), + ) + + return nil +} + +// GetRule 获取规则 +func (m *Manager) GetRule(id string) (Rule, bool) { + if rule, ok := m.rules.Load(id); ok { + return rule.(Rule), true + } + return nil, false +} + +// ListRules 列出指定类型的规则 +func (m *Manager) ListRules(ruleType RuleType) []Rule { + m.mu.RLock() + defer m.mu.RUnlock() + + var rules []Rule + if ids, ok := m.indexes[ruleType]; ok { + for _, id := range ids { + if rule, exists := m.rules.Load(id); exists { + rules = append(rules, rule.(Rule)) + } + } + } + return rules +} + +// DeleteRule 删除规则 +func (m *Manager) DeleteRule(id string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if rule, ok := m.rules.Load(id); ok { + m.rules.Delete(id) + m.removeFromIndex(rule.(Rule)) + m.notifyUpdate(rule.(Rule).GetType()) + + m.logger.Info("删除规则成功", + "id", id, + "type", string(rule.(Rule).GetType()), + ) + + return true + } + return false +} + +// UpdateRule 更新规则 +func (m *Manager) UpdateRule(rule Rule) error { + // 1. 验证规则 + if err := rule.Validate(); err != nil { + return fmt.Errorf("规则验证失败: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + // 检查规则是否存在 + if _, exists := m.rules.Load(rule.GetID()); !exists { + return fmt.Errorf("规则ID %s 不存在", rule.GetID()) + } + + // 更新时间戳 + switch r := rule.(type) { + case *DNSRule: + r.UpdatedAt = time.Now() + case *RewriteRule: + r.UpdatedAt = time.Now() + case *RouteRule: + r.UpdatedAt = time.Now() + } + + // 更新规则 + m.rules.Store(rule.GetID(), rule) + m.updateIndex(rule) + m.notifyUpdate(rule.GetType()) + + m.logger.Info("更新规则成功", + "id", rule.GetID(), + "type", string(rule.GetType()), + "priority", rule.GetPriority(), + ) + + return nil +} + +// AddObserver 添加观察者 +func (m *Manager) AddObserver(observer Observer) { + m.mu.Lock() + defer m.mu.Unlock() + m.observers = append(m.observers, observer) +} + +// 更新规则索引 +func (m *Manager) updateIndex(rule Rule) { + ruleType := rule.GetType() + if _, ok := m.indexes[ruleType]; !ok { + m.indexes[ruleType] = make([]string, 0) + } + // 检查是否已存在于索引中 + exists := false + for _, id := range m.indexes[ruleType] { + if id == rule.GetID() { + exists = true + break + } + } + if !exists { + m.indexes[ruleType] = append(m.indexes[ruleType], rule.GetID()) + } +} + +// 从索引中移除规则 +func (m *Manager) removeFromIndex(rule Rule) { + ruleType := rule.GetType() + if ids, ok := m.indexes[ruleType]; ok { + for i, id := range ids { + if id == rule.GetID() { + m.indexes[ruleType] = append(ids[:i], ids[i+1:]...) + break + } + } + } +} + +// 通知规则更新 +func (m *Manager) notifyUpdate(ruleType RuleType) { + // 通知所有观察者 + for _, observer := range m.observers { + observer.OnRuleUpdate(ruleType) + } + + // 发送更新信号 + select { + case m.updateCh <- struct{}{}: + default: + } +} + +// GetUpdateChannel 获取更新通知通道 +func (m *Manager) GetUpdateChannel() <-chan struct{} { + return m.updateCh +} diff --git a/pkg/rule/rule_impl.go b/pkg/rule/rule_impl.go new file mode 100644 index 0000000..85e2283 --- /dev/null +++ b/pkg/rule/rule_impl.go @@ -0,0 +1,187 @@ +package rule + +import ( + "fmt" + "net" + "net/http" + "regexp" + "strings" +) + +// Match DNS规则匹配 +func (r *DNSRule) Match(req *http.Request) bool { + host := req.Host + if strings.Contains(host, ":") { + host, _, _ = net.SplitHostPort(host) + } + + switch r.MatchType { + case MatchTypeExact: + return host == r.Pattern + case MatchTypeWildcard: + return matchWildcard(host, r.Pattern) + case MatchTypeRegex: + if re, err := regexp.Compile(r.Pattern); err == nil { + return re.MatchString(host) + } + } + return false +} + +// Apply DNS规则应用 +func (r *DNSRule) Apply(req *http.Request) error { + if !r.Enabled { + return nil + } + // DNS规则的应用在DNS解析阶段处理,这里不需要修改请求 + return nil +} + +// Validate DNS规则验证 +func (r *DNSRule) Validate() error { + if r.ID == "" { + return fmt.Errorf("规则ID不能为空") + } + if len(r.Targets) == 0 { + return fmt.Errorf("DNS规则必须包含至少一个目标") + } + for _, target := range r.Targets { + if net.ParseIP(target.IP) == nil { + return fmt.Errorf("无效的IP地址: %s", target.IP) + } + if target.Port < 0 || target.Port > 65535 { + return fmt.Errorf("无效的端口号: %d", target.Port) + } + } + return nil +} + +// Match URL重写规则匹配 +func (r *RewriteRule) Match(req *http.Request) bool { + switch r.MatchType { + case MatchTypeExact: + return req.URL.Path == r.Pattern + case MatchTypePath: + return strings.HasPrefix(req.URL.Path, r.Pattern) + case MatchTypeRegex: + if re, err := regexp.Compile(r.Pattern); err == nil { + return re.MatchString(req.URL.Path) + } + } + return false +} + +// Apply URL重写规则应用 +func (r *RewriteRule) Apply(req *http.Request) error { + if !r.Enabled { + return nil + } + + switch r.MatchType { + case MatchTypeExact, MatchTypePath: + req.URL.Path = strings.Replace(req.URL.Path, r.Pattern, r.Replacement, 1) + case MatchTypeRegex: + if re, err := regexp.Compile(r.Pattern); err == nil { + req.URL.Path = re.ReplaceAllString(req.URL.Path, r.Replacement) + } else { + return fmt.Errorf("正则表达式编译失败: %v", err) + } + } + return nil +} + +// Validate URL重写规则验证 +func (r *RewriteRule) Validate() error { + if r.ID == "" { + return fmt.Errorf("规则ID不能为空") + } + if r.Pattern == "" { + return fmt.Errorf("重写规则必须包含匹配模式") + } + if r.MatchType == MatchTypeRegex { + if _, err := regexp.Compile(r.Pattern); err != nil { + return fmt.Errorf("无效的正则表达式: %v", err) + } + } + return nil +} + +// Match 路由规则匹配 +func (r *RouteRule) Match(req *http.Request) bool { + switch r.MatchType { + case MatchTypeExact: + return req.URL.Path == r.Pattern + case MatchTypePath: + return strings.HasPrefix(req.URL.Path, r.Pattern) + case MatchTypeRegex: + if re, err := regexp.Compile(r.Pattern); err == nil { + return re.MatchString(req.URL.Path) + } + } + return false +} + +// Apply 路由规则应用 +func (r *RouteRule) Apply(req *http.Request) error { + if !r.Enabled { + return nil + } + + // 修改请求头 + for key, value := range r.HeaderModifier { + // 支持变量替换 + switch value { + case "${client_ip}": + clientIP := req.RemoteAddr + if ip, _, err := net.SplitHostPort(clientIP); err == nil { + req.Header.Set(key, ip) + } else { + req.Header.Set(key, clientIP) + } + default: + req.Header.Set(key, value) + } + } + return nil +} + +// Validate 路由规则验证 +func (r *RouteRule) Validate() error { + if r.ID == "" { + return fmt.Errorf("规则ID不能为空") + } + if r.Pattern == "" { + return fmt.Errorf("路由规则必须包含匹配模式") + } + if r.Target == "" { + return fmt.Errorf("路由规则必须包含目标地址") + } + if r.MatchType == MatchTypeRegex { + if _, err := regexp.Compile(r.Pattern); err != nil { + return fmt.Errorf("无效的正则表达式: %v", err) + } + } + return nil +} + +// 辅助函数:通配符匹配 +func matchWildcard(host, pattern string) bool { + if !strings.Contains(pattern, "*") { + return host == pattern + } + + parts := strings.Split(pattern, ".") + hostParts := strings.Split(host, ".") + + if len(parts) != len(hostParts) { + return false + } + + for i := 0; i < len(parts); i++ { + if parts[i] != "*" && parts[i] != hostParts[i] { + return false + } + } + + return true +} diff --git a/pkg/rule/types.go b/pkg/rule/types.go new file mode 100644 index 0000000..837660d --- /dev/null +++ b/pkg/rule/types.go @@ -0,0 +1,100 @@ +package rule + +import ( + "net/http" + "time" +) + +// RuleType 规则类型 +type RuleType string + +const ( + // RuleTypeDNS DNS规则类型 + RuleTypeDNS RuleType = "dns" + // RuleTypeRewrite URL重写规则类型 + RuleTypeRewrite RuleType = "rewrite" + // RuleTypeRoute 路由规则类型 + RuleTypeRoute RuleType = "route" +) + +// MatchType 匹配类型 +type MatchType string + +const ( + // MatchTypeExact 精确匹配 + MatchTypeExact MatchType = "exact" + // MatchTypeWildcard 通配符匹配 + MatchTypeWildcard MatchType = "wildcard" + // MatchTypeRegex 正则匹配 + MatchTypeRegex MatchType = "regex" + // MatchTypePath 路径匹配 + MatchTypePath MatchType = "path" +) + +// Rule 统一的规则接口 +type Rule interface { + // GetID 获取规则ID + GetID() string + // GetType 获取规则类型 + GetType() RuleType + // GetPriority 获取规则优先级 + GetPriority() int + // Match 匹配规则 + Match(req *http.Request) bool + // Apply 应用规则 + Apply(req *http.Request) error + // Validate 验证规则 + Validate() error +} + +// BaseRule 基础规则结构 +type BaseRule struct { + ID string `json:"id"` + Type RuleType `json:"type"` + Priority int `json:"priority"` + Pattern string `json:"pattern"` + MatchType MatchType `json:"match_type"` + Enabled bool `json:"enabled"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// DNSRule DNS规则 +type DNSRule struct { + BaseRule + Targets []DNSTarget `json:"targets"` +} + +// DNSTarget DNS目标 +type DNSTarget struct { + IP string `json:"ip"` + Port int `json:"port"` +} + +// RewriteRule URL重写规则 +type RewriteRule struct { + BaseRule + Replacement string `json:"replacement"` +} + +// RouteRule 路由规则 +type RouteRule struct { + BaseRule + Target string `json:"target"` + HeaderModifier map[string]string `json:"header_modifier"` +} + +// GetID 获取规则ID +func (r *BaseRule) GetID() string { + return r.ID +} + +// GetType 获取规则类型 +func (r *BaseRule) GetType() RuleType { + return r.Type +} + +// GetPriority 获取规则优先级 +func (r *BaseRule) GetPriority() int { + return r.Priority +} diff --git a/pkg/server/graceful.go b/pkg/server/graceful.go new file mode 100644 index 0000000..2983581 --- /dev/null +++ b/pkg/server/graceful.go @@ -0,0 +1,87 @@ +package server + +import ( + "context" + "log/slog" + "net/http" + "os" + "os/signal" + "sync" + "syscall" + "time" +) + +// GracefulServer 优雅关闭服务器 +type GracefulServer struct { + // HTTP服务器 + server *http.Server + // 等待组 + wg sync.WaitGroup + // 停止信号 + stopChan chan struct{} +} + +// NewGracefulServer 创建优雅关闭服务器 +func NewGracefulServer(addr string, handler http.Handler) *GracefulServer { + return &GracefulServer{ + server: &http.Server{ + Addr: addr, + Handler: handler, + }, + stopChan: make(chan struct{}), + } +} + +// Start 启动服务器 +func (s *GracefulServer) Start() error { + // 启动HTTP服务器 + go func() { + slog.Info("启动HTTP服务器", "addr", s.server.Addr) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + slog.Error("HTTP服务器错误", "error", err) + } + }() + + // 等待中断信号 + s.waitForInterrupt() + + return nil +} + +// waitForInterrupt 等待中断信号 +func (s *GracefulServer) waitForInterrupt() { + // 创建信号通道 + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // 等待信号 + <-sigChan + slog.Info("收到停止信号,开始优雅关闭") + + // 创建关闭上下文 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // 关闭HTTP服务器 + if err := s.server.Shutdown(ctx); err != nil { + slog.Error("关闭HTTP服务器失败", "error", err) + } + + // 通知所有goroutine停止 + close(s.stopChan) + + // 等待所有goroutine结束 + s.wg.Wait() + + slog.Info("服务器已优雅关闭") +} + +// Stop 停止服务器 +func (s *GracefulServer) Stop() { + close(s.stopChan) +} + +// WaitGroup 获取等待组 +func (s *GracefulServer) WaitGroup() *sync.WaitGroup { + return &s.wg +} diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..cdfb444 --- /dev/null +++ b/pool.go @@ -0,0 +1,26 @@ +package goproxy + +import "sync" + +// 泛型对象池 +type pool[T any] struct { + pool sync.Pool +} + +func newPool[T any](newFunc func() T) *pool[T] { + return &pool[T]{ + pool: sync.Pool{ + New: func() interface{} { + return newFunc() + }, + }, + } +} + +func (p *pool[T]) Get() T { + return p.pool.Get().(T) +} + +func (p *pool[T]) Put(x T) { + p.pool.Put(x) +} diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..f510192 --- /dev/null +++ b/proxy.go @@ -0,0 +1,1301 @@ +package goproxy + +import ( + "bufio" + "context" + "crypto/elliptic" + "crypto/tls" + "fmt" + "io" + "log/slog" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/darkit/goproxy/config" + "github.com/darkit/goproxy/pkg/cache" + "github.com/darkit/goproxy/pkg/healthcheck" + "github.com/darkit/goproxy/pkg/loadbalance" + "github.com/darkit/goproxy/pkg/metrics" + "github.com/darkit/goproxy/pkg/middleware" + "github.com/darkit/goproxy/pkg/reverse" + "github.com/ouqiang/websocket" + "github.com/viki-org/dnscache" +) + +const ( + // 连接目标服务器超时时间 + defaultTargetConnectTimeout = 5 * time.Second + // 目标服务器读写超时时间 + defaultTargetReadWriteTimeout = 10 * time.Second +) + +// 隧道连接成功响应行 +var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n") + +// 错误网关响应 +var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))) + +// 对象池 +var ( + bufPool = newPool(func() []byte { + return make([]byte, 32*1024) + }) + + ctxPool = newPool(func() *Context { + return new(Context) + }) +) + +// CertificateCache 证书缓存接口 +type CertificateCache interface { + // Get 获取证书 + Get(host string) *tls.Certificate + // Set 设置证书 + Set(host string, cert *tls.Certificate) +} + +// MemCertCache 内存证书缓存 +type MemCertCache struct { + certs sync.Map +} + +// Get 获取证书 +func (c *MemCertCache) Get(host string) *tls.Certificate { + v, ok := c.certs.Load(host) + if !ok { + return nil + } + return v.(*tls.Certificate) +} + +// Set 设置证书 +func (c *MemCertCache) Set(host string, cert *tls.Certificate) { + c.certs.Store(host, cert) +} + +// CacheAdapter 缓存适配器,统一不同缓存实现的接口 +type CacheAdapter struct { + cache interface{} + // 缓存方法类型标志 + getMethodType int + setMethodType int + // 方法类型常量 + getResponseBool int + getInterfaceBool int + setResponseOnly int + setResponseTTL int + setInterfaceOnly int +} + +// NewCacheAdapter 创建缓存适配器 +func NewCacheAdapter(cache interface{}) *CacheAdapter { + adapter := &CacheAdapter{ + cache: cache, + // 方法类型常量初始化 + getResponseBool: 1, + getInterfaceBool: 2, + setResponseOnly: 1, + setResponseTTL: 2, + setInterfaceOnly: 3, + } + + // 判断支持的方法类型 + if _, ok := cache.(interface { + Get(string) (*http.Response, bool) + }); ok { + adapter.getMethodType = adapter.getResponseBool + } else if _, ok := cache.(interface { + Get(string) (interface{}, bool) + }); ok { + adapter.getMethodType = adapter.getInterfaceBool + } + + if _, ok := cache.(interface { + Set(string, *http.Response, time.Duration) + }); ok { + adapter.setMethodType = adapter.setResponseTTL + } else if _, ok := cache.(interface { + Set(string, *http.Response) + }); ok { + adapter.setMethodType = adapter.setResponseOnly + } else if _, ok := cache.(interface { + Set(string, interface{}) + }); ok { + adapter.setMethodType = adapter.setInterfaceOnly + } + + return adapter +} + +// Get 统一的获取方法 +func (a *CacheAdapter) Get(key string) (interface{}, bool) { + switch a.getMethodType { + case a.getResponseBool: + if getter, ok := a.cache.(interface { + Get(string) (*http.Response, bool) + }); ok { + return getter.Get(key) + } + case a.getInterfaceBool: + if getter, ok := a.cache.(interface { + Get(string) (interface{}, bool) + }); ok { + return getter.Get(key) + } + } + return nil, false +} + +// Set 统一的设置方法 +func (a *CacheAdapter) Set(key string, value interface{}, ttl time.Duration) { + resp, isResponse := value.(*http.Response) + switch a.setMethodType { + case a.setResponseTTL: + if setter, ok := a.cache.(interface { + Set(string, *http.Response, time.Duration) + }); ok && isResponse { + setter.Set(key, resp, ttl) + } + case a.setResponseOnly: + if setter, ok := a.cache.(interface { + Set(string, *http.Response) + }); ok && isResponse { + setter.Set(key, resp) + } + case a.setInterfaceOnly: + if setter, ok := a.cache.(interface { + Set(string, interface{}) + }); ok { + setter.Set(key, value) + } + } +} + +// Proxy HTTP代理 +type Proxy struct { + // 配置 + config *config.Config + // 委托 + delegate Delegate + // 证书缓存 + certCache CertificateCache + // HTTP缓存 + httpCache cache.Cache + // 缓存适配器 + cacheAdapter *CacheAdapter + // 负载均衡器 + loadBalancer loadbalance.LoadBalancer + // 健康检查器 + healthChecker *healthcheck.HealthChecker + // 监控指标 + metrics metrics.MetricsCollector + // 客户端跟踪 + clientTrace *httptrace.ClientTrace + // 基础传输(用于直接获取*http.Transport类型) + transport *http.Transport + // HTTP请求传输(可能被中间件包装) + httpTransport http.RoundTripper + // DNS缓存 + dnsCache *dnscache.Resolver + // 客户端连接数 + clientConnNum int32 + // 证书管理器 + certManager *CertManager + // 日志记录器 + logger *slog.Logger +} + +// New 创建代理 +func New(opts *Options) *Proxy { + if opts == nil { + opts = &Options{} + } + + if opts.Config == nil { + opts.Config = config.DefaultConfig() + } + + if opts.Delegate == nil { + opts.Delegate = &DefaultDelegate{} + } + + p := &Proxy{ + config: opts.Config, + delegate: opts.Delegate, + certCache: opts.CertCache, + httpCache: opts.HTTPCache, + loadBalancer: opts.LoadBalancer, + healthChecker: opts.HealthChecker, + metrics: opts.Metrics, + clientTrace: opts.ClientTrace, + clientConnNum: 0, + logger: opts.Config.Logger, + } + + // 如果存在HTTP缓存,创建缓存适配器 + if p.httpCache != nil { + p.cacheAdapter = NewCacheAdapter(p.httpCache) + } + + // 创建DNS缓存 + p.dnsCache = dnscache.New(opts.Config.DNSCacheTTL) + + // 设置证书管理器 + if opts.CertManager != nil { + // 如果选项中已提供证书管理器,直接使用 + p.certManager = opts.CertManager + } else if opts.Config.DecryptHTTPS { + // 如果启用了HTTPS解密,且未提供证书管理器,则创建一个新的证书管理器 + certManagerOpts := []CertManagerOption{ + WithDefaultPrivateKey(true), // 使用默认私钥提高性能 + WithValidityYears(1), // 证书有效期1年 + WithUseECDSA(opts.Config.UseECDSA), // 根据配置决定是否使用ECDSA + } + + // 如果配置指定了使用ECDSA,设置曲线为P-256 + if opts.Config.UseECDSA { + certManagerOpts = append(certManagerOpts, WithCurve(elliptic.P256())) + } + + p.certManager = NewCertManager(p.certCache, certManagerOpts...) + } + + // 创建基础传输 + httpTransport := &http.Transport{ + Proxy: p.proxyFromDelegate, + DialContext: p.dialContextWithCache(), + MaxIdleConns: opts.Config.ConnectionPoolSize, + IdleConnTimeout: opts.Config.IdleTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableKeepAlives: false, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + + // 保存原始Transport供后续使用 + p.transport = httpTransport + + // 包装传输对象,应用中间件 + var roundTripper http.RoundTripper = httpTransport + + // 应用重试中间件 + if opts.Config.EnableRetry { + policy := &middleware.RetryPolicy{ + MaxRetries: opts.Config.MaxRetries, + BaseBackoff: opts.Config.BaseBackoff, + MaxBackoff: opts.Config.MaxBackoff, + } + retryMiddleware := middleware.NewRetryMiddleware(policy) + roundTripper = retryMiddleware.Transport(roundTripper) + } + + // 最终的RoundTripper赋值给p.httpTransport,用于HTTP请求 + p.httpTransport = roundTripper + + // 将健康检查器与负载均衡器集成 + if p.healthChecker != nil && p.loadBalancer != nil { + p.healthChecker.SetStatusChangeCallback(func(target string, healthy bool) { + if healthy { + p.loadBalancer.MarkUp(target) + } else { + p.loadBalancer.MarkDown(target) + } + }) + } + + return p +} + +// NewProxy 使用functional options模式创建代理 +func NewProxy(options ...Option) *Proxy { + // 创建默认选项 + opts := &Options{ + Config: config.DefaultConfig(), + } + + // 应用所有选项 + for _, option := range options { + option(opts) + } + + // 使用传统方法创建代理 + return New(opts) +} + +// ServeHTTP 处理HTTP请求 +func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // 判断是反向代理还是正向代理 + if p.config.ReverseProxy { + if req.URL.Scheme == "" { + req.URL.Scheme = "http" + } + // 如果是反向代理模式,使用反向代理处理请求 + reverseProxy, err := reverse.New(convertToReverseConfig(p.config)) + if err != nil { + p.logger.Error("创建反向代理失败", + "error", err.Error(), + ) + http.Error(rw, "Internal Server Error", http.StatusInternalServerError) + return + } + reverseProxy.ServeHTTP(rw, req) + return + } + + // 更新请求计数指标 + if p.metrics != nil { + p.metrics.IncRequestCount() + } + + if req.URL.Host == "" || req.URL.Host == req.Host { + rw.Header().Set("X-Proxy-Error", "Invalid Request") + http.Error(rw, "", http.StatusBadRequest) + return + } + // 处理请求 + ctx := ctxPool.Get() + ctx.Reset(req) + defer ctxPool.Put(ctx) + + // 调用连接事件 + p.delegate.Connect(ctx, rw) + + // 认证检查 + p.delegate.Auth(ctx, rw) + if ctx.IsAborted() { + return + } + + // HTTP隧道连接(CONNECT方法) + if req.Method == http.MethodConnect { + p.tunnelProxy(ctx, rw) + return + } + + // 如果是WebSocket请求,使用WebSocket代理 + if isWebSocketRequest(req) { + clientConn, err := hijacker(rw) + if err != nil { + p.delegate.ErrorLog(err) + http.Error(rw, "无法处理WebSocket请求", http.StatusInternalServerError) + return + } + p.websocketProxy(ctx, clientConn) + return + } + + // 处理普通HTTP请求 + p.handleHTTP(ctx, rw) +} + +// handleHTTP 处理HTTP请求 +func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) { + // 调用请求前事件 + p.delegate.BeforeRequest(ctx) + if ctx.IsAborted() { + return + } + + // 开始时间 + startTime := time.Now() + + // 获取上级代理 + parentProxy, err := p.proxyFromDelegate(ctx.Req) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 获取上级代理错误: %s", ctx.Req.URL.Host, err)) + ctx.ParentProxyURL = nil + } else { + ctx.ParentProxyURL = parentProxy + } + + var ( + resp *http.Response + req = ctx.Req + ) + + // 从缓存获取响应 + if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) { + cacheKey := generateCacheKey(req) + var cachedResp interface{} + var ok bool + + // 使用缓存适配器获取数据 + if p.cacheAdapter != nil { + cachedResp, ok = p.cacheAdapter.Get(cacheKey) + if ok && cachedResp != nil { + // 从缓存中找到响应 + resp = cachedResp.(*http.Response) + // 更新指标 + if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { + incrementCacheHit(p.metrics) + } + } + } + } + + // 如果缓存中没有,则发送请求 + if resp == nil { + // 创建传输上下文 + reqCtx := req.Context() + if p.clientTrace != nil { + reqCtx = httptrace.WithClientTrace(reqCtx, p.clientTrace) + } + + // 设置请求超时 + if p.config.RequestTimeout > 0 { + var cancel context.CancelFunc + reqCtx, cancel = context.WithTimeout(reqCtx, p.config.RequestTimeout) + defer cancel() + } + + req = req.WithContext(reqCtx) + + // 发送请求 + var err error + resp, err = p.httpTransport.RoundTrip(req) + // 处理错误 + if err != nil { + p.delegate.BeforeResponse(ctx, nil, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", req.URL.Host, err)) + http.Error(rw, err.Error(), http.StatusBadGateway) + return + } + + // 更新指标 + if p.metrics != nil { + p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds()) + } + + // 缓存响应 + if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) && canCacheStatus(resp.StatusCode) { + cacheKey := generateCacheKey(req) + + // 使用缓存适配器设置数据 + if p.cacheAdapter != nil { + p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) + } + } + } + + // 调用响应前事件 + p.delegate.BeforeResponse(ctx, resp, nil) + if ctx.IsAborted() { + return + } + + // 复制头部信息 + for key, values := range resp.Header { + for _, value := range values { + rw.Header().Add(key, value) + } + } + + // 写入状态码 + rw.WriteHeader(resp.StatusCode) + + // 复制响应体 + buf := bufPool.Get() + defer bufPool.Put(buf) + + _, err = io.CopyBuffer(rw, resp.Body, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err)) + } + + // 关闭响应体 + resp.Body.Close() + + // 调用完成事件 + p.delegate.Finish(ctx) +} + +// canCacheMethod 检查请求方法是否可缓存 +func canCacheMethod(method string) bool { + return method == http.MethodGet || method == http.MethodHead +} + +// canCacheStatus 检查响应状态码是否可缓存 +func canCacheStatus(statusCode int) bool { + return statusCode >= 200 && statusCode < 400 +} + +// generateCacheKey 生成缓存键 +func generateCacheKey(req *http.Request) string { + return req.Method + " " + req.URL.String() +} + +// getCacheTTL 获取缓存TTL +func getCacheTTL(resp *http.Response) time.Duration { + // 默认5分钟 + ttl := 5 * time.Minute + + // 从Cache-Control获取max-age + cacheControl := resp.Header.Get("Cache-Control") + if cacheControl != "" { + for _, directive := range strings.Split(cacheControl, ",") { + directive = strings.TrimSpace(directive) + if strings.HasPrefix(directive, "max-age=") { + maxAge := strings.TrimPrefix(directive, "max-age=") + if seconds, err := strconv.Atoi(maxAge); err == nil { + ttl = time.Duration(seconds) * time.Second + } + break + } + } + } + + return ttl +} + +// ClientConnNum 获取客户端连接数 +func (p *Proxy) ClientConnNum() int32 { + return atomic.LoadInt32(&p.clientConnNum) +} + +// DoRequest 执行HTTP请求 +func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) { + if ctx.Data == nil { + ctx.Data = make(map[interface{}]interface{}) + } + + // 请求前处理 + p.delegate.BeforeRequest(ctx) + if ctx.IsAborted() { + return + } + + // 检查缓存 + if p.httpCache != nil && ctx.Req.Method == http.MethodGet && p.config.EnableCache { + cacheKey := cache.GenerateCacheKey(ctx.Req) + if p.cacheAdapter != nil { + cachedResp, ok := p.cacheAdapter.Get(cacheKey) + if ok && cachedResp != nil { + // 使用缓存的响应 + cached := cachedResp.(*http.Response) + p.delegate.BeforeResponse(ctx, cached, nil) + if !ctx.IsAborted() { + responseFunc(cached, nil) + } + + // 更新指标 + if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { + incrementCacheHit(p.metrics) + } + return + } + } + } + + // 准备请求 + newReq := ctx.Req.Clone(ctx.Req.Context()) + + // 移除hop-by-hop头部 + for _, h := range hopHeaders { + newReq.Header.Del(h) + } + + // 添加客户端跟踪 + if p.clientTrace != nil { + newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace)) + } + + // 执行请求 + resp, err := p.httpTransport.RoundTrip(newReq) + + // 响应前处理 + p.delegate.BeforeResponse(ctx, resp, err) + if ctx.IsAborted() { + if resp != nil { + resp.Body.Close() + } + return + } + + // 错误处理 + if err != nil { + responseFunc(nil, err) + return + } + + // 移除hop-by-hop头部 + for _, h := range hopHeaders { + resp.Header.Del(h) + } + + // 缓存响应 + if p.httpCache != nil && p.config.EnableCache && cache.ShouldCache(ctx.Req, resp) { + cacheKey := cache.GenerateCacheKey(ctx.Req) + if p.cacheAdapter != nil { + p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) + } + } + + // 返回响应 + responseFunc(resp, nil) +} + +// HTTPS代理 +func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) { + if isWebSocketRequest(ctx.Req) { + p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil)) + return + } + + p.DoRequest(ctx, func(resp *http.Response, err error) { + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,请求错误: %s", ctx.Req.URL, err)) + tlsClientConn.Write(badGatewayResponse) + return + } + + // 直接写入TLS连接 + err = resp.Write(tlsClientConn) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,响应写入客户端失败: %s", ctx.Req.URL, err)) + } + resp.Body.Close() + }) +} + +// 隧道代理 +func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) { + // 获取客户端连接 + clientConn, err := hijacker(rw) + if err != nil { + p.delegate.ErrorLog(err) + rw.WriteHeader(http.StatusBadGateway) + return + } + defer clientConn.Close() + + // 处理WebSocket请求 + if isWebSocketRequest(ctx.Req) { + p.websocketProxy(ctx, clientConn) + return + } + + // 获取上级代理 + parentProxyURL, err := p.delegate.ParentProxy(ctx.Req) + if ctx.ParentProxyURL != nil { + parentProxyURL = ctx.ParentProxyURL + } + + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err)) + rw.WriteHeader(http.StatusBadGateway) + return + } + + // 如果不使用上级代理,通知客户端隧道已建立 + if parentProxyURL == nil { + _, err = clientConn.Write(tunnelEstablishedResponseLine) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err)) + return + } + } + + // 检测WebSocket + isWebsocket := false + methodBytes, err := clientConn.Peek(3) + if err == nil && string(methodBytes) == http.MethodGet { + isWebsocket = true + } + + // 处理WebSocket + if isWebsocket { + req, err := http.ReadRequest(clientConn.BufferReader()) + if err != nil { + if err != io.EOF { + p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + req.RemoteAddr = ctx.Req.RemoteAddr + req.URL.Scheme = "http" + req.URL.Host = req.Host + ctx.Req = req + + p.websocketProxy(ctx, clientConn) + return + } + + // HTTPS解密 + var tlsClientConn *tls.Conn + if p.config.DecryptHTTPS { + // 生成证书 + certConfig, err := p.generateTLSConfig(ctx.Req.URL.Host) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,生成证书失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 创建TLS服务器连接 + tlsClientConn = tls.Server(clientConn, certConfig) + defer tlsClientConn.Close() + + // TLS握手 + if err := tlsClientConn.Handshake(); err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,握手失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 读取HTTPS请求 + buf := bufio.NewReader(tlsClientConn) + tlsReq, err := http.ReadRequest(buf) + if err != nil { + if err != io.EOF { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,读取客户端请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + + // 更新请求信息 + tlsReq.RemoteAddr = ctx.Req.RemoteAddr + tlsReq.URL.Scheme = "https" + tlsReq.URL.Host = tlsReq.Host + ctx.Req = tlsReq + } + + // 确定目标地址 + targetAddr := ctx.Req.URL.Host + if ctx.TargetAddr != "" { + targetAddr = ctx.TargetAddr + } else if parentProxyURL != nil { + targetAddr = parentProxyURL.Host + } + + // 确保地址包含端口 + if !strings.Contains(targetAddr, ":") { + targetAddr += ":443" + } + + // 连接目标服务器 + targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err)) + return + } + defer targetConn.Close() + + // 向上级代理发送CONNECT请求 + if parentProxyURL != nil { + tunnelRequestLine := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", ctx.Req.URL.Host, ctx.Req.URL.Host) + _, err = targetConn.Write([]byte(tunnelRequestLine)) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 向上级代理发送CONNECT请求失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 读取上级代理响应 + bufReader := bufio.NewReader(targetConn) + resp, err := http.ReadResponse(bufReader, ctx.Req) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 读取上级代理响应失败: %s", ctx.Req.URL.Host, err)) + return + } + defer resp.Body.Close() + + // 检查上级代理响应 + if resp.StatusCode != http.StatusOK { + p.tunnelConnected(ctx, fmt.Errorf("上级代理返回错误状态码: %d", resp.StatusCode)) + p.delegate.ErrorLog(fmt.Errorf("%s - 上级代理返回错误状态码: %d", ctx.Req.URL.Host, resp.StatusCode)) + return + } + } + + // 处理HTTPS解密或直接隧道转发 + if p.config.DecryptHTTPS { + p.httpsProxy(ctx, tlsClientConn) + } else { + p.tunnelConnected(ctx, nil) + p.transfer(clientConn, targetConn) + } +} + +// WebSocket代理 +func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) { + if !p.config.WebSocketIntercept { + // 不拦截WebSocket,直接转发 + remoteAddr := ctx.Addr() + var err error + var targetConn net.Conn + + // 根据协议建立连接 + if ctx.IsHTTPS() { + targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) + } else { + targetConn, err = net.Dial("tcp", remoteAddr) + } + + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 将请求转发给目标 + err = ctx.Req.Write(targetConn) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 开始转发数据 + p.tunnelConnected(ctx, nil) + p.transfer(srcConn, targetConn) + return + } + + // 创建WebSocket升级器 + upgrader := websocket.Upgrader{ + HandshakeTimeout: defaultTargetConnectTimeout, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + + // 创建伪响应写入器,用于升级WebSocket连接 + rw := newResponseWriter(srcConn) + + // 升级源连接为WebSocket连接 + srcWSConn, err := upgrader.Upgrade(rw, ctx.Req, nil) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 升级WebSocket连接失败: %s", ctx.Req.URL.Host, err)) + return + } + defer srcWSConn.Close() + + // 构建目标URL + u := url.URL{ + Scheme: func() string { + if ctx.IsHTTPS() { + return "wss" + } + return "ws" + }(), + Host: ctx.Req.URL.Host, + Path: ctx.Req.URL.Path, + RawQuery: ctx.Req.URL.RawQuery, + } + + // 创建目标WebSocket连接 + dialer := websocket.Dialer{ + HandshakeTimeout: defaultTargetConnectTimeout, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + // 连接到目标WebSocket服务器 + targetWSConn, _, err := dialer.Dial(u.String(), ctx.Req.Header) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 连接目标WebSocket服务器失败: %s", ctx.Req.URL.Host, err)) + return + } + defer targetWSConn.Close() + + // 连接成功,通知 + p.tunnelConnected(ctx, nil) + + // 开始WebSocket消息转发 + p.transferWebSocket(ctx, srcWSConn, targetWSConn) +} + +// transferWebSocket 使用WebSocket协议进行双向消息转发 +func (p *Proxy) transferWebSocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) { + doneCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // 源到目标 + go func() { + for { + if doneCtx.Err() != nil { + return + } + + // 读取源消息,正确处理消息类型 + msgType, msg, err := srcConn.ReadMessage() + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", + srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err)) + cancel() // 取消另一个goroutine + return + } + + // 调用消息拦截接口 + p.delegate.WebSocketSendMessage(ctx, &msgType, &msg) + + // 写入目标,保留原始消息类型(文本/二进制) + err = targetConn.WriteMessage(msgType, msg) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", + srcConn.RemoteAddr().String(), targetConn.RemoteAddr().String(), err)) + cancel() // 取消另一个goroutine + return + } + } + }() + + // 目标到源 + for { + if doneCtx.Err() != nil { + return + } + + // 读取目标消息,正确处理消息类型 + msgType, msg, err := targetConn.ReadMessage() + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", + targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err)) + cancel() // 取消另一个goroutine + return + } + + // 调用消息拦截接口 + p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg) + + // 写入源,保留原始消息类型(文本/二进制) + err = srcConn.WriteMessage(msgType, msg) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", + targetConn.RemoteAddr().String(), srcConn.RemoteAddr().String(), err)) + cancel() // 取消另一个goroutine + return + } + } +} + +// 用于WebSocket升级的响应写入器 +type responseWriter struct { + conn *ConnBuffer + header http.Header + statusCode int +} + +func newResponseWriter(conn *ConnBuffer) *responseWriter { + return &responseWriter{ + conn: conn, + header: make(http.Header), + statusCode: http.StatusOK, + } +} + +func (rw *responseWriter) Header() http.Header { + return rw.header +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + return rw.conn.Write(b) +} + +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode +} + +// 双向转发 +func (p *Proxy) transfer(src net.Conn, dst net.Conn) { + // 创建完成通道 + done := make(chan struct{}, 2) + + // src -> dst + go func() { + buf := bufPool.Get() + written, err := io.CopyBuffer(dst, src, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) + } + + // 记录传输字节数 + if p.metrics != nil { + p.metrics.AddBytesTransferred("request", written) + } + + bufPool.Put(buf) + dst.Close() + done <- struct{}{} + }() + + // dst -> src + go func() { + buf := bufPool.Get() + written, err := io.CopyBuffer(src, dst, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err)) + } + + // 记录传输字节数 + if p.metrics != nil { + p.metrics.AddBytesTransferred("response", written) + } + + bufPool.Put(buf) + src.Close() + done <- struct{}{} + }() + + // 等待两个方向都结束 + <-done + <-done +} + +// 隧道连接处理 +func (p *Proxy) tunnelConnected(ctx *Context, err error) { + ctx.TunnelProxy = true + p.delegate.BeforeRequest(ctx) + if err != nil { + p.delegate.BeforeResponse(ctx, nil, err) + return + } + + resp := &http.Response{ + Status: "200 Connection Established", + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: http.NoBody, + } + p.delegate.BeforeResponse(ctx, resp, nil) +} + +// 使用DNS缓存的DialContext +func (p *Proxy) dialContextWithCache() func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + // 创建拨号器 + dialer := &net.Dialer{ + Timeout: defaultTargetConnectTimeout, + KeepAlive: 30 * time.Second, + } + + // 如果没有启用DNS缓存,直接拨号 + if p.dnsCache == nil { + return dialer.DialContext(ctx, network, addr) + } + + // 解析主机和端口 + separator := strings.LastIndex(addr, ":") + if separator < 0 { + return nil, fmt.Errorf("invalid address: %s", addr) + } + + host := addr[:separator] + port := addr[separator:] + + // 查询DNS缓存 + ips, err := p.dnsCache.Fetch(host) + if err != nil { + return nil, err + } + + // 使用第一个IPv4地址 + var ip string + for _, item := range ips { + ip = item.String() + if !strings.Contains(ip, ":") { + break + } + } + + if ip == "" { + return nil, fmt.Errorf("no valid IP address found for: %s", host) + } + + // 连接到解析后的IP + return dialer.DialContext(ctx, network, ip+port) + } +} + +// 从委托获取代理 +func (p *Proxy) proxyFromDelegate(req *http.Request) (*url.URL, error) { + if p.loadBalancer != nil && p.config.EnableLoadBalancing { + // 使用负载均衡 + host := req.URL.Hostname() + return p.loadBalancer.Next(host) + } + // 使用委托 + return p.delegate.ParentProxy(req) +} + +// 生成TLS配置 +func (p *Proxy) generateTLSConfig(host string) (*tls.Config, error) { + // 如果没有证书管理器,则创建一个 + if p.certManager == nil { + // 创建证书管理器,使用已有的证书缓存 + options := []CertManagerOption{ + WithDefaultPrivateKey(true), // 使用默认私钥提高性能 + WithValidityYears(1), // 证书有效期1年 + } + p.certManager = NewCertManager(p.certCache, options...) + } + + // 1. 首先检查是否配置了自定义证书 + if p.config.TLSCert != "" && p.config.TLSKey != "" { + cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey) + if err != nil { + return nil, fmt.Errorf("加载TLS证书失败: %s", err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, nil + } + + // 2. 检查是否配置了CA证书和密钥(用于动态生成证书) + if p.config.CACert != "" && p.config.CAKey != "" { + // 加载CA证书和私钥 + caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err)) + // 如果加载失败,使用默认CA + return p.certManager.GenerateTLSConfig(host) + } + + // 使用自定义CA生成证书 + cert, err := p.certManager.GenerateCertificate(host, caCert, caKey) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err)) + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + + // 3. 使用默认CA生成证书 + tlsConfig, err := p.certManager.GenerateTLSConfig(host) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err)) + return nil, err + } + + return tlsConfig, nil +} + +// 获取客户端连接 +func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) { + hijacker, ok := rw.(http.Hijacker) + if !ok { + return nil, fmt.Errorf("http server不支持Hijacker") + } + conn, bufrw, err := hijacker.Hijack() + if err != nil { + return nil, fmt.Errorf("hijacker错误: %s", err) + } + + return NewConnBuffer(conn, bufrw.Reader), nil +} + +// 检查是否是WebSocket请求 +func isWebSocketRequest(req *http.Request) bool { + if req == nil { + return false + } + + // 检查Connection头 + connection := strings.ToLower(req.Header.Get("Connection")) + if !strings.Contains(connection, "upgrade") { + return false + } + + // 检查Upgrade头 + upgrade := strings.ToLower(req.Header.Get("Upgrade")) + if upgrade != "websocket" { + return false + } + + return true +} + +// hop-by-hop 头部 +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// isCacheHitMetricsSupported 检查指标是否支持缓存命中计数 +func isCacheHitMetricsSupported(m metrics.MetricsCollector) bool { + _, ok := m.(interface{ IncCacheHit() }) + return ok +} + +// incrementCacheHit 增加缓存命中计数 +func incrementCacheHit(m metrics.MetricsCollector) { + if hitter, ok := m.(interface{ IncCacheHit() }); ok { + hitter.IncCacheHit() + } +} + +// SetDialContext 设置自定义的拨号上下文函数 +func (p *Proxy) SetDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) { + p.transport.DialContext = dialContext +} + +// convertToReverseConfig 将 config.Config 转换为 reverse.Config +func convertToReverseConfig(cfg *config.Config) *reverse.Config { + return &reverse.Config{ + BaseConfig: reverse.BaseConfig{ + ListenAddr: cfg.ListenAddr, + TargetAddr: cfg.TargetAddr, + EnableHTTPS: cfg.DecryptHTTPS, + TLSConfig: &reverse.TLSConfig{ + CertFile: cfg.TLSCert, + KeyFile: cfg.TLSKey, + InsecureSkipVerify: cfg.InsecureSkipVerify, + UseECDSA: cfg.UseECDSA, + }, + EnableWebSocket: cfg.SupportWebSocketUpgrade, + EnableCompression: cfg.EnableCompression, + EnableCORS: cfg.EnableCORS, + PreserveClientIP: cfg.PreserveClientIP, + AddXForwardedFor: cfg.AddXForwardedFor, + AddXRealIP: cfg.AddXRealIP, + }, + RulesFile: cfg.ReverseProxyRulesFile, + InsecureSkipVerify: cfg.InsecureSkipVerify, + EnableHealthCheck: cfg.EnableHealthCheck, + HealthCheckInterval: cfg.HealthCheckInterval, + HealthCheckTimeout: cfg.HealthCheckTimeout, + EnableRetry: cfg.EnableRetry, + MaxRetries: cfg.MaxRetries, + RetryBackoff: cfg.BaseBackoff, + MaxRetryBackoff: cfg.MaxBackoff, + EnableMetrics: cfg.EnableMetrics, + EnableTracing: cfg.EnableTracing, + WebSocketIntercept: cfg.WebSocketIntercept, + DNSCacheTTL: cfg.DNSCacheTTL, + EnableCache: cfg.EnableCache, + CacheTTL: cfg.CacheTTL, + EnableConnectionPool: cfg.EnableConnectionPool, + ConnectionPoolSize: cfg.ConnectionPoolSize, + IdleTimeout: cfg.IdleTimeout, + RequestTimeout: cfg.RequestTimeout, + } +}