From 21e0a73e5ced33f1ea17fe0f9212b2b38199d000 Mon Sep 17 00:00:00 2001 From: DarkiT Date: Thu, 13 Mar 2025 15:56:33 +0800 Subject: [PATCH] init --- .gitignore | 3 + README.md | 454 ++++++++++ SUMMARY.md | 73 ++ cmd/example/main.go | 165 ++++ cmd/reverse_proxy_example/main.go | 185 ++++ delegate.go | 132 +++ go.mod | 10 + go.sum | 8 + internal/cache/cache.go | 220 +++++ internal/config/config.go | 135 +++ internal/healthcheck/healthchecker.go | 248 ++++++ internal/loadbalance/loadbalancer.go | 330 ++++++++ internal/metrics/metrics.go | 250 ++++++ internal/middleware/ratelimiter.go | 184 ++++ internal/middleware/retry.go | 156 ++++ internal/proxy/certificate.go | 552 ++++++++++++ internal/proxy/conn_buffer.go | 134 +++ internal/proxy/context.go | 132 +++ internal/proxy/delegate.go | 123 +++ internal/proxy/options.go | 262 ++++++ internal/proxy/proxy.go | 1130 +++++++++++++++++++++++++ internal/proxy/reverse_proxy.go | 277 ++++++ internal/rewriter/rewriter.go | 98 +++ internal/router/router.go | 103 +++ mitm-proxy.crt | 14 + proxy.go | 753 ++++++++++++++++ 26 files changed, 6131 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 SUMMARY.md create mode 100644 cmd/example/main.go create mode 100644 cmd/reverse_proxy_example/main.go create mode 100644 delegate.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/cache/cache.go create mode 100644 internal/config/config.go create mode 100644 internal/healthcheck/healthchecker.go create mode 100644 internal/loadbalance/loadbalancer.go create mode 100644 internal/metrics/metrics.go create mode 100644 internal/middleware/ratelimiter.go create mode 100644 internal/middleware/retry.go create mode 100644 internal/proxy/certificate.go create mode 100644 internal/proxy/conn_buffer.go create mode 100644 internal/proxy/context.go create mode 100644 internal/proxy/delegate.go create mode 100644 internal/proxy/options.go create mode 100644 internal/proxy/proxy.go create mode 100644 internal/proxy/reverse_proxy.go create mode 100644 internal/rewriter/rewriter.go create mode 100644 internal/router/router.go create mode 100644 mitm-proxy.crt create mode 100644 proxy.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4306551 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea +./delegate.go +./proxy.go diff --git a/README.md b/README.md new file mode 100644 index 0000000..2f3bde8 --- /dev/null +++ b/README.md @@ -0,0 +1,454 @@ +# GoProxy + +GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebSocket代理,并提供了丰富的功能和扩展点。 + +## 功能特性 + +- 支持HTTP、HTTPS和WebSocket代理 +- 支持正向代理和反向代理 +- 支持HTTPS解密(中间人模式) + - 自定义CA证书和私钥 + - 动态证书生成与缓存 + - 通配符域名证书支持 + - 支持RSA和ECDSA证书算法选择 +- 支持上游代理链 +- 支持负载均衡(轮询、随机、权重等) +- 支持健康检查 +- 支持请求重试 +- 支持HTTP缓存 +- 支持请求限流 +- 支持监控指标收集 +- 支持自定义处理逻辑(委托模式) +- 支持DNS缓存 +- 支持URL重写(反向代理模式) + +## 安装 + +```bash +go get github.com/goproxy +``` + +## 快速开始 + +### 正向代理 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建代理 + p := proxy.New(nil) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +### 启用HTTPS解密 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建配置 + cfg := config.DefaultConfig() + cfg.DecryptHTTPS = true + cfg.CACert = "ca.crt" // CA证书路径 + cfg.CAKey = "ca.key" // CA私钥路径 + cfg.UseECDSA = true // 使用ECDSA生成证书(默认为false,使用RSA) + + // 可选:使用自定义TLS证书 + // cfg.TLSCert = "server.crt" + // cfg.TLSKey = "server.key" + + // 创建证书缓存 + certCache := &proxy.MemCertCache{} + + // 创建代理 + p := proxy.New(&proxy.Options{ + Config: cfg, + CertCache: certCache, + }) + + // 启动HTTP服务器 + log.Println("HTTPS解密代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +> **注意**: 使用HTTPS解密功能时,需要在客户端安装CA证书,否则会出现证书警告。 + +### 反向代理 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建配置 + cfg := config.DefaultConfig() + cfg.ReverseProxy = true + cfg.EnableURLRewrite = true + cfg.AddXForwardedFor = true + cfg.AddXRealIP = true + + // 创建自定义委托 + delegate := &ReverseProxyDelegate{ + backend: "localhost:8081", + } + + // 创建代理 + p := proxy.New(&proxy.Options{ + Config: cfg, + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("反向代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// ReverseProxyDelegate 反向代理委托 +type ReverseProxyDelegate struct { + proxy.DefaultDelegate + backend string +} + +// ResolveBackend 解析后端服务器 +func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) { + return d.backend, nil +} +``` + +### 自定义委托 + +```go +package main + +import ( + "log" + "net/http" + + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建自定义委托 + delegate := &CustomDelegate{} + + // 创建代理 + p := proxy.New(&proxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + proxy.DefaultDelegate +} + +// BeforeRequest 请求前事件 +func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) { + log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String()) +} + +// BeforeResponse 响应前事件 +func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) { + if err != nil { + log.Printf("响应错误: %v\n", err) + return + } + log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status) +} +``` + +### 完整示例 + +- 正向代理示例: [cmd/example/main.go](cmd/example/main.go) +- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go) + +### 使用函数式选项模式 + +```go +package main + +import ( + "log" + "net/http" + "time" + + "github.com/goproxy/internal/metrics" + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建监控指标 + metricsCollector := metrics.NewSimpleMetrics() + + // 创建证书缓存 + certCache := &proxy.MemCertCache{} + + // 使用函数式选项模式创建代理 + p := proxy.NewProxy( + // 启用HTTPS解密 + proxy.WithDecryptHTTPS(certCache), + proxy.WithCACertAndKey("ca.crt", "ca.key"), + + // 设置监控指标 + proxy.WithMetrics(metricsCollector), + + // 设置请求超时和连接池 + proxy.WithRequestTimeout(30 * time.Second), + proxy.WithConnectionPoolSize(100), + proxy.WithIdleTimeout(90 * time.Second), + + // 启用DNS缓存 + proxy.WithDNSCacheTTL(10 * time.Minute), + + // 启用请求重试 + proxy.WithEnableRetry(3, 1*time.Second, 10*time.Second), + + // 启用CORS支持 + proxy.WithEnableCORS(true), + ) + + // 启动HTTP服务器和监控服务器 + go func() { + log.Println("监控服务器启动在 :8081") + http.Handle("/metrics", metricsCollector) + if err := http.ListenAndServe(":8081", nil); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + + // 启动代理服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} +``` + +## 架构设计 + +GoProxy采用模块化设计,主要包含以下模块: + +- **代理核心(Proxy)**:处理HTTP请求和响应,实现代理功能 +- **反向代理(ReverseProxy)**:处理反向代理请求,支持URL重写和请求修改 +- **路由(Router)**:基于主机名、路径、正则表达式等规则路由请求到不同的后端 +- **URL重写(Rewriter)**:重写请求URL和响应中的URL +- **代理上下文(Context)**:保存请求上下文信息,用于在处理过程中传递数据 +- **代理委托(Delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑 +- **连接缓冲区(ConnBuffer)**:封装网络连接,提供缓冲读写功能 +- **负载均衡(LoadBalancer)**:实现负载均衡算法,支持轮询、随机、权重等 +- **健康检查(HealthChecker)**:检查上游服务器的健康状态,自动剔除不健康的服务器 +- **缓存(Cache)**:实现HTTP缓存,减少重复请求 +- **缓存适配器(CacheAdapter)**:统一不同缓存实现的接口,提高代码可读性和性能 +- **证书生成(CertGenerator)**:动态生成TLS证书,支持HTTPS解密 +- **限流(RateLimit)**:实现请求限流,防止过载 +- **监控(Metrics)**:收集代理运行指标,用于监控和分析 +- **重试(Retry)**:实现请求重试,提高请求成功率 + +## 配置选项 + +GoProxy提供了丰富的配置选项,可以通过`Options`结构体进行配置: + +```go +type Options struct { + // 配置 + Config *config.Config + // 委托 + Delegate Delegate + // 证书缓存 + CertCache CertificateCache + // HTTP缓存 + HTTPCache cache.Cache + // 负载均衡器 + LoadBalancer loadbalance.LoadBalancer + // 健康检查器 + HealthChecker *healthcheck.HealthChecker + // 监控指标 + Metrics metrics.Metrics + // 客户端跟踪 + ClientTrace *httptrace.ClientTrace +} +``` + +### 函数式选项模式 + +GoProxy 现在支持函数式选项模式(Functional Options Pattern),通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于: + +- 参数配置更加直观和清晰 +- 可以灵活选择需要的配置项,不必记忆参数顺序 +- 代码可读性更高,便于维护 +- 可以逐步添加新的配置选项而不破坏兼容性 + +可以使用 `NewProxy` 函数和函数式选项创建代理: + +```go +// 创建一个简单的代理 +proxy := proxy.NewProxy() + +// 创建一个功能丰富的代理 +proxy := proxy.NewProxy( + proxy.WithConfig(config.DefaultConfig()), + proxy.WithHTTPCache(myCache), + proxy.WithDecryptHTTPS(myCertCache), + proxy.WithCACertAndKey("ca.crt", "ca.key"), + proxy.WithMetrics(myMetrics), + proxy.WithLoadBalancer(myLoadBalancer), + proxy.WithRequestTimeout(10 * time.Second), + proxy.WithEnableCORS(true) +) +``` + +### 可用的 With 方法 + +GoProxy 提供了以下 With 方法用于配置代理的各个方面: + +#### 基础配置选项 +- `WithConfig(cfg *config.Config)`: 设置代理配置 +- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用 +- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输 +- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪 + +#### 功能模块选项 +- `WithDelegate(delegate Delegate)`: 设置委托类 +- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存 +- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器 +- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器 +- `WithMetrics(m metrics.Metrics)`: 设置监控指标 + +#### 功能开启选项 +- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS +- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成(默认使用RSA) +- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截 +- `WithReverseProxy(enable bool)`: 启用反向代理模式 +- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试 +- `WithRateLimit(rps float64)`: 设置请求限流 +- `WithURLRewrite(enable bool)`: 启用URL重写 +- `WithEnableCORS(enable bool)`: 启用CORS支持 + +#### 证书相关选项 +- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥 +- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥 + +#### 性能和超时相关选项 +- `WithConnectionPoolSize(size int)`: 设置连接池大小 +- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间 +- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间 +- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL + +### 配置HTTPS解密 + +要启用HTTPS解密功能,需要在配置中设置以下选项: + +```go +config := &config.Config{ + // 启用HTTPS解密 + DecryptHTTPS: true, + + // 方式一:使用CA证书和私钥动态生成证书 + CACert: "path/to/ca.crt", // CA证书路径 + CAKey: "path/to/ca.key", // CA私钥路径 + + // 选择证书生成算法(可选) + UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA) + + // 方式二:使用固定的TLS证书和私钥 + // TLSCert: "path/to/server.crt", + // TLSKey: "path/to/server.key", +} +``` + +或者使用函数式选项模式: + +```go +proxy := proxy.NewProxy( + proxy.WithDecryptHTTPS(&proxy.MemCertCache{}), + proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"), + proxy.WithEnableECDSA(true), // 使用ECDSA生成证书 + // 或者使用静态TLS证书 + // proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key") +) +``` + +同时,建议配置证书缓存以提高性能: + +```go +certCache := &proxy.MemCertCache{} +``` + +## 扩展点 + +GoProxy提供了多个扩展点,可以通过实现相应的接口进行扩展: + +- **Delegate**:代理委托接口,用于自定义代理处理逻辑 +- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法 +- **Cache**:缓存接口,用于实现自定义缓存策略 +- **CertificateCache**:证书缓存接口,用于自定义证书存储方式 +- **Metrics**:监控接口,用于实现自定义监控指标收集 + +## 反向代理特性 + +GoProxy的反向代理模式提供以下特性: + +- **URL重写**:支持基于前缀和正则表达式的URL重写 +- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则 +- **请求修改**:支持修改发往后端服务器的请求 +- **响应修改**:支持修改来自后端服务器的响应 +- **保留客户端信息**:支持添加X-Forwarded-For和X-Real-IP头 +- **CORS支持**:支持自动添加CORS头 +- **WebSocket支持**:支持WebSocket协议的透明代理 +- **负载均衡**:支持多种负载均衡算法 +- **健康检查**:支持对后端服务器进行健康检查 +- **监控指标**:支持收集反向代理的监控指标 + +## 贡献 + +欢迎贡献代码、报告问题或提出建议。请遵循以下步骤: + +1. Fork 项目 +2. 创建特性分支 (`git checkout -b feature/amazing-feature`) +3. 提交更改 (`git commit -m 'Add some amazing feature'`) +4. 推送到分支 (`git push origin feature/amazing-feature`) +5. 创建 Pull Request + +## 许可证 + +本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。 \ No newline at end of file diff --git a/SUMMARY.md b/SUMMARY.md new file mode 100644 index 0000000..e3fe759 --- /dev/null +++ b/SUMMARY.md @@ -0,0 +1,73 @@ +# GoProxy 项目总结 + +## 项目概述 + +GoProxy 是一个功能强大的 Go 语言 HTTP 代理库,支持 HTTP、HTTPS 和 WebSocket 代理,并提供了丰富的功能和扩展点。该项目采用模块化设计,各个模块之间职责明确,耦合度低,便于扩展和维护。 + +## 已完成的模块 + +1. **配置模块(config)**:提供代理配置选项,包括连接池大小、超时时间、是否启用缓存、负载均衡等。 + +2. **代理上下文(context)**:保存请求上下文信息,用于在代理处理过程中传递数据,包括原始请求、目标地址、上级代理等。 + +3. **代理委托(delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑,包括连接、认证、请求前、响应前等事件。 + +4. **连接缓冲区(conn_buffer)**:封装网络连接,提供缓冲读写功能,简化网络 IO 操作。 + +5. **负载均衡(loadbalance)**:实现负载均衡算法,支持轮询、随机、权重等,自动选择合适的上游服务器。 + +6. **健康检查(healthcheck)**:检查上游服务器的健康状态,自动剔除不健康的服务器,提高代理可靠性。 + +7. **缓存(cache)**:实现 HTTP 缓存,减少重复请求,提高代理性能。 + +8. **缓存适配器(cache_adapter)**:统一不同缓存实现的接口,使用适配器模式提高代码可读性和执行效率,支持多种缓存实现方式。 + +9. **证书生成(cert_generator)**:动态生成TLS证书,支持HTTPS解密(中间人模式),可基于CA证书和私钥创建域名证书,并支持证书缓存。支持RSA和ECDSA两种算法,用户可根据需要选择安全性和性能的平衡。 + +10. **限流(ratelimit)**:实现请求限流,防止过载,保护上游服务器。 + +11. **监控(metrics)**:收集代理运行指标,用于监控和分析,包括请求数、响应时间、错误数等。 + +12. **重试(retry)**:实现请求重试,提高请求成功率,处理临时性故障。 + +13. **代理核心(proxy)**:处理 HTTP 请求和响应,实现代理功能,包括 HTTP、HTTPS 和 WebSocket 代理。 + +## 项目特点 + +1. **模块化设计**:各个模块职责明确,耦合度低,便于扩展和维护。 + +2. **丰富的功能**:支持 HTTP、HTTPS 和 WebSocket 代理,并提供了负载均衡、健康检查、缓存、限流、监控等功能。 + +3. **灵活的扩展点**:提供了多个扩展点,可以通过实现相应的接口进行扩展,如代理委托、负载均衡、缓存、监控等。 + +4. **高性能**:采用 Go 语言的并发特性,实现高性能的代理服务。优化的缓存适配器和证书缓存机制进一步提升性能。 + +5. **可靠性**:通过健康检查、重试等机制,提高代理的可靠性。 + +6. **可观测性**:通过监控指标收集,提高代理的可观测性。 + +7. **安全性**:支持HTTPS解密功能,可用于调试、安全审计和内容过滤。提供RSA和ECDSA双算法选择,满足不同的安全需求和性能场景。 + +8. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持,以及对更多密钥算法(如Ed25519等)的支持。 + +## 使用示例 + +我们提供了一个完整的示例程序,展示了如何使用 GoProxy 库创建一个功能完善的代理服务器,包括负载均衡、健康检查、缓存、监控等功能。另外还提供了启用HTTPS解密功能的示例,展示如何使用CA证书动态生成站点证书。 + +## 未来计划 + +1. **完善文档**:编写更详细的文档,包括 API 文档、使用示例、最佳实践等。 + +2. **增加测试**:增加单元测试和集成测试,提高代码质量和可靠性。 + +3. **性能优化**:进一步优化代理性能,减少资源消耗。 + +4. **增加更多功能**:如请求过滤、内容修改、安全检查等。 + +5. **提供更多扩展点**:如请求路由、请求转换等。 + +6. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持。 + +## 总结 + +GoProxy 是一个功能强大、设计良好的 Go 语言 HTTP 代理库,可以满足各种代理需求,如开发调试、负载均衡、API 网关等。通过模块化设计和丰富的扩展点,GoProxy 可以灵活地适应各种场景,是构建代理服务的理想选择。最近的增强包括缓存适配器优化和完整的HTTPS解密功能实现,使GoProxy更加完善和高效。 \ No newline at end of file diff --git a/cmd/example/main.go b/cmd/example/main.go new file mode 100644 index 0000000..4abd4c2 --- /dev/null +++ b/cmd/example/main.go @@ -0,0 +1,165 @@ +package main + +import ( + "flag" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/goproxy/internal/cache" + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/healthcheck" + "github.com/goproxy/internal/loadbalance" + "github.com/goproxy/internal/metrics" + "github.com/goproxy/internal/proxy" +) + +var ( + // 监听地址 + addr = flag.String("addr", ":8080", "代理服务器监听地址") + // 上游代理服务器 + upstream = flag.String("upstream", "", "上游代理服务器地址,多个地址用逗号分隔") + // 是否启用负载均衡 + enableLoadBalance = flag.Bool("enable-lb", false, "是否启用负载均衡") + // 是否启用健康检查 + enableHealthCheck = flag.Bool("enable-hc", false, "是否启用健康检查") + // 是否启用缓存 + enableCache = flag.Bool("enable-cache", false, "是否启用缓存") + // 是否启用重试 + enableRetry = flag.Bool("enable-retry", false, "是否启用重试") + // 是否启用监控 + enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控") + // 监控地址 + metricsAddr = flag.String("metrics-addr", ":8081", "监控服务器监听地址") +) + +// 解析目标地址 +func parseTargets(targets string) []string { + if targets == "" { + return nil + } + return strings.Split(targets, ",") +} + +func main() { + // 解析命令行参数 + flag.Parse() + + // 创建配置 + cfg := config.DefaultConfig() + cfg.EnableLoadBalancing = *enableLoadBalance + cfg.EnableHealthCheck = *enableHealthCheck + cfg.EnableCache = *enableCache + cfg.EnableRetry = *enableRetry + + // 创建选项 + opts := &proxy.Options{ + Config: cfg, + } + + // 创建负载均衡器 + if *enableLoadBalance && *upstream != "" { + lb := loadbalance.NewRoundRobinBalancer() + for _, target := range parseTargets(*upstream) { + lb.Add(target, 1) + } + opts.LoadBalancer = lb + } + + // 创建健康检查器 + if *enableHealthCheck && opts.LoadBalancer != nil { + hc := healthcheck.NewHealthChecker(&healthcheck.Config{ + Interval: time.Second * 10, + Timeout: time.Second * 2, + MaxFails: 3, + MinSuccess: 2, + }) + opts.HealthChecker = hc + } + + // 创建缓存 + if *enableCache { + c := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000) + opts.HTTPCache = c + } + + // 创建监控 + if *enableMetrics { + m := metrics.NewSimpleMetrics() + opts.Metrics = m + + // 启动监控服务器 + go func() { + mux := http.NewServeMux() + handler := m.GetHandler() + mux.Handle("/metrics", handler) + log.Printf("监控服务器启动在 %s\n", *metricsAddr) + if err := http.ListenAndServe(*metricsAddr, mux); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + } + + // 创建自定义委托 + delegate := &CustomDelegate{} + opts.Delegate = delegate + + // 创建代理 + p := proxy.New(opts) + + // 创建HTTP服务器 + server := &http.Server{ + Addr: *addr, + Handler: p, + } + + // 启动HTTP服务器 + go func() { + log.Printf("代理服务器启动在 %s\n", *addr) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("代理服务器启动失败: %v", err) + } + }() + + // 等待信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("正在关闭代理服务器...") + server.Close() + log.Println("代理服务器已关闭") +} + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + proxy.DefaultDelegate +} + +// Connect 连接事件 +func (d *CustomDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) { + log.Printf("收到连接: %s -> %s\n", ctx.Req.RemoteAddr, ctx.Req.URL.Host) +} + +// BeforeRequest 请求前事件 +func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) { + log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String()) +} + +// BeforeResponse 响应前事件 +func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) { + if err != nil { + log.Printf("响应错误: %v\n", err) + return + } + log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status) +} + +// ErrorLog 错误日志 +func (d *CustomDelegate) ErrorLog(err error) { + log.Printf("错误: %v\n", err) +} diff --git a/cmd/reverse_proxy_example/main.go b/cmd/reverse_proxy_example/main.go new file mode 100644 index 0000000..19edef4 --- /dev/null +++ b/cmd/reverse_proxy_example/main.go @@ -0,0 +1,185 @@ +package main + +import ( + "flag" + "log" + "net/http" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/metrics" + "github.com/goproxy/internal/proxy" +) + +var ( + // 监听地址 + addr = flag.String("addr", ":8080", "反向代理服务器监听地址") + // 后端服务器 + backend = flag.String("backend", "localhost:8081", "后端服务器地址") + // 路由规则文件 + routeFile = flag.String("route-file", "", "路由规则文件路径") + // 是否启用URL重写 + enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写") + // 是否启用缓存 + enableCache = flag.Bool("enable-cache", false, "是否启用缓存") + // 是否启用压缩 + enableCompression = flag.Bool("enable-compression", false, "是否启用压缩") + // 是否启用监控 + enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控") + // 监控地址 + metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址") + // 是否添加X-Forwarded-For + addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头") + // 是否添加X-Real-IP + addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头") + // 是否启用CORS + enableCORS = flag.Bool("enable-cors", false, "是否启用CORS") + // 路径前缀 + pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除") +) + +func main() { + // 解析命令行参数 + flag.Parse() + + // 创建配置 + cfg := config.DefaultConfig() + cfg.ReverseProxy = true + cfg.EnableCache = *enableCache + cfg.EnableCompression = *enableCompression + cfg.EnableURLRewrite = *enableRewrite + cfg.AddXForwardedFor = *addXForwardedFor + cfg.AddXRealIP = *addXRealIP + cfg.EnableCORS = *enableCORS + cfg.ReverseProxyRulesFile = *routeFile + + // 创建选项 + opts := &proxy.Options{ + Config: cfg, + } + + // 创建监控 + if *enableMetrics { + m := metrics.NewSimpleMetrics() + opts.Metrics = m + + // 启动监控服务器 + go func() { + mux := http.NewServeMux() + handler := m.GetHandler() + mux.Handle("/metrics", handler) + log.Printf("监控服务器启动在 %s\n", *metricsAddr) + if err := http.ListenAndServe(*metricsAddr, mux); err != nil { + log.Fatalf("监控服务器启动失败: %v", err) + } + }() + } + + // 创建自定义委托 + delegate := &ReverseProxyDelegate{ + backend: *backend, + prefix: *pathPrefix, + } + opts.Delegate = delegate + + // 创建代理 + p := proxy.New(opts) + + // 如果有路径前缀,添加重写规则 + if *pathPrefix != "" { + reverseProxy := p.NewReverseProxy() + log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix) + reverseProxy.AddRewriteRule(*pathPrefix, "", false) + } + + // 创建HTTP服务器 + server := &http.Server{ + Addr: *addr, + Handler: p, + } + + // 启动HTTP服务器 + go func() { + log.Printf("反向代理服务器启动在 %s,后端服务器为 %s\n", *addr, *backend) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("代理服务器启动失败: %v", err) + } + }() + + // 等待信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("正在关闭代理服务器...") + server.Close() + log.Println("代理服务器已关闭") +} + +// ReverseProxyDelegate 反向代理委托 +type ReverseProxyDelegate struct { + proxy.DefaultDelegate + backend string + prefix string +} + +// ResolveBackend 解析后端服务器 +func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) { + // 这里可以实现基于请求路径、主机名等的路由逻辑 + return d.backend, nil +} + +// ModifyRequest 修改请求 +func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) { + // 移除路径前缀 + if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) { + req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix) + if req.URL.Path == "" { + req.URL.Path = "/" + } + } + + // 添加自定义请求头 + req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339)) +} + +// ModifyResponse 修改响应 +func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error { + // 添加自定义响应头 + resp.Header.Set("X-Proxied-By", "GoProxy") + return nil +} + +// Connect 连接事件 +func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) { + log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path) +} + +// BeforeRequest 请求前事件 +func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) { + log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path) +} + +// BeforeResponse 响应前事件 +func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) { + if err != nil { + log.Printf("响应错误: %v\n", err) + return + } + log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status) +} + +// ErrorLog 错误日志 +func (d *ReverseProxyDelegate) ErrorLog(err error) { + log.Printf("错误: %v\n", err) +} + +// HandleError 处理错误 +func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) { + log.Printf("处理错误: %v\n", err) + http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway) +} diff --git a/delegate.go b/delegate.go new file mode 100644 index 0000000..1a617d8 --- /dev/null +++ b/delegate.go @@ -0,0 +1,132 @@ +// Copyright 2018 ouqiang authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +package goproxy + +import ( + "log" + "net/http" + "net/url" + "strings" +) + +// Context 代理上下文 +type Context struct { + Req *http.Request + Data map[interface{}]interface{} + TunnelProxy bool + abort bool +} + +func (c *Context) IsHTTPS() bool { + return c.Req.URL.Scheme == "https" +} + +var defaultPorts = map[string]string{ + "https": "443", + "http": "80", + "": "80", +} + +func (c *Context) WebsocketUrl() *url.URL { + u := new(url.URL) + *u = *c.Req.URL + if c.IsHTTPS() { + u.Scheme = "wss" + } else { + u.Scheme = "ws" + } + + return u +} + +func (c *Context) Addr() string { + addr := c.Req.Host + + if !strings.Contains(c.Req.URL.Host, ":") { + addr += ":" + defaultPorts[c.Req.URL.Scheme] + } + + return addr +} + +// Abort 中断执行 +func (c *Context) Abort() { + c.abort = true +} + +// IsAborted 是否已中断执行 +func (c *Context) IsAborted() bool { + return c.abort +} + +// Reset 重置 +func (c *Context) Reset(req *http.Request) { + c.Req = req + c.Data = make(map[interface{}]interface{}) + c.abort = false + c.TunnelProxy = false +} + +type Delegate interface { + // Connect 收到客户端连接 + Connect(ctx *Context, rw http.ResponseWriter) + // Auth 代理身份认证 + Auth(ctx *Context, rw http.ResponseWriter) + // BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body + BeforeRequest(ctx *Context) + // BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code + BeforeResponse(ctx *Context, resp *http.Response, err error) + // WebSocketSendMessage websocket发送消息 + WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte) + // WebSockerReceiveMessage websocket接收 消息 + WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte) + // ParentProxy 上级代理 + ParentProxy(*http.Request) (*url.URL, error) + // Finish 本次请求结束 + Finish(ctx *Context) + // 记录错误信息 + ErrorLog(err error) +} + +var _ Delegate = &DefaultDelegate{} + +// DefaultDelegate 默认Handler什么也不做 +type DefaultDelegate struct { + Delegate +} + +func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {} + +func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {} + +func (h *DefaultDelegate) BeforeRequest(ctx *Context) {} + +func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {} + +func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) { + return http.ProxyFromEnvironment(req) +} + +// WebSocketSendMessage websocket发送消息 +func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {} + +// WebSockerReceiveMessage websocket接收 消息 +func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {} + +func (h *DefaultDelegate) Finish(ctx *Context) {} + +func (h *DefaultDelegate) ErrorLog(err error) { + log.Println(err) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..65f955a --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/goproxy + +go 1.24.0 + +require ( + github.com/ouqiang/goproxy v1.3.2 + github.com/ouqiang/websocket v1.6.2 + github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 + golang.org/x/time v0.11.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3e22dbe --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk= +github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww= +github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U= +github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M= +github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ= +github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..1e06f2d --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,220 @@ +package cache + +import ( + "bytes" + "crypto/md5" + "encoding/hex" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// Cache 缓存接口 +type Cache interface { + // Get 获取缓存 + Get(key string) (*http.Response, bool) + // Set 设置缓存 + Set(key string, resp *http.Response) + // Delete 删除缓存 + Delete(key string) + // Clear 清空缓存 + Clear() +} + +// MemoryCache 内存缓存实现 +type MemoryCache struct { + // 缓存内容 + items sync.Map + // 过期时间 + ttl time.Duration + // 清理间隔 + cleanupInterval time.Duration + // 最大条目数 + maxEntries int + // 当前条目数 + size int32 + // 互斥锁 + mu sync.Mutex +} + +// CacheItem 缓存项 +type CacheItem struct { + response *http.Response + responseBody []byte + expiry time.Time +} + +// NewMemoryCache 创建内存缓存 +func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache { + cache := &MemoryCache{ + ttl: ttl, + cleanupInterval: cleanupInterval, + maxEntries: maxEntries, + } + + // 启动过期清理 + if cleanupInterval > 0 { + go cache.startCleanup() + } + + return cache +} + +// Get 获取缓存 +func (c *MemoryCache) Get(key string) (*http.Response, bool) { + value, ok := c.items.Load(key) + if !ok { + return nil, false + } + + item := value.(*CacheItem) + if time.Now().After(item.expiry) { + c.Delete(key) + return nil, false + } + + // 克隆响应,避免修改原始数据 + resp := cloneResponse(item.response, item.responseBody) + return resp, true +} + +// Set 设置缓存 +func (c *MemoryCache) Set(key string, resp *http.Response) { + // 检查缓存是否已满 + c.mu.Lock() + if c.maxEntries > 0 && c.size >= int32(c.maxEntries) { + c.mu.Unlock() + return + } + c.size++ + c.mu.Unlock() + + // 读取并保存响应体 + var bodyBytes []byte + if resp.Body != nil { + bodyBytes, _ = io.ReadAll(resp.Body) + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + } + + item := &CacheItem{ + response: resp, + responseBody: bodyBytes, + expiry: time.Now().Add(c.ttl), + } + + c.items.Store(key, item) +} + +// Delete 删除缓存 +func (c *MemoryCache) Delete(key string) { + c.items.Delete(key) + c.mu.Lock() + c.size-- + if c.size < 0 { + c.size = 0 + } + c.mu.Unlock() +} + +// Clear 清空缓存 +func (c *MemoryCache) Clear() { + c.items = sync.Map{} + c.mu.Lock() + c.size = 0 + c.mu.Unlock() +} + +// startCleanup 启动过期清理 +func (c *MemoryCache) startCleanup() { + ticker := time.NewTicker(c.cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + now := time.Now() + c.items.Range(func(key, value interface{}) bool { + item := value.(*CacheItem) + if now.After(item.expiry) { + c.Delete(key.(string)) + } + return true + }) + } +} + +// GenerateCacheKey 生成缓存键 +func GenerateCacheKey(req *http.Request) string { + // 忽略一些可变的头部 + ignoredHeaders := map[string]bool{ + "Connection": true, + "Keep-Alive": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "TE": true, + "Trailers": true, + "Transfer-Encoding": true, + "Upgrade": true, + } + + // 提取缓存键组件 + components := []string{ + req.Method, + req.URL.String(), + } + + // 添加选择性头部 + for key, values := range req.Header { + if !ignoredHeaders[key] { + for _, value := range values { + components = append(components, key+":"+value) + } + } + } + + // 连接并计算哈希 + data := strings.Join(components, "|") + hash := md5.New() + hash.Write([]byte(data)) + return hex.EncodeToString(hash.Sum(nil)) +} + +// cloneResponse 克隆HTTP响应 +func cloneResponse(resp *http.Response, body []byte) *http.Response { + clone := *resp + clone.Body = io.NopCloser(bytes.NewBuffer(body)) + clone.Header = make(http.Header) + + for k, v := range resp.Header { + clone.Header[k] = v + } + + return &clone +} + +// ShouldCache 判断请求是否应该缓存 +func ShouldCache(req *http.Request, resp *http.Response) bool { + // 只缓存GET请求 + if req.Method != http.MethodGet { + return false + } + + // 检查响应状态码 + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusNotModified && + resp.StatusCode != http.StatusMovedPermanently && + resp.StatusCode != http.StatusPermanentRedirect { + return false + } + + // 检查Cache-Control头 + cacheControl := resp.Header.Get("Cache-Control") + if strings.Contains(cacheControl, "no-store") || + strings.Contains(cacheControl, "no-cache") || + strings.Contains(cacheControl, "private") { + return false + } + + return true +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..eeebd59 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,135 @@ +package config + +import ( + "time" +) + +// Config 代理配置 +type Config struct { + // 监听地址 + ListenAddr string + // 是否启用负载均衡 + EnableLoadBalancing bool + // 负载均衡后端列表 + Backends []string + // 是否启用限流 + EnableRateLimit bool + // 每秒请求速率限制 + RateLimit float64 + // 并发请求峰值限制 + MaxBurst int + // 最大连接数 + MaxConnections int + // 是否启用连接池 + EnableConnectionPool bool + // 连接池大小 + ConnectionPoolSize int + // 连接空闲超时时间 + IdleTimeout time.Duration + // 请求超时时间 + RequestTimeout time.Duration + // 是否启用响应缓存 + EnableCache bool + // 缓存过期时间 + CacheTTL time.Duration + // 是否启用HTTPS解密 + DecryptHTTPS bool + // TLS证书文件路径 + TLSCert string + // TLS密钥文件路径 + TLSKey string + // CA证书文件路径(用于生成动态证书) + CACert string + // CA密钥文件路径(用于生成动态证书) + CAKey string + // 是否启用健康检查 + EnableHealthCheck bool + // 健康检查间隔时间 + HealthCheckInterval time.Duration + // 健康检查超时时间 + HealthCheckTimeout time.Duration + // 是否启用重试机制 + EnableRetry bool + // 最大重试次数 + MaxRetries int + // 重试间隔基数 + RetryBackoff time.Duration + // 最大重试间隔 + MaxRetryBackoff time.Duration + // 是否启用监控指标 + EnableMetrics bool + // 是否启用请求追踪 + EnableTracing bool + // 是否拦截WebSocket + WebSocketIntercept bool + // DNS缓存过期时间 + DNSCacheTTL time.Duration + // 是否作为反向代理 + ReverseProxy bool + // 反向代理规则文件路径 + ReverseProxyRulesFile string + // 是否启用URL重写 + EnableURLRewrite bool + // 是否保留客户端IP + PreserveClientIP bool + // 是否启用压缩 + EnableCompression bool + // 是否自动添加CORS头 + EnableCORS bool + // 重写Host头 + RewriteHostHeader bool + // 是否添加X-Forwarded-For头 + AddXForwardedFor bool + // 是否添加X-Real-IP头 + AddXRealIP bool + // 是否支持Websocket升级 + SupportWebSocketUpgrade bool + // 是否使用ECDSA生成证书(默认使用RSA) + UseECDSA bool +} + +// DefaultConfig 默认配置 +func DefaultConfig() *Config { + return &Config{ + ListenAddr: ":8080", + EnableLoadBalancing: false, + Backends: []string{}, + EnableRateLimit: false, + RateLimit: 100, + MaxBurst: 50, + MaxConnections: 1000, + EnableConnectionPool: true, + ConnectionPoolSize: 100, + IdleTimeout: 60 * time.Second, + RequestTimeout: 30 * time.Second, + EnableCache: false, + CacheTTL: 5 * time.Minute, + DecryptHTTPS: false, + TLSCert: "", + TLSKey: "", + CACert: "", + CAKey: "", + EnableHealthCheck: false, + HealthCheckInterval: 10 * time.Second, + HealthCheckTimeout: 5 * time.Second, + EnableRetry: true, + MaxRetries: 3, + RetryBackoff: 100 * time.Millisecond, + MaxRetryBackoff: 2 * time.Second, + EnableMetrics: false, + EnableTracing: false, + WebSocketIntercept: false, + DNSCacheTTL: 5 * time.Minute, + ReverseProxy: false, + ReverseProxyRulesFile: "", + EnableURLRewrite: false, + PreserveClientIP: true, + EnableCompression: false, + EnableCORS: false, + RewriteHostHeader: false, + AddXForwardedFor: true, + AddXRealIP: true, + SupportWebSocketUpgrade: true, + UseECDSA: false, + } +} diff --git a/internal/healthcheck/healthchecker.go b/internal/healthcheck/healthchecker.go new file mode 100644 index 0000000..5859597 --- /dev/null +++ b/internal/healthcheck/healthchecker.go @@ -0,0 +1,248 @@ +package healthcheck + +import ( + "context" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +// HealthChecker 健康检查器 +type HealthChecker struct { + // 配置 + config *Config + // 健康检查状态 + statusMap sync.Map + // 状态变更回调 + statusChangeCallback func(string, bool) + // 是否运行中 + running bool + // 上下文 + ctx context.Context + // 取消函数 + cancel context.CancelFunc + // 互斥锁 + mu sync.Mutex +} + +// Config 健康检查配置 +type Config struct { + // 检查间隔 + Interval time.Duration + // 检查超时 + Timeout time.Duration + // 检查路径 + Path string + // 检查方法 + Method string + // 检查状态码 + SuccessStatus int + // 最大失败次数 + MaxFails int + // 最小成功次数 + MinSuccess int +} + +// NewHealthChecker 创建健康检查器 +func NewHealthChecker(config *Config) *HealthChecker { + if config.Path == "" { + config.Path = "/" + } + if config.Method == "" { + config.Method = http.MethodGet + } + if config.SuccessStatus == 0 { + config.SuccessStatus = http.StatusOK + } + if config.MaxFails == 0 { + config.MaxFails = 3 + } + if config.MinSuccess == 0 { + config.MinSuccess = 2 + } + + ctx, cancel := context.WithCancel(context.Background()) + return &HealthChecker{ + config: config, + ctx: ctx, + cancel: cancel, + running: false, + } +} + +// Start 启动健康检查 +func (hc *HealthChecker) Start() { + hc.mu.Lock() + defer hc.mu.Unlock() + + if hc.running { + return + } + + hc.running = true + go hc.run() +} + +// Stop 停止健康检查 +func (hc *HealthChecker) Stop() { + hc.mu.Lock() + defer hc.mu.Unlock() + + if !hc.running { + return + } + + hc.cancel() + hc.running = false +} + +// AddTarget 添加监控目标 +func (hc *HealthChecker) AddTarget(target string) error { + u, err := url.Parse(target) + if err != nil { + return err + } + + // 初始化为健康状态 + hc.statusMap.Store(u.String(), &backendStatus{ + URL: u, + Healthy: true, + FailCount: 0, + SuccessCount: 0, + }) + + return nil +} + +// RemoveTarget 移除监控目标 +func (hc *HealthChecker) RemoveTarget(target string) error { + u, err := url.Parse(target) + if err != nil { + return err + } + + hc.statusMap.Delete(u.String()) + return nil +} + +// IsHealthy 检查目标是否健康 +func (hc *HealthChecker) IsHealthy(target string) bool { + u, err := url.Parse(target) + if err != nil { + return false + } + + value, ok := hc.statusMap.Load(u.String()) + if !ok { + return false + } + + status := value.(*backendStatus) + return status.Healthy +} + +// SetStatusChangeCallback 设置状态变更回调 +func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) { + hc.statusChangeCallback = callback +} + +// backendStatus 后端健康状态 +type backendStatus struct { + URL *url.URL + Healthy bool + FailCount int + SuccessCount int +} + +// run 运行健康检查 +func (hc *HealthChecker) run() { + ticker := time.NewTicker(hc.config.Interval) + defer ticker.Stop() + + for { + select { + case <-hc.ctx.Done(): + return + case <-ticker.C: + hc.checkAll() + } + } +} + +// checkAll 检查所有后端 +func (hc *HealthChecker) checkAll() { + hc.statusMap.Range(func(key, value interface{}) bool { + go hc.check(key.(string), value.(*backendStatus)) + return true + }) +} + +// check 检查单个后端 +func (hc *HealthChecker) check(key string, status *backendStatus) { + // 创建检查请求 + u := *status.URL + u.Path = hc.config.Path + req, err := http.NewRequest(hc.config.Method, u.String(), nil) + if err != nil { + hc.updateStatus(key, status, false) + return + } + + // 设置超时的客户端 + client := &http.Client{ + Timeout: hc.config.Timeout, + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: hc.config.Timeout, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: hc.config.Timeout, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + }, + } + + // 发送请求 + resp, err := client.Do(req) + if err != nil { + hc.updateStatus(key, status, false) + return + } + defer resp.Body.Close() + + // 检查响应状态 + if resp.StatusCode == hc.config.SuccessStatus { + hc.updateStatus(key, status, true) + } else { + hc.updateStatus(key, status, false) + } +} + +// updateStatus 更新后端状态 +func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) { + if success { + status.SuccessCount++ + status.FailCount = 0 + if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess { + status.Healthy = true + if hc.statusChangeCallback != nil { + hc.statusChangeCallback(key, true) + } + } + } else { + status.FailCount++ + status.SuccessCount = 0 + if status.Healthy && status.FailCount >= hc.config.MaxFails { + status.Healthy = false + if hc.statusChangeCallback != nil { + hc.statusChangeCallback(key, false) + } + } + } + + hc.statusMap.Store(key, status) +} diff --git a/internal/loadbalance/loadbalancer.go b/internal/loadbalance/loadbalancer.go new file mode 100644 index 0000000..656e643 --- /dev/null +++ b/internal/loadbalance/loadbalancer.go @@ -0,0 +1,330 @@ +package loadbalance + +import ( + "math/rand" + "net/url" + "sync" + "sync/atomic" + "time" +) + +// Strategy 负载均衡策略 +type Strategy int + +const ( + // StrategyRoundRobin 轮询策略 + StrategyRoundRobin Strategy = iota + // StrategyRandom 随机策略 + StrategyRandom + // StrategyWeightedRoundRobin 加权轮询策略 + StrategyWeightedRoundRobin + // StrategyIPHash IP哈希策略 + StrategyIPHash +) + +// LoadBalancer 负载均衡器接口 +type LoadBalancer interface { + // Next 获取下一个后端 + Next(key string) (*url.URL, error) + // Add 添加后端 + Add(backend string, weight int) error + // Remove 删除后端 + Remove(backend string) error + // MarkDown 标记后端为不可用 + MarkDown(backend string) error + // MarkUp 标记后端为可用 + MarkUp(backend string) error + // Reset 重置负载均衡器 + Reset() error +} + +// Backend 后端服务器 +type Backend struct { + // URL 后端URL + URL *url.URL + // Weight 权重 + Weight int + // Down 是否不可用 + Down bool +} + +// RoundRobinBalancer 轮询负载均衡器 +type RoundRobinBalancer struct { + backends []*Backend + current int32 + mutex sync.RWMutex +} + +// NewRoundRobinBalancer 创建轮询负载均衡器 +func NewRoundRobinBalancer() *RoundRobinBalancer { + return &RoundRobinBalancer{ + backends: make([]*Backend, 0), + current: 0, + } +} + +// Next 获取下一个后端 +func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) { + lb.mutex.RLock() + defer lb.mutex.RUnlock() + + if len(lb.backends) == 0 { + return nil, nil + } + + // 计算可用后端数量 + var availableCount int + for _, backend := range lb.backends { + if !backend.Down { + availableCount++ + } + } + + if availableCount == 0 { + return nil, nil + } + + // 循环直到找到可用后端 + for i := 0; i < len(lb.backends); i++ { + idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends)) + backend := lb.backends[idx] + if !backend.Down { + return backend.URL, nil + } + } + + return nil, nil +} + +// Add 添加后端 +func (lb *RoundRobinBalancer) Add(backend string, weight int) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + return nil + } + } + + lb.backends = append(lb.backends, &Backend{ + URL: url, + Weight: weight, + Down: false, + }) + + return nil +} + +// Remove 删除后端 +func (lb *RoundRobinBalancer) Remove(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for i, b := range lb.backends { + if b.URL.String() == url.String() { + lb.backends = append(lb.backends[:i], lb.backends[i+1:]...) + return nil + } + } + + return nil +} + +// MarkDown 标记后端为不可用 +func (lb *RoundRobinBalancer) MarkDown(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = true + return nil + } + } + + return nil +} + +// MarkUp 标记后端为可用 +func (lb *RoundRobinBalancer) MarkUp(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = false + return nil + } + } + + return nil +} + +// Reset 重置负载均衡器 +func (lb *RoundRobinBalancer) Reset() error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + lb.backends = make([]*Backend, 0) + atomic.StoreInt32(&lb.current, 0) + + return nil +} + +// RandomBalancer 随机负载均衡器 +type RandomBalancer struct { + backends []*Backend + mutex sync.RWMutex + rand *rand.Rand +} + +// NewRandomBalancer 创建随机负载均衡器 +func NewRandomBalancer() *RandomBalancer { + source := rand.NewSource(time.Now().UnixNano()) + random := rand.New(source) + return &RandomBalancer{ + backends: make([]*Backend, 0), + rand: random, + } +} + +// Next 获取下一个后端 +func (lb *RandomBalancer) Next(key string) (*url.URL, error) { + lb.mutex.RLock() + defer lb.mutex.RUnlock() + + if len(lb.backends) == 0 { + return nil, nil + } + + // 计算可用后端数量 + var availableBackends []*Backend + for _, backend := range lb.backends { + if !backend.Down { + availableBackends = append(availableBackends, backend) + } + } + + if len(availableBackends) == 0 { + return nil, nil + } + + idx := lb.rand.Intn(len(availableBackends)) + return availableBackends[idx].URL, nil +} + +// Add 添加后端 +func (lb *RandomBalancer) Add(backend string, weight int) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + return nil + } + } + + lb.backends = append(lb.backends, &Backend{ + URL: url, + Weight: weight, + Down: false, + }) + + return nil +} + +// Remove 删除后端 +func (lb *RandomBalancer) Remove(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for i, b := range lb.backends { + if b.URL.String() == url.String() { + lb.backends = append(lb.backends[:i], lb.backends[i+1:]...) + return nil + } + } + + return nil +} + +// MarkDown 标记后端为不可用 +func (lb *RandomBalancer) MarkDown(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = true + return nil + } + } + + return nil +} + +// MarkUp 标记后端为可用 +func (lb *RandomBalancer) MarkUp(backend string) error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + url, err := url.Parse(backend) + if err != nil { + return err + } + + for _, b := range lb.backends { + if b.URL.String() == url.String() { + b.Down = false + return nil + } + } + + return nil +} + +// Reset 重置负载均衡器 +func (lb *RandomBalancer) Reset() error { + lb.mutex.Lock() + defer lb.mutex.Unlock() + + lb.backends = make([]*Backend, 0) + + return nil +} diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..502c72f --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,250 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// Metrics 监控指标接口 +type Metrics interface { + // 增加请求计数 + IncRequestCount() + // 增加错误计数 + IncErrorCount(err error) + // 观察请求持续时间 + ObserveRequestDuration(seconds float64) + // 增加活跃连接数 + IncActiveConnections() + // 减少活跃连接数 + DecActiveConnections() + // 设置后端健康状态 + SetBackendHealth(backend string, healthy bool) + // 设置后端响应时间 + SetBackendResponseTime(backend string, duration time.Duration) + // 观察请求字节数 + ObserveRequestBytes(bytes int64) + // 观察响应字节数 + ObserveResponseBytes(bytes int64) + // 添加传输字节数 + AddBytesTransferred(direction string, bytes int64) + // 增加缓存命中计数 + IncCacheHit() + // 获取指标处理器 + GetHandler() http.Handler +} + +// SimpleMetrics 简单指标实现 +type SimpleMetrics struct { + // 请求计数 + requestCount int64 + // 错误计数 + errorCount int64 + // 活跃连接数 + activeConnections int64 + // 累计响应时间 + totalResponseTime int64 + // 传输字节数 + bytesTransferred map[string]int64 + // 后端健康状态 + backendHealth map[string]bool + // 后端响应时间 + backendResponseTime map[string]time.Duration + // 缓存命中计数 + cacheHits int64 + // 互斥锁 + mu sync.Mutex +} + +// NewSimpleMetrics 创建简单指标 +func NewSimpleMetrics() *SimpleMetrics { + return &SimpleMetrics{ + bytesTransferred: make(map[string]int64), + backendHealth: make(map[string]bool), + backendResponseTime: make(map[string]time.Duration), + } +} + +// IncRequestCount 增加请求计数 +func (m *SimpleMetrics) IncRequestCount() { + atomic.AddInt64(&m.requestCount, 1) +} + +// IncErrorCount 增加错误计数 +func (m *SimpleMetrics) IncErrorCount(err error) { + atomic.AddInt64(&m.errorCount, 1) +} + +// ObserveRequestDuration 观察请求持续时间 +func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) { + nsec := int64(seconds * float64(time.Second)) + atomic.AddInt64(&m.totalResponseTime, nsec) +} + +// IncActiveConnections 增加活跃连接数 +func (m *SimpleMetrics) IncActiveConnections() { + atomic.AddInt64(&m.activeConnections, 1) +} + +// DecActiveConnections 减少活跃连接数 +func (m *SimpleMetrics) DecActiveConnections() { + atomic.AddInt64(&m.activeConnections, -1) +} + +// SetBackendHealth 设置后端健康状态 +func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) { + m.backendHealth[backend] = healthy +} + +// SetBackendResponseTime 设置后端响应时间 +func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) { + m.backendResponseTime[backend] = duration +} + +// ObserveRequestBytes 观察请求字节数 +func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred["request"] += bytes +} + +// ObserveResponseBytes 观察响应字节数 +func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred["response"] += bytes +} + +// AddBytesTransferred 添加传输字节数 +func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) { + m.mu.Lock() + defer m.mu.Unlock() + m.bytesTransferred[direction] += bytes +} + +// IncCacheHit 增加缓存命中计数 +func (m *SimpleMetrics) IncCacheHit() { + atomic.AddInt64(&m.cacheHits, 1) +} + +// GetHandler 获取指标处理器 +func (m *SimpleMetrics) GetHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + + // 输出基本指标 + w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n")) + w.Write([]byte("# TYPE proxy_requests_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount))) + + w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n")) + w.Write([]byte("# TYPE proxy_errors_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount))) + + w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n")) + w.Write([]byte("# TYPE proxy_active_connections gauge\n")) + w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections))) + + // 输出缓存命中数据 + w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n")) + w.Write([]byte("# TYPE proxy_cache_hits_total counter\n")) + w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits))) + + // 输出传输字节数 + for direction, bytes := range m.bytesTransferred { + w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction))) + w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes))) + } + + // 输出后端健康状态 + for backend, healthy := range m.backendHealth { + healthValue := 0 + if healthy { + healthValue = 1 + } + w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n"))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n"))) + w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue))) + } + + // 输出后端响应时间 + for backend, duration := range m.backendResponseTime { + w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n"))) + w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n"))) + w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second)))) + } + + // 平均响应时间 + if m.requestCount > 0 { + avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second) + w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n")) + w.Write([]byte("# TYPE proxy_average_response_time gauge\n")) + w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime))) + } + }) +} + +// PrometheusMetrics Prometheus指标实现 +type PrometheusMetrics struct { + // 可以通过引入prometheus客户端库实现更完整的指标收集 + // 此处省略具体实现 +} + +// MetricsMiddleware 指标中间件 +type MetricsMiddleware struct { + metrics Metrics +} + +// NewMetricsMiddleware 创建指标中间件 +func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware { + return &MetricsMiddleware{ + metrics: metrics, + } +} + +// Middleware 中间件处理函数 +func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + // 包装响应写入器,用于捕获状态码 + rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + // 继续处理请求 + next.ServeHTTP(rw, r) + + // 记录请求指标 + duration := time.Since(start) + m.metrics.ObserveRequestDuration(duration.Seconds()) + }) +} + +// responseWriter 包装的响应写入器 +type responseWriter struct { + http.ResponseWriter + statusCode int + written int64 +} + +// WriteHeader 写入状态码 +func (rw *responseWriter) WriteHeader(statusCode int) { + rw.statusCode = statusCode + rw.ResponseWriter.WriteHeader(statusCode) +} + +// Write 写入数据 +func (rw *responseWriter) Write(b []byte) (int, error) { + n, err := rw.ResponseWriter.Write(b) + rw.written += int64(n) + return n, err +} + +// Flush 刷新数据 +func (rw *responseWriter) Flush() { + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go new file mode 100644 index 0000000..05292f1 --- /dev/null +++ b/internal/middleware/ratelimiter.go @@ -0,0 +1,184 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiter 限流器接口 +type RateLimiter interface { + // Allow 检查请求是否允许通过 + Allow(key string) bool +} + +// SimpleRateLimiter 简单限流器 +type SimpleRateLimiter struct { + limiter *rate.Limiter +} + +// NewSimpleRateLimiter 创建简单限流器 +func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter { + return &SimpleRateLimiter{ + limiter: rate.NewLimiter(rate.Limit(r), b), + } +} + +// Allow 检查请求是否允许通过 +func (rl *SimpleRateLimiter) Allow(key string) bool { + return rl.limiter.Allow() +} + +// IPRateLimiter 按IP限流 +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu sync.RWMutex + rate rate.Limit + burst int + cleanupInterval time.Duration + lastSeen map[string]time.Time +} + +// NewIPRateLimiter 创建IP限流器 +func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter { + limiter := &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + rate: rate.Limit(r), + burst: b, + cleanupInterval: cleanup, + lastSeen: make(map[string]time.Time), + } + + // 启动过期清理 + if cleanup > 0 { + go limiter.startCleanup() + } + + return limiter +} + +// startCleanup 启动过期清理 +func (rl *IPRateLimiter) startCleanup() { + ticker := time.NewTicker(rl.cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + rl.cleanup() + } +} + +// cleanup 清理过期限流器 +func (rl *IPRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for ip, lastSeen := range rl.lastSeen { + if now.Sub(lastSeen) > rl.cleanupInterval { + delete(rl.ips, ip) + delete(rl.lastSeen, ip) + } + } +} + +// AddIP 添加IP限流器 +func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter := rate.NewLimiter(rl.rate, rl.burst) + rl.ips[ip] = limiter + rl.lastSeen[ip] = time.Now() + + return limiter +} + +// GetLimiter 获取IP限流器 +func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + rl.mu.RLock() + limiter, exists := rl.ips[ip] + rl.mu.RUnlock() + + if !exists { + return rl.AddIP(ip) + } + + // 更新最后访问时间 + rl.mu.Lock() + rl.lastSeen[ip] = time.Now() + rl.mu.Unlock() + + return limiter +} + +// Allow 检查请求是否允许通过 +func (rl *IPRateLimiter) Allow(ip string) bool { + limiter := rl.GetLimiter(ip) + return limiter.Allow() +} + +// RateLimitMiddleware 限流中间件 +type RateLimitMiddleware struct { + limiter RateLimiter +} + +// NewRateLimitMiddleware 创建限流中间件 +func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware { + return &RateLimitMiddleware{ + limiter: limiter, + } +} + +// Middleware 中间件处理函数 +func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取客户端IP + ip := getClientIP(r) + + // 检查是否允许通过 + if !m.limiter.Allow(ip) { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } + + // 继续处理请求 + next.ServeHTTP(w, r) + }) +} + +// getClientIP 获取客户端IP +func getClientIP(r *http.Request) string { + // 检查 X-Forwarded-For 头 + ip := r.Header.Get("X-Forwarded-For") + if ip != "" { + // 取第一个IP + for i := 0; i < len(ip) && i < 15; i++ { + if ip[i] == ',' { + ip = ip[:i] + break + } + } + return ip + } + + // 检查 X-Real-IP 头 + ip = r.Header.Get("X-Real-IP") + if ip != "" { + return ip + } + + // 从 RemoteAddr 获取 + if r.RemoteAddr != "" { + // 去掉端口部分 + for i := 0; i < len(r.RemoteAddr); i++ { + if r.RemoteAddr[i] == ':' { + return r.RemoteAddr[:i] + } + } + return r.RemoteAddr + } + + return "unknown" +} diff --git a/internal/middleware/retry.go b/internal/middleware/retry.go new file mode 100644 index 0000000..1e744b0 --- /dev/null +++ b/internal/middleware/retry.go @@ -0,0 +1,156 @@ +package middleware + +import ( + "bytes" + "io" + "math" + "net" + "net/http" + "time" +) + +// RetryPolicy 重试策略 +type RetryPolicy struct { + // 最大重试次数 + MaxRetries int + // 基础退避时间 + BaseBackoff time.Duration + // 最大退避时间 + MaxBackoff time.Duration + // 重试判断函数 + ShouldRetry func(req *http.Request, resp *http.Response, err error) bool +} + +// DefaultRetryPolicy 默认重试策略 +func DefaultRetryPolicy() *RetryPolicy { + return &RetryPolicy{ + MaxRetries: 3, + BaseBackoff: 100 * time.Millisecond, + MaxBackoff: 2 * time.Second, + ShouldRetry: defaultShouldRetry, + } +} + +// defaultShouldRetry 默认重试判断 +func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool { + // 不重试非幂等请求 + if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions { + return false + } + + // 检查错误 + if err != nil { + // 重试网络错误 + if netErr, ok := err.(net.Error); ok { + return netErr.Temporary() || netErr.Timeout() + } + return false + } + + // 检查响应状态码 + if resp != nil { + // 重试服务器错误 + return resp.StatusCode >= 500 && resp.StatusCode < 600 + } + + return false +} + +// RetryRoundTripper 重试HTTP传输 +type RetryRoundTripper struct { + // 下一级传输 + Next http.RoundTripper + // 重试策略 + Policy *RetryPolicy +} + +// NewRetryRoundTripper 创建重试HTTP传输 +func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper { + if next == nil { + next = http.DefaultTransport + } + if policy == nil { + policy = DefaultRetryPolicy() + } + return &RetryRoundTripper{ + Next: next, + Policy: policy, + } +} + +// RoundTrip 执行HTTP请求 +func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // 需要保留原始请求体,以便重试 + var reqBodyBytes []byte + if req.Body != nil { + var err error + reqBodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return nil, err + } + req.Body.Close() + } + + var resp *http.Response + var err error + + // 尝试请求直到成功或达到最大重试次数 + for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ { + // 复制请求体 + if len(reqBodyBytes) > 0 { + req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes)) + } + + // 发送请求 + resp, err = rt.Next.RoundTrip(req) + + // 检查是否需要重试 + if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) { + // 如果需要重试,先关闭当前响应 + if resp != nil { + resp.Body.Close() + } + + // 计算退避时间 + backoff := rt.calculateBackoff(attempt) + time.Sleep(backoff) + continue + } + + // 不需要重试,返回响应 + return resp, err + } + + // 所有重试都失败 + return resp, err +} + +// calculateBackoff 计算退避时间 +func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration { + // 指数退避: baseBackoff * 2^attempt + backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt))) + if backoff > rt.Policy.MaxBackoff { + backoff = rt.Policy.MaxBackoff + } + return backoff +} + +// RetryMiddleware 重试中间件 +type RetryMiddleware struct { + policy *RetryPolicy +} + +// NewRetryMiddleware 创建重试中间件 +func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware { + if policy == nil { + policy = DefaultRetryPolicy() + } + return &RetryMiddleware{ + policy: policy, + } +} + +// Middleware 中间件处理函数 +func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper { + return NewRetryRoundTripper(next, m.policy) +} diff --git a/internal/proxy/certificate.go b/internal/proxy/certificate.go new file mode 100644 index 0000000..a2c842f --- /dev/null +++ b/internal/proxy/certificate.go @@ -0,0 +1,552 @@ +package proxy + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "hash/fnv" + "math/big" + "net" + "os" + "strings" + "time" +) + +// 内置默认CA证书和私钥 +var ( + // 默认根证书 + defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE----- +MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF +Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK +EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy +NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G +A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw +EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX +mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO +BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG +A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI +MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC +A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY +9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO +-----END CERTIFICATE----- +`) + // 默认根私钥 + defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49 +AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/ +nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA== +-----END EC PRIVATE KEY----- +`) +) + +// 加载和初始化默认根证书 +var ( + defaultRootCA *x509.Certificate + defaultRootKey *ecdsa.PrivateKey +) + +func init() { + var err error + block, _ := pem.Decode(defaultRootCAPem) + if block == nil { + panic("解析默认根证书PEM块失败") + } + defaultRootCA, err = x509.ParseCertificate(block.Bytes) + if err != nil { + panic(fmt.Errorf("加载默认根证书失败: %s", err)) + } + + block, _ = pem.Decode(defaultRootKeyPem) + if block == nil { + panic("解析默认根私钥PEM块失败") + } + defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes) + if err != nil { + panic(fmt.Errorf("加载默认根私钥失败: %s", err)) + } +} + +// CertManager 证书管理器 +type CertManager struct { + // 证书缓存 + cache CertificateCache + // 默认私钥,可用于多个证书共享 + defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥 + // 默认使用ECDSA P-256曲线 + curve elliptic.Curve + // 证书有效期(年) + validityYears int + // 是否使用ECDSA(否则使用RSA) + useECDSA bool +} + +// NewCertManager 创建证书管理器 +func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager { + manager := &CertManager{ + cache: cache, + curve: elliptic.P256(), // 默认使用P-256曲线 + validityYears: 1, // 默认证书有效期1年 + useECDSA: true, // 默认使用ECDSA + } + + // 应用选项 + for _, option := range options { + option(manager) + } + + return manager +} + +// CertManagerOption 证书管理器选项 +type CertManagerOption func(*CertManager) + +// WithUseECDSA 设置是否使用ECDSA(否则使用RSA) +func WithUseECDSA(useECDSA bool) CertManagerOption { + return func(m *CertManager) { + m.useECDSA = useECDSA + } +} + +// WithDefaultPrivateKey 设置是否使用默认私钥 +func WithDefaultPrivateKey(enable bool) CertManagerOption { + return func(m *CertManager) { + if enable { + if m.useECDSA { + // 生成ECDSA私钥 + priv, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err)) + } + m.defaultPrivateKey = priv + } else { + // 生成RSA私钥 + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + panic(fmt.Errorf("生成默认RSA私钥失败: %s", err)) + } + m.defaultPrivateKey = priv + } + } + } +} + +// WithCurve 设置椭圆曲线 +func WithCurve(curve elliptic.Curve) CertManagerOption { + return func(m *CertManager) { + m.curve = curve + } +} + +// WithValidityYears 设置证书有效期(年) +func WithValidityYears(years int) CertManagerOption { + return func(m *CertManager) { + if years > 0 { + m.validityYears = years + } + } +} + +// GenerateTLSConfig 为指定主机生成TLS配置 +func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) { + // 处理可能的端口 + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + + // 检查证书缓存 + if m.cache != nil { + // 检查主域名和子域名 + fields := strings.Split(host, ".") + domains := []string{host} + + // 添加父域名 + if len(fields) > 2 { + domains = append(domains, strings.Join(fields[1:], ".")) + } + + // 查找缓存 + for _, domain := range domains { + if cert := m.cache.Get(domain); cert != nil { + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + } + } + + // 生成新证书 + cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey) + if err != nil { + return nil, err + } + + // 缓存证书 + if m.cache != nil { + // 缓存主机名 + m.cache.Set(host, cert) + + // 如果是IP地址,不进行其他处理 + if net.ParseIP(host) != nil { + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + + // 缓存域名和子域名证书 + fields := strings.Split(host, ".") + if len(fields) >= 2 { + // 缓存主域名 + domain := strings.Join(fields[1:], ".") + m.cache.Set(domain, cert) + } + } + + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil +} + +// GenerateCertificate 生成证书 +func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) { + // 准备私钥 + var priv interface{} + var pubKey interface{} + var err error + + // 使用默认私钥或生成新私钥 + if m.defaultPrivateKey != nil { + priv = m.defaultPrivateKey + } else if m.useECDSA { + // 生成ECDSA私钥 + ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err) + } + priv = ecdsaKey + pubKey = &ecdsaKey.PublicKey + } else { + // 生成RSA私钥 + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("生成RSA私钥失败: %s", err) + } + priv = rsaKey + pubKey = &rsaKey.PublicKey + } + + // 获取公钥 + if pubKey == nil { + switch k := priv.(type) { + case *ecdsa.PrivateKey: + pubKey = &k.PublicKey + case *rsa.PrivateKey: + pubKey = &k.PublicKey + default: + return nil, fmt.Errorf("不支持的私钥类型") + } + } + + // 创建证书模板 + template := m.createCertificateTemplate(host) + + // 签名证书 + derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey) + if err != nil { + return nil, fmt.Errorf("创建证书失败: %s", err) + } + + // 编码为PEM格式 + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }) + + // 编码私钥 + var keyPEM []byte + switch k := priv.(type) { + case *ecdsa.PrivateKey: + privBytes, err := x509.MarshalECPrivateKey(k) + if err != nil { + return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: privBytes, + }) + case *rsa.PrivateKey: + privBytes := x509.MarshalPKCS1PrivateKey(k) + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privBytes, + }) + default: + return nil, fmt.Errorf("不支持的私钥类型") + } + + // 创建TLS证书 + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("创建TLS证书对失败: %s", err) + } + + return &tlsCert, nil +} + +// createCertificateTemplate 创建证书模板 +func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate { + // 使用基于主机名的哈希值作为序列号 + fv := fnv.New64a() + fv.Write([]byte(host)) + serialNumber := big.NewInt(0).SetUint64(fv.Sum64()) + + // 准备模板 + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: host, + Organization: []string{"GoProxy Dynamic CA"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效,容忍时间偏差 + NotAfter: time.Now().AddDate(m.validityYears, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: false, + } + + // 处理IP地址和域名 + ipAddr := net.ParseIP(host) + if ipAddr != nil { + template.IPAddresses = []net.IP{ipAddr} + } else { + // 移除可能的端口部分 + if strings.Contains(host, ":") { + host = strings.Split(host, ":")[0] + } + + // 将主机名添加到DNS名称列表 + template.DNSNames = []string{host} + + // 添加通配符域名支持 + fields := strings.Split(host, ".") + fieldNum := len(fields) + + // 为每一级子域名添加通配符 + for i := 0; i <= (fieldNum - 2); i++ { + wildcardDomain := "*." + strings.Join(fields[i:], ".") + // 避免重复 + if wildcardDomain != host { + template.DNSNames = append(template.DNSNames, wildcardDomain) + } + } + } + + return template +} + +// LoadCAFromFiles 从文件加载CA证书和私钥 +func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) { + // 读取CA证书 + caCertPEM, err := os.ReadFile(certFile) + if err != nil { + return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err) + } + + block, _ := pem.Decode(caCertPEM) + if block == nil { + return nil, nil, fmt.Errorf("解析CA证书PEM块失败") + } + + caCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析CA证书失败: %s", err) + } + + // 读取CA私钥 + caKeyPEM, err := os.ReadFile(keyFile) + if err != nil { + return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err) + } + + block, _ = pem.Decode(caKeyPEM) + if block == nil { + return nil, nil, fmt.Errorf("解析CA私钥PEM块失败") + } + + var caKey *ecdsa.PrivateKey + + // 尝试不同的私钥格式 + switch block.Type { + case "EC PRIVATE KEY": + caKey, err = x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err) + } + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err) + } + var ok bool + caKey, ok = key.(*ecdsa.PrivateKey) + if !ok { + return nil, nil, fmt.Errorf("私钥不是ECDSA类型") + } + default: + return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type) + } + + return caCert, caKey, nil +} + +// GetDefaultRootCA 获取默认根证书和私钥 +func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) { + return defaultRootCA, defaultRootKey +} + +// GenerateRootCA 生成新的根证书和私钥 +func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) { + // 生成私钥 + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err) + } + + // 随机生成序列号 + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("生成序列号失败: %s", err) + } + + // 创建根证书模板 + notBefore := time.Now().Add(-10 * time.Minute) + notAfter := notBefore.AddDate(validYears, 0, 0) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: "GoProxy Root CA", + Organization: []string{"GoProxy"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 2, + } + + // 自签名 + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("创建根证书失败: %s", err) + } + + // 解析生成的证书 + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err) + } + + return cert, priv, nil +} + +// SaveCertificateToFile 将证书和私钥保存到文件 +func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error { + // 保存证书 + certPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + if err := os.WriteFile(certFile, certPEM, 0644); err != nil { + return fmt.Errorf("保存证书到文件失败: %s", err) + } + + // 保存私钥 + keyBytes, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("序列化私钥失败: %s", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: keyBytes, + }) + if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil { + return fmt.Errorf("保存私钥到文件失败: %s", err) + } + + return nil +} + +// GenerateRootCA 生成根证书 +func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) { + // 生成私钥 + var priv interface{} + var pubKey interface{} + var err error + + if m.useECDSA { + // 生成ECDSA私钥 + ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err) + } + priv = ecdsaKey + pubKey = &ecdsaKey.PublicKey + } else { + // 生成RSA私钥 + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err) + } + priv = rsaKey + pubKey = &rsaKey.PublicKey + } + + // 创建根证书模板 + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "GoProxy Root CA", + Organization: []string{"GoProxy Authority"}, + Country: []string{"CN"}, + Province: []string{"GuangDong"}, + Locality: []string{"Guangzhou"}, + }, + NotBefore: time.Now().Add(-10 * time.Minute), + NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期 + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + + // 自签名 + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv) + if err != nil { + return nil, nil, fmt.Errorf("创建根证书失败: %s", err) + } + + // 解析证书 + cert, err := x509.ParseCertificate(derBytes) + if err != nil { + return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err) + } + + return cert, priv, nil +} diff --git a/internal/proxy/conn_buffer.go b/internal/proxy/conn_buffer.go new file mode 100644 index 0000000..6d3ead6 --- /dev/null +++ b/internal/proxy/conn_buffer.go @@ -0,0 +1,134 @@ +package proxy + +import ( + "bufio" + "io" + "net" + "time" +) + +// ConnBuffer 连接缓冲区 +// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口 +type ConnBuffer struct { + // 底层连接 + conn net.Conn + // 缓冲读取器 + reader *bufio.Reader +} + +// NewConnBuffer 创建连接缓冲区 +func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer { + if reader == nil { + reader = bufio.NewReader(conn) + } + return &ConnBuffer{ + conn: conn, + reader: reader, + } +} + +// Read 从连接读取数据 +func (c *ConnBuffer) Read(b []byte) (int, error) { + return c.reader.Read(b) +} + +// Write 向连接写入数据 +func (c *ConnBuffer) Write(b []byte) (int, error) { + return c.conn.Write(b) +} + +// Close 关闭连接 +func (c *ConnBuffer) Close() error { + return c.conn.Close() +} + +// LocalAddr 获取本地地址 +func (c *ConnBuffer) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr 获取远程地址 +func (c *ConnBuffer) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline 设置读写超时 +func (c *ConnBuffer) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline 设置读取超时 +func (c *ConnBuffer) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline 设置写入超时 +func (c *ConnBuffer) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// BufferReader 获取缓冲读取器 +func (c *ConnBuffer) BufferReader() *bufio.Reader { + return c.reader +} + +// Peek 查看缓冲区中的数据,但不消费 +func (c *ConnBuffer) Peek(n int) ([]byte, error) { + return c.reader.Peek(n) +} + +// ReadByte 读取一个字节 +func (c *ConnBuffer) ReadByte() (byte, error) { + return c.reader.ReadByte() +} + +// UnreadByte 将最后读取的字节放回缓冲区 +func (c *ConnBuffer) UnreadByte() error { + return c.reader.UnreadByte() +} + +// ReadLine 读取一行数据 +func (c *ConnBuffer) ReadLine() (string, error) { + line, isPrefix, err := c.reader.ReadLine() + if err != nil { + return "", err + } + + // 如果一行数据没有读取完整,继续读取 + if isPrefix { + var buf []byte + buf = append(buf, line...) + for isPrefix && err == nil { + line, isPrefix, err = c.reader.ReadLine() + if err != nil { + return "", err + } + buf = append(buf, line...) + } + return string(buf), nil + } + + return string(line), nil +} + +// ReadN 读取指定字节数的数据 +func (c *ConnBuffer) ReadN(n int) ([]byte, error) { + buf := make([]byte, n) + _, err := io.ReadFull(c.reader, buf) + if err != nil { + return nil, err + } + return buf, nil +} + +// Flush 刷新缓冲区 +func (c *ConnBuffer) Flush() error { + // 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作 + return nil +} + +// Reset 重置连接缓冲区 +func (c *ConnBuffer) Reset(conn net.Conn) { + c.conn = conn + c.reader = bufio.NewReader(conn) +} diff --git a/internal/proxy/context.go b/internal/proxy/context.go new file mode 100644 index 0000000..fb28596 --- /dev/null +++ b/internal/proxy/context.go @@ -0,0 +1,132 @@ +package proxy + +import ( + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// Context 代理上下文 +// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据 +type Context struct { + // 原始请求 + Req *http.Request + // 请求开始时间 + StartTime time.Time + // 上下文数据,用于在各个处理阶段传递数据 + Data map[interface{}]interface{} + // 是否是隧道代理 + TunnelProxy bool + // 请求ID + RequestID string + // 目标地址 + TargetAddr string + // 上级代理地址 + ParentProxyURL *url.URL + // 是否中断执行 + abort bool + // 请求标签,用于标记请求类型 + Tags []string + // 是否已中止 + aborted bool + // 互斥锁 + mu sync.RWMutex +} + +// IsHTTPS 是否是HTTPS请求 +func (c *Context) IsHTTPS() bool { + return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect +} + +// defaultPorts 默认端口映射 +var defaultPorts = map[string]string{ + "https": "443", + "http": "80", + "": "80", +} + +// WebSocketURL 获取WebSocket URL +func (c *Context) WebSocketURL() *url.URL { + u := *c.Req.URL + if c.IsHTTPS() { + u.Scheme = "wss" + } else { + u.Scheme = "ws" + } + return &u +} + +// Addr 获取请求地址 +func (c *Context) Addr() string { + addr := c.Req.Host + + if !strings.Contains(c.Req.URL.Host, ":") { + addr += ":" + defaultPorts[c.Req.URL.Scheme] + } + + return addr +} + +// AddTag 添加请求标签 +func (c *Context) AddTag(tag string) { + c.Tags = append(c.Tags, tag) +} + +// HasTag 检查是否包含指定标签 +func (c *Context) HasTag(tag string) bool { + for _, t := range c.Tags { + if t == tag { + return true + } + } + return false +} + +// Abort 中断执行 +func (c *Context) Abort() { + c.aborted = true +} + +// IsAborted 是否已中断执行 +func (c *Context) IsAborted() bool { + return c.aborted +} + +// Reset 重置上下文 +func (c *Context) Reset(req *http.Request) { + c.Req = req + c.StartTime = time.Now() + c.Data = make(map[interface{}]interface{}) + c.abort = false + c.TunnelProxy = false + c.Tags = make([]string, 0) + c.RequestID = "" + c.TargetAddr = "" + c.ParentProxyURL = nil + c.aborted = false +} + +// Set 设置数据 +func (c *Context) Set(key, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.Data == nil { + c.Data = make(map[interface{}]interface{}) + } + c.Data[key] = value +} + +// Get 获取数据 +func (c *Context) Get(key interface{}) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.Data == nil { + return nil, false + } + val, ok := c.Data[key] + return val, ok +} diff --git a/internal/proxy/delegate.go b/internal/proxy/delegate.go new file mode 100644 index 0000000..2c3af01 --- /dev/null +++ b/internal/proxy/delegate.go @@ -0,0 +1,123 @@ +package proxy + +import ( + "net/http" + "net/url" +) + +// Delegate 代理委托接口 +// 定义了代理处理请求的各个阶段的回调方法 +type Delegate interface { + // Connect 连接事件 + Connect(ctx *Context, rw http.ResponseWriter) + + // Auth 认证事件 + Auth(ctx *Context, rw http.ResponseWriter) + + // BeforeRequest 请求前事件 + BeforeRequest(ctx *Context) + + // BeforeResponse 响应前事件 + BeforeResponse(ctx *Context, resp *http.Response, err error) + + // WebSocketSendMessage websocket发送消息拦截 + // WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte) + + // WebSocketReceiveMessage websocket接收消息拦截 + // WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte) + + // ParentProxy 获取上级代理 + ParentProxy(req *http.Request) (*url.URL, error) + + // ErrorLog 错误日志 + ErrorLog(err error) + + // Finish 完成事件 + Finish(ctx *Context) + + // 以下是反向代理相关的方法 + + // ResolveBackend 解析后端服务器 + // 在反向代理模式下,根据请求确定应该转发到哪个后端服务器 + ResolveBackend(req *http.Request) (string, error) + + // ModifyRequest 修改请求 + // 在反向代理模式下,可以修改发往后端服务器的请求 + ModifyRequest(req *http.Request) + + // ModifyResponse 修改响应 + // 在反向代理模式下,可以修改来自后端服务器的响应 + ModifyResponse(resp *http.Response) error + + // HandleError 处理错误 + // 在反向代理模式下,可以自定义错误处理逻辑 + HandleError(rw http.ResponseWriter, req *http.Request, err error) +} + +// DefaultDelegate 默认代理委托 +type DefaultDelegate struct{} + +// Connect 连接事件 +func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) { + // 默认实现不做任何处理 +} + +// Auth 认证事件 +func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) { + // 默认实现不做任何处理 +} + +// BeforeRequest 请求前事件 +func (d *DefaultDelegate) BeforeRequest(ctx *Context) { + // 默认实现不做任何处理 +} + +// BeforeResponse 响应前事件 +func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) { + // 默认实现不做任何处理 +} + +// ParentProxy 获取上级代理 +func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) { + // 默认实现不使用上级代理 + return nil, nil +} + +// WebSocketSendMessage websocket发送消息拦截 +// func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {} + +// WebSocketReceiveMessage websocket接收消息拦截 +// func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {} + +// ErrorLog 错误日志 +func (d *DefaultDelegate) ErrorLog(err error) { + // 默认实现不做任何处理 +} + +// Finish 完成事件 +func (d *DefaultDelegate) Finish(ctx *Context) { + // 默认实现不做任何处理 +} + +// ResolveBackend 解析后端服务器 +func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) { + // 默认实现返回请求中的主机 + return req.Host, nil +} + +// ModifyRequest 修改请求 +func (d *DefaultDelegate) ModifyRequest(req *http.Request) { + // 默认实现不做任何修改 +} + +// ModifyResponse 修改响应 +func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error { + // 默认实现不做任何修改 + return nil +} + +// HandleError 处理错误 +func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) { + // 默认实现返回502错误 + http.Error(rw, err.Error(), http.StatusBadGateway) +} diff --git a/internal/proxy/options.go b/internal/proxy/options.go new file mode 100644 index 0000000..6087657 --- /dev/null +++ b/internal/proxy/options.go @@ -0,0 +1,262 @@ +package proxy + +import ( + "net/http" + "net/http/httptrace" + "time" + + "github.com/goproxy/internal/cache" + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/healthcheck" + "github.com/goproxy/internal/loadbalance" + "github.com/goproxy/internal/metrics" +) + +// Option 用于配置代理选项的函数类型 +type Option func(*Options) + +// WithConfig 设置代理配置 +func WithConfig(cfg *config.Config) Option { + return func(opt *Options) { + opt.Config = cfg + } +} + +// WithDisableKeepAlive 设置连接是否重用 +func WithDisableKeepAlive(disableKeepAlive bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // 在transport中设置DisableKeepAlives + } +} + +// WithClientTrace 设置HTTP客户端跟踪 +func WithClientTrace(t *httptrace.ClientTrace) Option { + return func(opt *Options) { + opt.ClientTrace = t + } +} + +// WithDelegate 设置委托类 +func WithDelegate(delegate Delegate) Option { + return func(opt *Options) { + opt.Delegate = delegate + } +} + +// WithTransport 使用自定义HTTP传输 +func WithTransport(t *http.Transport) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // 在New方法中处理transport + } +} + +// WithDecryptHTTPS 启用中间人代理解密HTTPS +func WithDecryptHTTPS(c CertificateCache) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.DecryptHTTPS = true + opt.CertCache = c + } +} + +// WithEnableWebsocketIntercept 启用WebSocket拦截 +func WithEnableWebsocketIntercept() Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + // WebSocket拦截在代理处理逻辑中实现 + } +} + +// WithHTTPCache 设置HTTP缓存 +func WithHTTPCache(c cache.Cache) Option { + return func(opt *Options) { + opt.HTTPCache = c + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableCache = true + } +} + +// WithLoadBalancer 设置负载均衡器 +func WithLoadBalancer(lb loadbalance.LoadBalancer) Option { + return func(opt *Options) { + opt.LoadBalancer = lb + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableLoadBalancing = true + } +} + +// WithHealthChecker 设置健康检查器 +func WithHealthChecker(hc *healthcheck.HealthChecker) Option { + return func(opt *Options) { + opt.HealthChecker = hc + } +} + +// WithMetrics 设置监控指标 +func WithMetrics(m metrics.Metrics) Option { + return func(opt *Options) { + opt.Metrics = m + } +} + +// WithTLSCertAndKey 设置TLS证书和密钥 +func WithTLSCertAndKey(certPath, keyPath string) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.TLSCert = certPath + opt.Config.TLSKey = keyPath + } +} + +// WithCACertAndKey 设置CA证书和密钥(用于生成动态证书) +func WithCACertAndKey(caCertPath, caKeyPath string) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.CACert = caCertPath + opt.Config.CAKey = caKeyPath + } +} + +// WithConnectionPoolSize 设置连接池大小 +func WithConnectionPoolSize(size int) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.ConnectionPoolSize = size + } +} + +// WithIdleTimeout 设置空闲超时时间 +func WithIdleTimeout(timeout time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.IdleTimeout = timeout + } +} + +// WithRequestTimeout 设置请求超时时间 +func WithRequestTimeout(timeout time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.RequestTimeout = timeout + } +} + +// WithReverseProxy 启用反向代理模式 +func WithReverseProxy(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.ReverseProxy = enable + } +} + +// WithEnableRetry 启用请求重试 +func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableRetry = true + opt.Config.MaxRetries = maxRetries + opt.Config.RetryBackoff = baseBackoff + opt.Config.MaxRetryBackoff = maxBackoff + } +} + +// WithRateLimit 设置请求限流 +func WithRateLimit(rps float64) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableRateLimit = true + opt.Config.RateLimit = rps + } +} + +// WithDNSCacheTTL 设置DNS缓存TTL +func WithDNSCacheTTL(ttl time.Duration) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.DNSCacheTTL = ttl + } +} + +// WithURLRewrite 启用URL重写 +func WithURLRewrite(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableURLRewrite = enable + } +} + +// WithEnableCORS 启用CORS支持 +func WithEnableCORS(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.EnableCORS = enable + } +} + +// WithCertManager 设置证书管理器 +// 这是一个内部函数,主要用于在New方法中设置CertManager +func WithCertManager(certManager *CertManager) Option { + return func(opt *Options) { + opt.CertManager = certManager + } +} + +// WithEnableECDSA 启用ECDSA证书生成(默认使用RSA) +func WithEnableECDSA(enable bool) Option { + return func(opt *Options) { + if opt.Config == nil { + opt.Config = config.DefaultConfig() + } + opt.Config.UseECDSA = enable + } +} + +// NewWithOptions 使用选项函数创建代理 +func NewWithOptions(options ...Option) *Proxy { + opts := &Options{ + Config: config.DefaultConfig(), + } + + // 应用所有选项 + for _, option := range options { + option(opts) + } + + return New(opts) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..3da6caf --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,1130 @@ +package proxy + +import ( + "bufio" + "context" + "crypto/elliptic" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "net/url" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/goproxy/internal/cache" + "github.com/goproxy/internal/config" + "github.com/goproxy/internal/healthcheck" + "github.com/goproxy/internal/loadbalance" + "github.com/goproxy/internal/metrics" + "github.com/goproxy/internal/middleware" + "github.com/viki-org/dnscache" +) + +const ( + // 连接目标服务器超时时间 + defaultTargetConnectTimeout = 5 * time.Second + // 目标服务器读写超时时间 + defaultTargetReadWriteTimeout = 10 * time.Second +) + +// 隧道连接成功响应行 +var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n") + +// 错误网关响应 +var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))) + +// 对象池 +var ( + bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024) + }, + } + + ctxPool = sync.Pool{ + New: func() interface{} { + return new(Context) + }, + } +) + +// CertificateCache 证书缓存接口 +type CertificateCache interface { + // Get 获取证书 + Get(host string) *tls.Certificate + // Set 设置证书 + Set(host string, cert *tls.Certificate) +} + +// MemCertCache 内存证书缓存 +type MemCertCache struct { + certs sync.Map +} + +// Get 获取证书 +func (c *MemCertCache) Get(host string) *tls.Certificate { + v, ok := c.certs.Load(host) + if !ok { + return nil + } + return v.(*tls.Certificate) +} + +// Set 设置证书 +func (c *MemCertCache) Set(host string, cert *tls.Certificate) { + c.certs.Store(host, cert) +} + +// CacheAdapter 缓存适配器,统一不同缓存实现的接口 +type CacheAdapter struct { + cache interface{} + // 缓存方法类型标志 + getMethodType int + setMethodType int + // 方法类型常量 + getResponseBool int + getInterfaceBool int + setResponseOnly int + setResponseTTL int + setInterfaceOnly int +} + +// NewCacheAdapter 创建缓存适配器 +func NewCacheAdapter(cache interface{}) *CacheAdapter { + adapter := &CacheAdapter{ + cache: cache, + // 方法类型常量初始化 + getResponseBool: 1, + getInterfaceBool: 2, + setResponseOnly: 1, + setResponseTTL: 2, + setInterfaceOnly: 3, + } + + // 判断支持的方法类型 + if _, ok := cache.(interface { + Get(string) (*http.Response, bool) + }); ok { + adapter.getMethodType = adapter.getResponseBool + } else if _, ok := cache.(interface { + Get(string) (interface{}, bool) + }); ok { + adapter.getMethodType = adapter.getInterfaceBool + } + + if _, ok := cache.(interface { + Set(string, *http.Response, time.Duration) + }); ok { + adapter.setMethodType = adapter.setResponseTTL + } else if _, ok := cache.(interface { + Set(string, *http.Response) + }); ok { + adapter.setMethodType = adapter.setResponseOnly + } else if _, ok := cache.(interface { + Set(string, interface{}) + }); ok { + adapter.setMethodType = adapter.setInterfaceOnly + } + + return adapter +} + +// Get 统一的获取方法 +func (a *CacheAdapter) Get(key string) (interface{}, bool) { + switch a.getMethodType { + case a.getResponseBool: + if getter, ok := a.cache.(interface { + Get(string) (*http.Response, bool) + }); ok { + return getter.Get(key) + } + case a.getInterfaceBool: + if getter, ok := a.cache.(interface { + Get(string) (interface{}, bool) + }); ok { + return getter.Get(key) + } + } + return nil, false +} + +// Set 统一的设置方法 +func (a *CacheAdapter) Set(key string, value interface{}, ttl time.Duration) { + resp, isResponse := value.(*http.Response) + switch a.setMethodType { + case a.setResponseTTL: + if setter, ok := a.cache.(interface { + Set(string, *http.Response, time.Duration) + }); ok && isResponse { + setter.Set(key, resp, ttl) + } + case a.setResponseOnly: + if setter, ok := a.cache.(interface { + Set(string, *http.Response) + }); ok && isResponse { + setter.Set(key, resp) + } + case a.setInterfaceOnly: + if setter, ok := a.cache.(interface { + Set(string, interface{}) + }); ok { + setter.Set(key, value) + } + } +} + +// Options 代理选项 +type Options struct { + // 配置 + Config *config.Config + // 委托 + Delegate Delegate + // 证书缓存 + CertCache CertificateCache + // HTTP缓存 + HTTPCache cache.Cache + // 负载均衡器 + LoadBalancer loadbalance.LoadBalancer + // 健康检查器 + HealthChecker *healthcheck.HealthChecker + // 监控指标 + Metrics metrics.Metrics + // 客户端跟踪 + ClientTrace *httptrace.ClientTrace + // 证书管理器 + CertManager *CertManager +} + +// Proxy HTTP代理 +type Proxy struct { + // 配置 + config *config.Config + // 委托 + delegate Delegate + // 证书缓存 + certCache CertificateCache + // HTTP缓存 + httpCache cache.Cache + // 缓存适配器 + cacheAdapter *CacheAdapter + // 负载均衡器 + loadBalancer loadbalance.LoadBalancer + // 健康检查器 + healthChecker *healthcheck.HealthChecker + // 监控指标 + metrics metrics.Metrics + // 客户端跟踪 + clientTrace *httptrace.ClientTrace + // 传输 + transport *http.Transport + // DNS缓存 + dnsCache *dnscache.Resolver + // 客户端连接数 + clientConnNum int32 + // 证书管理器 + certManager *CertManager +} + +// New 创建代理 +func New(opts *Options) *Proxy { + if opts == nil { + opts = &Options{} + } + + if opts.Config == nil { + opts.Config = config.DefaultConfig() + } + + if opts.Delegate == nil { + opts.Delegate = &DefaultDelegate{} + } + + p := &Proxy{ + config: opts.Config, + delegate: opts.Delegate, + certCache: opts.CertCache, + httpCache: opts.HTTPCache, + loadBalancer: opts.LoadBalancer, + healthChecker: opts.HealthChecker, + metrics: opts.Metrics, + clientTrace: opts.ClientTrace, + clientConnNum: 0, + } + + // 如果存在HTTP缓存,创建缓存适配器 + if p.httpCache != nil { + p.cacheAdapter = NewCacheAdapter(p.httpCache) + } + + // 创建DNS缓存 + p.dnsCache = dnscache.New(opts.Config.DNSCacheTTL) + + // 设置证书管理器 + if opts.CertManager != nil { + // 如果选项中已提供证书管理器,直接使用 + p.certManager = opts.CertManager + } else if opts.Config.DecryptHTTPS { + // 如果启用了HTTPS解密,且未提供证书管理器,则创建一个新的证书管理器 + certManagerOpts := []CertManagerOption{ + WithDefaultPrivateKey(true), // 使用默认私钥提高性能 + WithValidityYears(1), // 证书有效期1年 + WithUseECDSA(opts.Config.UseECDSA), // 根据配置决定是否使用ECDSA + } + + // 如果配置指定了使用ECDSA,设置曲线为P-256 + if opts.Config.UseECDSA { + certManagerOpts = append(certManagerOpts, WithCurve(elliptic.P256())) + } + + p.certManager = NewCertManager(p.certCache, certManagerOpts...) + } + + // 创建传输 + p.transport = &http.Transport{ + Proxy: p.proxyFromDelegate, + DialContext: p.dialContextWithCache(), + MaxIdleConns: opts.Config.ConnectionPoolSize, + IdleConnTimeout: opts.Config.IdleTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DisableKeepAlives: false, + DisableCompression: false, + ForceAttemptHTTP2: true, + } + + // 应用重试中间件 + if opts.Config.EnableRetry { + policy := &middleware.RetryPolicy{ + MaxRetries: opts.Config.MaxRetries, + BaseBackoff: opts.Config.RetryBackoff, + MaxBackoff: opts.Config.MaxRetryBackoff, + } + retryMiddleware := middleware.NewRetryMiddleware(policy) + p.transport = retryMiddleware.Transport(p.transport).(*http.Transport) + } + + // 将健康检查器与负载均衡器集成 + if p.healthChecker != nil && p.loadBalancer != nil { + p.healthChecker.SetStatusChangeCallback(func(target string, healthy bool) { + if healthy { + p.loadBalancer.MarkUp(target) + } else { + p.loadBalancer.MarkDown(target) + } + }) + } + + return p +} + +// NewProxy 使用functional options模式创建代理 +func NewProxy(options ...Option) *Proxy { + // 创建默认选项 + opts := &Options{ + Config: config.DefaultConfig(), + } + + // 应用所有选项 + for _, option := range options { + option(opts) + } + + // 使用传统方法创建代理 + return New(opts) +} + +// ServeHTTP 处理HTTP请求 +func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // 判断是反向代理还是正向代理 + if p.config.ReverseProxy { + // 如果是反向代理模式,使用反向代理处理请求 + reverseProxy := p.NewReverseProxy() + reverseProxy.ServeHTTP(rw, req) + return + } + + // 以下是正向代理的逻辑 + // ...以下保持原代码不变... + + // 更新请求计数指标 + if p.metrics != nil { + p.metrics.IncRequestCount() + } + + // 处理请求 + ctx := ctxPool.Get().(*Context) + ctx.Reset(req) + defer ctxPool.Put(ctx) + + // 调用连接事件 + p.delegate.Connect(ctx, rw) + + // 认证检查 + p.delegate.Auth(ctx, rw) + if ctx.IsAborted() { + return + } + + // HTTP隧道连接(CONNECT方法) + if req.Method == http.MethodConnect { + p.tunnelProxy(ctx, rw) + return + } + + // 如果是WebSocket请求,使用WebSocket代理 + if isWebSocketRequest(req) { + clientConn, err := hijacker(rw) + if err != nil { + p.delegate.ErrorLog(err) + http.Error(rw, "无法处理WebSocket请求", http.StatusInternalServerError) + return + } + p.websocketProxy(ctx, clientConn) + return + } + + // 处理普通HTTP请求 + p.handleHTTP(ctx, rw) +} + +// handleHTTP 处理HTTP请求 +func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) { + // 调用请求前事件 + p.delegate.BeforeRequest(ctx) + if ctx.IsAborted() { + return + } + + // 开始时间 + startTime := time.Now() + + // 获取上级代理 + parentProxy, err := p.proxyFromDelegate(ctx.Req) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 获取上级代理错误: %s", ctx.Req.URL.Host, err)) + ctx.ParentProxyURL = nil + } else { + ctx.ParentProxyURL = parentProxy + } + + var ( + resp *http.Response + req = ctx.Req + ) + + // 从缓存获取响应 + if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) { + cacheKey := generateCacheKey(req) + var cachedResp interface{} + var ok bool + + // 使用缓存适配器获取数据 + if p.cacheAdapter != nil { + cachedResp, ok = p.cacheAdapter.Get(cacheKey) + if ok && cachedResp != nil { + // 从缓存中找到响应 + resp = cachedResp.(*http.Response) + // 更新指标 + if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { + incrementCacheHit(p.metrics) + } + } + } + } + + // 如果缓存中没有,则发送请求 + if resp == nil { + // 创建传输上下文 + reqCtx := req.Context() + if p.clientTrace != nil { + reqCtx = httptrace.WithClientTrace(reqCtx, p.clientTrace) + } + + // 设置请求超时 + if p.config.RequestTimeout > 0 { + var cancel context.CancelFunc + reqCtx, cancel = context.WithTimeout(reqCtx, p.config.RequestTimeout) + defer cancel() + } + + req = req.WithContext(reqCtx) + + // 发送请求 + var err error + resp, err = p.transport.RoundTrip(req) + + // 处理错误 + if err != nil { + p.delegate.BeforeResponse(ctx, nil, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", req.URL.Host, err)) + http.Error(rw, err.Error(), http.StatusBadGateway) + return + } + + // 更新指标 + if p.metrics != nil { + p.metrics.ObserveRequestDuration(time.Since(startTime).Seconds()) + } + + // 缓存响应 + if p.httpCache != nil && p.config.EnableCache && canCacheMethod(req.Method) && canCacheStatus(resp.StatusCode) { + cacheKey := generateCacheKey(req) + + // 使用缓存适配器设置数据 + if p.cacheAdapter != nil { + p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) + } + } + } + + // 调用响应前事件 + p.delegate.BeforeResponse(ctx, resp, nil) + if ctx.IsAborted() { + return + } + + // 复制头部信息 + for key, values := range resp.Header { + for _, value := range values { + rw.Header().Add(key, value) + } + } + + // 写入状态码 + rw.WriteHeader(resp.StatusCode) + + // 复制响应体 + _, err = io.Copy(rw, resp.Body) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err)) + } + + // 关闭响应体 + resp.Body.Close() + + // 调用完成事件 + p.delegate.Finish(ctx) +} + +// canCacheMethod 检查请求方法是否可缓存 +func canCacheMethod(method string) bool { + return method == http.MethodGet || method == http.MethodHead +} + +// canCacheStatus 检查响应状态码是否可缓存 +func canCacheStatus(statusCode int) bool { + return statusCode >= 200 && statusCode < 400 +} + +// generateCacheKey 生成缓存键 +func generateCacheKey(req *http.Request) string { + return req.Method + " " + req.URL.String() +} + +// getCacheTTL 获取缓存TTL +func getCacheTTL(resp *http.Response) time.Duration { + // 默认5分钟 + ttl := 5 * time.Minute + + // 从Cache-Control获取max-age + cacheControl := resp.Header.Get("Cache-Control") + if cacheControl != "" { + for _, directive := range strings.Split(cacheControl, ",") { + directive = strings.TrimSpace(directive) + if strings.HasPrefix(directive, "max-age=") { + maxAge := strings.TrimPrefix(directive, "max-age=") + if seconds, err := strconv.Atoi(maxAge); err == nil { + ttl = time.Duration(seconds) * time.Second + } + break + } + } + } + + return ttl +} + +// ClientConnNum 获取客户端连接数 +func (p *Proxy) ClientConnNum() int32 { + return atomic.LoadInt32(&p.clientConnNum) +} + +// DoRequest 执行HTTP请求 +func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) { + if ctx.Data == nil { + ctx.Data = make(map[interface{}]interface{}) + } + + // 请求前处理 + p.delegate.BeforeRequest(ctx) + if ctx.IsAborted() { + return + } + + // 检查缓存 + if p.httpCache != nil && ctx.Req.Method == http.MethodGet && p.config.EnableCache { + cacheKey := cache.GenerateCacheKey(ctx.Req) + if p.cacheAdapter != nil { + cachedResp, ok := p.cacheAdapter.Get(cacheKey) + if ok && cachedResp != nil { + // 使用缓存的响应 + cached := cachedResp.(*http.Response) + p.delegate.BeforeResponse(ctx, cached, nil) + if !ctx.IsAborted() { + responseFunc(cached, nil) + } + + // 更新指标 + if p.metrics != nil && isCacheHitMetricsSupported(p.metrics) { + incrementCacheHit(p.metrics) + } + return + } + } + } + + // 准备请求 + newReq := ctx.Req.Clone(ctx.Req.Context()) + + // 移除hop-by-hop头部 + for _, h := range hopHeaders { + newReq.Header.Del(h) + } + + // 添加客户端跟踪 + if p.clientTrace != nil { + newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace)) + } + + // 执行请求 + resp, err := p.transport.RoundTrip(newReq) + + // 响应前处理 + p.delegate.BeforeResponse(ctx, resp, err) + if ctx.IsAborted() { + if resp != nil { + resp.Body.Close() + } + return + } + + // 错误处理 + if err != nil { + responseFunc(nil, err) + return + } + + // 移除hop-by-hop头部 + for _, h := range hopHeaders { + resp.Header.Del(h) + } + + // 缓存响应 + if p.httpCache != nil && p.config.EnableCache && cache.ShouldCache(ctx.Req, resp) { + cacheKey := cache.GenerateCacheKey(ctx.Req) + if p.cacheAdapter != nil { + p.cacheAdapter.Set(cacheKey, resp, getCacheTTL(resp)) + } + } + + // 返回响应 + responseFunc(resp, nil) +} + +// HTTPS代理 +func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) { + if isWebSocketRequest(ctx.Req) { + p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil)) + return + } + + p.DoRequest(ctx, func(resp *http.Response, err error) { + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,请求错误: %s", ctx.Req.URL, err)) + tlsClientConn.Write(badGatewayResponse) + return + } + + // 直接写入TLS连接 + err = resp.Write(tlsClientConn) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,响应写入客户端失败: %s", ctx.Req.URL, err)) + } + resp.Body.Close() + }) +} + +// 隧道代理 +func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) { + // 获取客户端连接 + clientConn, err := hijacker(rw) + if err != nil { + p.delegate.ErrorLog(err) + rw.WriteHeader(http.StatusBadGateway) + return + } + defer clientConn.Close() + + // 处理WebSocket请求 + if isWebSocketRequest(ctx.Req) { + p.websocketProxy(ctx, clientConn) + return + } + + // 获取上级代理 + parentProxyURL, err := p.delegate.ParentProxy(ctx.Req) + if ctx.ParentProxyURL != nil { + parentProxyURL = ctx.ParentProxyURL + } + + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err)) + rw.WriteHeader(http.StatusBadGateway) + return + } + + // 如果不使用上级代理,通知客户端隧道已建立 + if parentProxyURL == nil { + _, err = clientConn.Write(tunnelEstablishedResponseLine) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err)) + return + } + } + + // 检测WebSocket + isWebsocket := false + methodBytes, err := clientConn.Peek(3) + if err == nil && string(methodBytes) == http.MethodGet { + isWebsocket = true + } + + // 处理WebSocket + if isWebsocket { + req, err := http.ReadRequest(clientConn.BufferReader()) + if err != nil { + if err != io.EOF { + p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + req.RemoteAddr = ctx.Req.RemoteAddr + req.URL.Scheme = "http" + req.URL.Host = req.Host + ctx.Req = req + + p.websocketProxy(ctx, clientConn) + return + } + + // HTTPS解密 + var tlsClientConn *tls.Conn + if p.config.DecryptHTTPS { + // 生成证书 + certConfig, err := p.generateTLSConfig(ctx.Req.URL.Host) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,生成证书失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 创建TLS服务器连接 + tlsClientConn = tls.Server(clientConn, certConfig) + defer tlsClientConn.Close() + + // TLS握手 + if err := tlsClientConn.Handshake(); err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,握手失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 读取HTTPS请求 + buf := bufio.NewReader(tlsClientConn) + tlsReq, err := http.ReadRequest(buf) + if err != nil { + if err != io.EOF { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密,读取客户端请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + + // 更新请求信息 + tlsReq.RemoteAddr = ctx.Req.RemoteAddr + tlsReq.URL.Scheme = "https" + tlsReq.URL.Host = tlsReq.Host + ctx.Req = tlsReq + } + + // 确定目标地址 + targetAddr := ctx.Req.URL.Host + if ctx.TargetAddr != "" { + targetAddr = ctx.TargetAddr + } else if parentProxyURL != nil { + targetAddr = parentProxyURL.Host + } + + // 确保地址包含端口 + if !strings.Contains(targetAddr, ":") { + targetAddr += ":443" + } + + // 连接目标服务器 + targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err)) + return + } + defer targetConn.Close() + + // 向上级代理发送CONNECT请求 + if parentProxyURL != nil { + tunnelRequestLine := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", ctx.Req.URL.Host, ctx.Req.URL.Host) + _, err = targetConn.Write([]byte(tunnelRequestLine)) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 向上级代理发送CONNECT请求失败: %s", ctx.Req.URL.Host, err)) + return + } + + // 读取上级代理响应 + bufReader := bufio.NewReader(targetConn) + resp, err := http.ReadResponse(bufReader, ctx.Req) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 读取上级代理响应失败: %s", ctx.Req.URL.Host, err)) + return + } + defer resp.Body.Close() + + // 检查上级代理响应 + if resp.StatusCode != http.StatusOK { + p.tunnelConnected(ctx, fmt.Errorf("上级代理返回错误状态码: %d", resp.StatusCode)) + p.delegate.ErrorLog(fmt.Errorf("%s - 上级代理返回错误状态码: %d", ctx.Req.URL.Host, resp.StatusCode)) + return + } + } + + // 处理HTTPS解密或直接隧道转发 + if p.config.DecryptHTTPS { + p.httpsProxy(ctx, tlsClientConn) + } else { + p.tunnelConnected(ctx, nil) + p.transfer(clientConn, targetConn) + } +} + +// WebSocket代理 +func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) { + if !p.config.WebSocketIntercept { + // 不拦截WebSocket,直接转发 + remoteAddr := ctx.Addr() + var err error + var targetConn net.Conn + + // 根据协议建立连接 + if ctx.IsHTTPS() { + targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) + } else { + targetConn, err = net.Dial("tcp", remoteAddr) + } + + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 将请求转发给目标 + err = ctx.Req.Write(targetConn) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 开始转发数据 + p.tunnelConnected(ctx, nil) + p.transfer(srcConn, targetConn) + return + } + + // 简单直接转发WebSocket,不使用WebSocket库 + remoteAddr := ctx.Addr() + var err error + var targetConn net.Conn + + // 根据协议建立连接 + if ctx.IsHTTPS() { + targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) + } else { + targetConn, err = net.Dial("tcp", remoteAddr) + } + + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 将请求转发给目标 + err = ctx.Req.Write(targetConn) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + + // 开始转发数据 + p.tunnelConnected(ctx, nil) + p.transfer(srcConn, targetConn) +} + +// 双向转发 +func (p *Proxy) transfer(src net.Conn, dst net.Conn) { + // 创建完成通道 + done := make(chan struct{}, 2) + + // src -> dst + go func() { + buf := bufPool.Get().([]byte) + written, err := io.CopyBuffer(dst, src, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) + } + + // 记录传输字节数 + if p.metrics != nil { + p.metrics.AddBytesTransferred("request", written) + } + + bufPool.Put(buf) + dst.Close() + done <- struct{}{} + }() + + // dst -> src + go func() { + buf := bufPool.Get().([]byte) + written, err := io.CopyBuffer(src, dst, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err)) + } + + // 记录传输字节数 + if p.metrics != nil { + p.metrics.AddBytesTransferred("response", written) + } + + bufPool.Put(buf) + src.Close() + done <- struct{}{} + }() + + // 等待两个方向都结束 + <-done + <-done +} + +// 隧道连接处理 +func (p *Proxy) tunnelConnected(ctx *Context, err error) { + ctx.TunnelProxy = true + p.delegate.BeforeRequest(ctx) + if err != nil { + p.delegate.BeforeResponse(ctx, nil, err) + return + } + + resp := &http.Response{ + Status: "200 Connection Established", + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: http.NoBody, + } + p.delegate.BeforeResponse(ctx, resp, nil) +} + +// 使用DNS缓存的DialContext +func (p *Proxy) dialContextWithCache() func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + // 创建拨号器 + dialer := &net.Dialer{ + Timeout: defaultTargetConnectTimeout, + KeepAlive: 30 * time.Second, + } + + // 如果没有启用DNS缓存,直接拨号 + if p.dnsCache == nil { + return dialer.DialContext(ctx, network, addr) + } + + // 解析主机和端口 + separator := strings.LastIndex(addr, ":") + if separator < 0 { + return nil, fmt.Errorf("invalid address: %s", addr) + } + + host := addr[:separator] + port := addr[separator:] + + // 查询DNS缓存 + ips, err := p.dnsCache.Fetch(host) + if err != nil { + return nil, err + } + + // 使用第一个IPv4地址 + var ip string + for _, item := range ips { + ip = item.String() + if !strings.Contains(ip, ":") { + break + } + } + + if ip == "" { + return nil, fmt.Errorf("no valid IP address found for: %s", host) + } + + // 连接到解析后的IP + return dialer.DialContext(ctx, network, ip+port) + } +} + +// 从委托获取代理 +func (p *Proxy) proxyFromDelegate(req *http.Request) (*url.URL, error) { + if p.loadBalancer != nil && p.config.EnableLoadBalancing { + // 使用负载均衡 + host := req.URL.Hostname() + return p.loadBalancer.Next(host) + } + // 使用委托 + return p.delegate.ParentProxy(req) +} + +// 生成TLS配置 +func (p *Proxy) generateTLSConfig(host string) (*tls.Config, error) { + // 如果没有证书管理器,则创建一个 + if p.certManager == nil { + // 创建证书管理器,使用已有的证书缓存 + options := []CertManagerOption{ + WithDefaultPrivateKey(true), // 使用默认私钥提高性能 + WithValidityYears(1), // 证书有效期1年 + } + p.certManager = NewCertManager(p.certCache, options...) + } + + // 1. 首先检查是否配置了自定义证书 + if p.config.TLSCert != "" && p.config.TLSKey != "" { + cert, err := tls.LoadX509KeyPair(p.config.TLSCert, p.config.TLSKey) + if err != nil { + return nil, fmt.Errorf("加载TLS证书失败: %s", err) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, nil + } + + // 2. 检查是否配置了CA证书和密钥(用于动态生成证书) + if p.config.CACert != "" && p.config.CAKey != "" { + // 加载CA证书和私钥 + caCert, caKey, err := LoadCAFromFiles(p.config.CACert, p.config.CAKey) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("加载CA证书和私钥失败: %s", err)) + // 如果加载失败,使用默认CA + return p.certManager.GenerateTLSConfig(host) + } + + // 使用自定义CA生成证书 + cert, err := p.certManager.GenerateCertificate(host, caCert, caKey) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("为%s生成动态证书失败: %s", host, err)) + return nil, err + } + + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + }, nil + } + + // 3. 使用默认CA生成证书 + tlsConfig, err := p.certManager.GenerateTLSConfig(host) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("为%s使用默认CA生成证书失败: %s", host, err)) + return nil, err + } + + return tlsConfig, nil +} + +// 获取客户端连接 +func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) { + hijacker, ok := rw.(http.Hijacker) + if !ok { + return nil, fmt.Errorf("http server不支持Hijacker") + } + conn, bufrw, err := hijacker.Hijack() + if err != nil { + return nil, fmt.Errorf("hijacker错误: %s", err) + } + + return NewConnBuffer(conn, bufrw.Reader), nil +} + +// 检查是否是WebSocket请求 +func isWebSocketRequest(req *http.Request) bool { + if req == nil { + return false + } + + // 检查Connection头 + connection := strings.ToLower(req.Header.Get("Connection")) + if !strings.Contains(connection, "upgrade") { + return false + } + + // 检查Upgrade头 + upgrade := strings.ToLower(req.Header.Get("Upgrade")) + if upgrade != "websocket" { + return false + } + + return true +} + +// hop-by-hop 头部 +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// isCacheHitMetricsSupported 检查指标是否支持缓存命中计数 +func isCacheHitMetricsSupported(m metrics.Metrics) bool { + _, ok := m.(interface{ IncCacheHit() }) + return ok +} + +// incrementCacheHit 增加缓存命中计数 +func incrementCacheHit(m metrics.Metrics) { + if hitter, ok := m.(interface{ IncCacheHit() }); ok { + hitter.IncCacheHit() + } +} diff --git a/internal/proxy/reverse_proxy.go b/internal/proxy/reverse_proxy.go new file mode 100644 index 0000000..e51ea6a --- /dev/null +++ b/internal/proxy/reverse_proxy.go @@ -0,0 +1,277 @@ +package proxy + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" + "strings" + "time" + + "github.com/goproxy/internal/rewriter" + "github.com/goproxy/internal/router" +) + +// ReverseProxy 反向代理 +type ReverseProxy struct { + // 代理对象 + proxy *Proxy + // 路由器 + router *router.Router + // URL重写器 + rewriter *rewriter.Rewriter + // HTTP传输对象 + transport http.RoundTripper +} + +// NewReverseProxy 创建反向代理 +func (p *Proxy) NewReverseProxy() *ReverseProxy { + rp := &ReverseProxy{ + proxy: p, + router: router.NewRouter(), + rewriter: rewriter.NewRewriter(), + } + + // 创建自定义的传输对象 + transport := &http.Transport{ + Proxy: func(req *http.Request) (*url.URL, error) { + // 使用代理委托中的方法获取代理 + return p.delegate.ParentProxy(req) + }, + DialContext: p.dialContextWithCache(), + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + MaxIdleConns: p.config.ConnectionPoolSize, + IdleConnTimeout: p.config.IdleTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConnsPerHost: p.config.ConnectionPoolSize, + DisableCompression: !p.config.EnableCompression, + } + + rp.transport = transport + + // 如果配置了规则文件,加载规则 + if p.config.ReverseProxyRulesFile != "" { + // 省略加载规则文件的实现 + } + + return rp +} + +// ServeHTTP 处理反向代理请求 +func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // 获取请求上下文 + ctx := ctxPool.Get().(*Context) + ctx.Reset(req) + defer ctxPool.Put(ctx) + + // 调用连接事件 + rp.proxy.delegate.Connect(ctx, rw) + + // 认证检查 + rp.proxy.delegate.Auth(ctx, rw) + if ctx.IsAborted() { + return + } + + // 请求前处理 + rp.proxy.delegate.BeforeRequest(ctx) + if ctx.IsAborted() { + return + } + + // 解析后端地址 + backend, err := rp.proxy.delegate.ResolveBackend(req) + if err != nil { + rp.proxy.delegate.HandleError(rw, req, err) + return + } + + // 创建请求代理对象 + proxy := httputil.NewSingleHostReverseProxy(&url.URL{ + Scheme: "http", + Host: backend, + }) + + // 使用自定义传输对象 + proxy.Transport = rp.transport + + // 设置自定义错误处理函数 + proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) { + rp.proxy.delegate.HandleError(rw, req, err) + } + + // 设置请求修改函数 + originalDirector := proxy.Director + proxy.Director = func(req *http.Request) { + // 调用原始Director函数 + originalDirector(req) + + // 处理URL重写 + if rp.proxy.config.EnableURLRewrite { + rp.rewriter.Rewrite(req) + } + + // 修改请求头 + if rp.proxy.config.RewriteHostHeader { + req.Host = backend + } + + // 添加X-Forwarded-For头 + if rp.proxy.config.AddXForwardedFor { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + // 如果已经有X-Forwarded-For,添加到末尾 + if prior, ok := req.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + req.Header.Set("X-Forwarded-For", clientIP) + } + } + + // 添加X-Real-IP头 + if rp.proxy.config.AddXRealIP { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err == nil { + req.Header.Set("X-Real-IP", clientIP) + } + } + + // 设置协议头 + req.Header.Set("X-Forwarded-Proto", "http") + if req.TLS != nil { + req.Header.Set("X-Forwarded-Proto", "https") + } + + // 调用委托的ModifyRequest方法 + rp.proxy.delegate.ModifyRequest(req) + } + + // 设置响应修改函数 + proxy.ModifyResponse = func(resp *http.Response) error { + // 处理响应URL重写 + if rp.proxy.config.EnableURLRewrite && resp != nil { + rp.rewriter.RewriteResponse(resp, req.Host) + } + + // 添加CORS头 + if rp.proxy.config.EnableCORS && resp != nil { + resp.Header.Set("Access-Control-Allow-Origin", "*") + resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + } + + // 调用委托的ModifyResponse方法 + return rp.proxy.delegate.ModifyResponse(resp) + } + + // 更新监控指标 + if rp.proxy.metrics != nil { + rp.proxy.metrics.IncActiveConnections() + defer rp.proxy.metrics.DecActiveConnections() + + startTime := time.Now() + defer func() { + duration := time.Since(startTime) + rp.proxy.metrics.ObserveRequestDuration(duration.Seconds()) + rp.proxy.metrics.IncRequestCount() + }() + } + + // 处理WebSocket升级 + if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) { + rp.handleWebSocketUpgrade(rw, req, backend) + return + } + + // 处理普通请求 + proxy.ServeHTTP(rw, req) + + // 完成事件 + rp.proxy.delegate.Finish(ctx) +} + +// 处理WebSocket升级 +func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) { + // 创建WebSocket代理 + target := &url.URL{ + Scheme: "ws", + Host: backend, + } + + if req.TLS != nil { + target.Scheme = "wss" + } + + // 创建连接到后端的WebSocket连接 + backendConn, err := rp.dialBackend(target.String(), req) + if err != nil { + rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err)) + return + } + defer backendConn.Close() + + // 将请求转发给后端 + err = req.Write(backendConn) + if err != nil { + rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err)) + return + } + + // 升级客户端连接 + clientConn, err := hijacker(rw) + if err != nil { + rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err)) + return + } + + // 双向转发数据 + rp.proxy.transfer(clientConn, backendConn) +} + +// 连接到后端 +func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) { + ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second) + defer cancel() + + backend := strings.TrimPrefix(url, "ws://") + backend = strings.TrimPrefix(backend, "wss://") + + if strings.Contains(backend, "/") { + backend = backend[:strings.Index(backend, "/")] + } + + // 根据协议选择连接方式 + if strings.HasPrefix(url, "wss://") { + // 使用 tls.Dialer 替代不存在的 tls.DialWithContext + dialer := &tls.Dialer{ + Config: &tls.Config{ + InsecureSkipVerify: true, + }, + } + return dialer.DialContext(ctx, "tcp", backend) + } + + var d net.Dialer + return d.DialContext(ctx, "tcp", backend) +} + +// 添加路由规则 +func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) { + route := &router.Route{ + Pattern: pattern, + Type: routeType, + Target: target, + } + rp.router.AddRoute(route) +} + +// 添加重写规则 +func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error { + return rp.rewriter.AddRule(pattern, replacement, useRegex) +} diff --git a/internal/rewriter/rewriter.go b/internal/rewriter/rewriter.go new file mode 100644 index 0000000..d89093b --- /dev/null +++ b/internal/rewriter/rewriter.go @@ -0,0 +1,98 @@ +package rewriter + +import ( + "net/http" + "regexp" + "strings" +) + +// Rewriter URL重写器 +// 用于在反向代理中重写请求URL +type Rewriter struct { + // 重写规则列表 + rules []*RewriteRule +} + +// RewriteRule 重写规则 +type RewriteRule struct { + // 匹配模式 + Pattern string + // 替换模式 + Replacement string + // 是否使用正则表达式 + UseRegex bool + // 编译后的正则表达式 + regex *regexp.Regexp +} + +// NewRewriter 创建URL重写器 +func NewRewriter() *Rewriter { + return &Rewriter{ + rules: make([]*RewriteRule, 0), + } +} + +// AddRule 添加重写规则 +func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error { + rule := &RewriteRule{ + Pattern: pattern, + Replacement: replacement, + UseRegex: useRegex, + } + + if useRegex { + regex, err := regexp.Compile(pattern) + if err != nil { + return err + } + rule.regex = regex + } + + r.rules = append(r.rules, rule) + return nil +} + +// Rewrite 重写URL +func (r *Rewriter) Rewrite(req *http.Request) { + path := req.URL.Path + + for _, rule := range r.rules { + if rule.UseRegex { + if rule.regex.MatchString(path) { + req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement) + break + } + } else { + if strings.HasPrefix(path, rule.Pattern) { + req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1) + break + } + } + } +} + +// RewriteResponse 重写响应 +// 主要用于处理响应中的Location头和内容中的URL +func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) { + // 处理重定向头 + location := resp.Header.Get("Location") + if location != "" { + // 将后端服务器的域名替换成代理服务器的域名 + for _, rule := range r.rules { + if rule.UseRegex && rule.regex != nil { + if rule.regex.MatchString(location) { + newLocation := rule.regex.ReplaceAllString(location, rule.Replacement) + resp.Header.Set("Location", newLocation) + break + } + } + } + } +} + +// LoadRulesFromFile 从文件加载重写规则 +func (r *Rewriter) LoadRulesFromFile(filename string) error { + // 实现从配置文件加载规则的逻辑 + // 这里省略实现细节 + return nil +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..df73962 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,103 @@ +package router + +import ( + "net/http" + "regexp" + "strings" +) + +// Route 路由规则 +type Route struct { + // 匹配模式(主机名、路径、正则表达式) + Pattern string + // 匹配类型 + Type RouteType + // 目标地址 + Target string + // 路径重写规则 + RewritePattern string + // 请求头修改 + HeaderModifier HeaderModifier + // 自定义匹配函数 + MatchFunc func(req *http.Request) bool +} + +// RouteType 路由类型 +type RouteType int + +const ( + // HostRoute 主机名路由 + HostRoute RouteType = iota + // PathRoute 路径路由 + PathRoute + // RegexRoute 正则表达式路由 + RegexRoute + // CustomRoute 自定义路由 + CustomRoute +) + +// HeaderModifier 头部修改接口 +type HeaderModifier interface { + // ModifyRequest 修改请求头 + ModifyRequest(req *http.Request) + // ModifyResponse 修改响应头 + ModifyResponse(resp *http.Response) +} + +// Router 路由器 +type Router struct { + routes []*Route +} + +// NewRouter 创建路由器 +func NewRouter() *Router { + return &Router{ + routes: make([]*Route, 0), + } +} + +// AddRoute 添加路由规则 +func (r *Router) AddRoute(route *Route) { + r.routes = append(r.routes, route) +} + +// Match 匹配请求 +func (r *Router) Match(req *http.Request) (*Route, bool) { + for _, route := range r.routes { + switch route.Type { + case HostRoute: + if matchHost(req.Host, route.Pattern) { + return route, true + } + case PathRoute: + if matchPath(req.URL.Path, route.Pattern) { + return route, true + } + case RegexRoute: + if matchRegex(req.URL.String(), route.Pattern) { + return route, true + } + case CustomRoute: + if route.MatchFunc != nil && route.MatchFunc(req) { + return route, true + } + } + } + return nil, false +} + +// 匹配主机名 +func matchHost(host, pattern string) bool { + return host == pattern || strings.HasSuffix(host, "."+pattern) +} + +// 匹配路径 +func matchPath(path, pattern string) bool { + return strings.HasPrefix(path, pattern) +} + +// 匹配正则表达式 +func matchRegex(url, pattern string) bool { + matched, _ := regexp.MatchString(pattern, url) + return matched +} diff --git a/mitm-proxy.crt b/mitm-proxy.crt new file mode 100644 index 0000000..e287170 --- /dev/null +++ b/mitm-proxy.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF +Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK +EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy +NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G +A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw +EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX +mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO +BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG +A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI +MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC +A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY +9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO +-----END CERTIFICATE----- \ No newline at end of file diff --git a/proxy.go b/proxy.go new file mode 100644 index 0000000..e7ed662 --- /dev/null +++ b/proxy.go @@ -0,0 +1,753 @@ +// Copyright 2018 ouqiang authors +// +// Licensed under the Apache License, Version 2.0 (the "License"): you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据 +package goproxy + +import ( + "bufio" + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/http/httptrace" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/viki-org/dnscache" + + "github.com/ouqiang/goproxy/cert" + "github.com/ouqiang/websocket" +) + +const ( + // 连接目标服务器超时时间 + defaultTargetConnectTimeout = 5 * time.Second + // 目标服务器读写超时时间 + defaultTargetReadWriteTimeout = 10 * time.Second +) + +type DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + +// 隧道连接成功响应行 +var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n") + +var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))) + +var ( + bufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 32*1024) + }, + } + + ctxPool = sync.Pool{ + New: func() interface{} { + return new(Context) + }, + } + headerPool = NewHeaderPool() + requestPool = newRequestPool() +) + +type RequestPool struct { + pool sync.Pool +} + +func newRequestPool() *RequestPool { + return &RequestPool{ + pool: sync.Pool{ + New: func() interface{} { + return new(http.Request) + }, + }, + } +} + +func (p *RequestPool) Get() *http.Request { + req := p.pool.Get().(*http.Request) + + req.Method = "" + req.URL = nil + req.Proto = "" + req.ProtoMajor = 0 + req.ProtoMinor = 0 + req.Header = nil + req.Body = nil + req.GetBody = nil + req.ContentLength = 0 + req.TransferEncoding = nil + req.Close = false + req.Host = "" + req.Form = nil + req.PostForm = nil + req.MultipartForm = nil + req.Trailer = nil + req.RemoteAddr = "" + req.RequestURI = "" + req.TLS = nil + req.Cancel = nil + req.Response = nil + + return req +} + +func (p *RequestPool) Put(req *http.Request) { + if req != nil { + p.pool.Put(req) + } +} + +type HeaderPool struct { + pool sync.Pool +} + +func NewHeaderPool() *HeaderPool { + return &HeaderPool{ + pool: sync.Pool{ + New: func() interface{} { + return http.Header{} + }, + }, + } +} + +func (p *HeaderPool) Get() http.Header { + header := p.pool.Get().(http.Header) + for k := range header { + delete(header, k) + } + + return header +} + +func (p *HeaderPool) Put(header http.Header) { + if header != nil { + p.pool.Put(header) + } +} + +// 生成隧道建立请求行 +func makeTunnelRequestLine(addr string) string { + return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr) +} + +type options struct { + disableKeepAlive bool + delegate Delegate + + decryptHTTPS bool + websocketIntercept bool + certCache cert.Cache + transport *http.Transport + clientTrace *httptrace.ClientTrace +} + +type Option func(*options) + +// WithDisableKeepAlive 连接是否重用 +func WithDisableKeepAlive(disableKeepAlive bool) Option { + return func(opt *options) { + opt.disableKeepAlive = disableKeepAlive + } +} + +func WithClientTrace(t *httptrace.ClientTrace) Option { + return func(opt *options) { + opt.clientTrace = t + } +} + +// WithDelegate 设置委托类 +func WithDelegate(delegate Delegate) Option { + return func(opt *options) { + opt.delegate = delegate + } +} + +// WithTransport 自定义http transport +func WithTransport(t *http.Transport) Option { + return func(opt *options) { + opt.transport = t + } +} + +// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口 +func WithDecryptHTTPS(c cert.Cache) Option { + return func(opt *options) { + opt.decryptHTTPS = true + opt.certCache = c + } +} + +// WithEnableWebsocketIntercept 拦截websocket +func WithEnableWebsocketIntercept() Option { + return func(opt *options) { + opt.websocketIntercept = true + } +} + +// New 创建proxy实例 +func New(opt ...Option) *Proxy { + opts := &options{} + for _, o := range opt { + o(opts) + } + if opts.delegate == nil { + opts.delegate = &DefaultDelegate{} + } + if opts.transport == nil { + opts.transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + MaxIdleConns: 100, + MaxConnsPerHost: 10, + IdleConnTimeout: 10 * time.Second, + TLSHandshakeTimeout: 5 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + } + + p := &Proxy{} + p.delegate = opts.delegate + p.websocketIntercept = opts.websocketIntercept + p.decryptHTTPS = opts.decryptHTTPS + if p.decryptHTTPS { + p.cert = cert.NewCertificate(opts.certCache, true) + } + p.transport = opts.transport + p.transport.DialContext = p.dialContext() + p.dnsCache = dnscache.New(5 * time.Minute) + p.transport.DisableKeepAlives = opts.disableKeepAlive + p.transport.Proxy = p.delegate.ParentProxy + p.clientTrace = opts.clientTrace + + return p +} + +// Proxy 实现了http.Handler接口 +type Proxy struct { + delegate Delegate + clientConnNum int32 + decryptHTTPS bool + websocketIntercept bool + cert *cert.Certificate + transport *http.Transport + clientTrace *httptrace.ClientTrace + dnsCache *dnscache.Resolver +} + +var _ http.Handler = &Proxy{} + +// ServeHTTP 实现了http.Handler接口 +func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.URL.Host == "" { + req.URL.Host = req.Host + } + atomic.AddInt32(&p.clientConnNum, 1) + ctx := ctxPool.Get().(*Context) + ctx.Reset(req) + + defer func() { + p.delegate.Finish(ctx) + ctxPool.Put(ctx) + atomic.AddInt32(&p.clientConnNum, -1) + }() + p.delegate.Connect(ctx, rw) + if ctx.abort { + return + } + p.delegate.Auth(ctx, rw) + if ctx.abort { + return + } + + switch { + case ctx.Req.Method == http.MethodConnect: + p.tunnelProxy(ctx, rw) + case websocket.IsWebSocketUpgrade(ctx.Req): + p.tunnelProxy(ctx, rw) + default: + p.httpProxy(ctx, rw) + } +} + +// ClientConnNum 获取客户端连接数 +func (p *Proxy) ClientConnNum() int32 { + return atomic.LoadInt32(&p.clientConnNum) +} + +// DoRequest 执行HTTP请求,并调用responseFunc处理response +func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) { + if ctx.Data == nil { + ctx.Data = make(map[interface{}]interface{}) + } + p.delegate.BeforeRequest(ctx) + if ctx.abort { + return + } + newReq := requestPool.Get() + *newReq = *ctx.Req + newHeader := headerPool.Get() + CloneHeader(newReq.Header, newHeader) + newReq.Header = newHeader + for _, item := range hopHeaders { + if newReq.Header.Get(item) != "" { + newReq.Header.Del(item) + } + } + if p.clientTrace != nil { + newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace)) + } + + resp, err := p.transport.RoundTrip(newReq) + p.delegate.BeforeResponse(ctx, resp, err) + if ctx.abort { + return + } + if err == nil { + for _, h := range hopHeaders { + resp.Header.Del(h) + } + } + responseFunc(resp, err) + headerPool.Put(newHeader) + requestPool.Put(newReq) +} + +// HTTP代理 +func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) { + ctx.Req.URL.Scheme = "http" + p.DoRequest(ctx, func(resp *http.Response, err error) { + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err)) + rw.WriteHeader(http.StatusBadGateway) + return + } + defer func() { + _ = resp.Body.Close() + }() + CopyHeader(rw.Header(), resp.Header) + rw.WriteHeader(resp.StatusCode) + buf := bufPool.Get().([]byte) + _, _ = io.CopyBuffer(rw, resp.Body, buf) + bufPool.Put(buf) + }) +} + +// HTTPS代理 +func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) { + if websocket.IsWebSocketUpgrade(ctx.Req) { + p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil)) + return + } + p.DoRequest(ctx, func(resp *http.Response, err error) { + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err)) + _, _ = tlsClientConn.Write(badGateway) + return + } + err = resp.Write(tlsClientConn) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err)) + } + _ = resp.Body.Close() + }) +} + +// 隧道代理 +func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) { + clientConn, err := hijacker(rw) + if err != nil { + p.delegate.ErrorLog(err) + rw.WriteHeader(http.StatusBadGateway) + return + } + defer func() { + _ = clientConn.Close() + }() + + if websocket.IsWebSocketUpgrade(ctx.Req) { + p.websocketProxy(ctx, clientConn) + return + } + + parentProxyURL, err := p.delegate.ParentProxy(ctx.Req) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err)) + rw.WriteHeader(http.StatusBadGateway) + return + } + if parentProxyURL == nil { + _, err = clientConn.Write(tunnelEstablishedResponseLine) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err)) + return + } + } + + isWebsocket := p.detectConnProtocol(clientConn) + if isWebsocket { + req, err := http.ReadRequest(clientConn.BufferReader()) + if err != nil { + if err != io.EOF { + p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + req.RemoteAddr = ctx.Req.RemoteAddr + req.URL.Scheme = "http" + req.URL.Host = req.Host + ctx.Req = req + + p.websocketProxy(ctx, clientConn) + return + } + var tlsClientConn *tls.Conn + if p.decryptHTTPS { + tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err)) + return + } + tlsClientConn = tls.Server(clientConn, tlsConfig) + defer func() { + _ = tlsClientConn.Close() + }() + if err := tlsClientConn.Handshake(); err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err)) + return + } + + buf := bufio.NewReader(tlsClientConn) + tlsReq, err := http.ReadRequest(buf) + if err != nil { + if err != io.EOF { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err)) + } + return + } + tlsReq.RemoteAddr = ctx.Req.RemoteAddr + tlsReq.URL.Scheme = "https" + tlsReq.URL.Host = tlsReq.Host + ctx.Req = tlsReq + } + + targetAddr := ctx.Req.URL.Host + if parentProxyURL != nil { + targetAddr = parentProxyURL.Host + } + if !strings.Contains(targetAddr, ":") { + targetAddr += ":443" + } + + targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err)) + return + } + defer func() { + _ = targetConn.Close() + }() + if parentProxyURL != nil { + tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host) + _, _ = targetConn.Write([]byte(tunnelRequestLine)) + } + + if p.decryptHTTPS { + p.httpsProxy(ctx, tlsClientConn) + } else { + p.tunnelConnected(ctx, nil) + p.transfer(clientConn, targetConn) + } +} + +// WebSocket代理 +func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) { + if !p.websocketIntercept { + remoteAddr := ctx.Addr() + var err error + var targetConn net.Conn + if ctx.IsHTTPS() { + targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) + } else { + targetConn, err = net.Dial("tcp", remoteAddr) + } + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + err = ctx.Req.Write(targetConn) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) + return + } + p.tunnelConnected(ctx, nil) + p.transfer(srcConn, targetConn) + return + } + + up := &websocket.Upgrader{ + HandshakeTimeout: defaultTargetConnectTimeout, + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, + } + srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{}) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err)) + return + } + + u := ctx.WebsocketUrl() + d := websocket.Dialer{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + } + + dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout) + defer cancel() + targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header) + if err != nil { + p.tunnelConnected(ctx, err) + p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err)) + return + } + p.tunnelConnected(ctx, nil) + p.transferWebsocket(ctx, srcWSConn, targetWSConn) +} + +// 探测连接协议 +func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) { + methodBytes, err := connBuf.Peek(3) + if err != nil { + return false + } + method := string(methodBytes) + if method != http.MethodGet { + return false + } + + return true +} + +// webSocket双向转发 +func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) { + doneCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for { + if doneCtx.Err() != nil { + return + } + + msgType, msg, err := srcConn.ReadMessage() + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err)) + return + } + p.delegate.WebSocketSendMessage(ctx, &msgType, &msg) + err = targetConn.WriteMessage(msgType, msg) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err)) + return + } + } + }() + + for { + if doneCtx.Err() != nil { + return + } + + msgType, msg, err := targetConn.ReadMessage() + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err)) + return + } + p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg) + err = srcConn.WriteMessage(msgType, msg) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err)) + return + } + } +} + +// 双向转发 +func (p *Proxy) transfer(src net.Conn, dst net.Conn) { + go func() { + buf := bufPool.Get().([]byte) + _, err := io.CopyBuffer(src, dst, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err)) + } + bufPool.Put(buf) + _ = src.Close() + _ = dst.Close() + }() + + buf := bufPool.Get().([]byte) + _, err := io.CopyBuffer(dst, src, buf) + if err != nil { + p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) + } + bufPool.Put(buf) + _ = dst.Close() + _ = src.Close() +} + +func (p *Proxy) tunnelConnected(ctx *Context, err error) { + ctx.TunnelProxy = true + p.delegate.BeforeRequest(ctx) + if err != nil { + p.delegate.BeforeResponse(ctx, nil, err) + return + } + resp := &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Proto: "1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: http.NoBody, + } + p.delegate.BeforeResponse(ctx, resp, nil) +} + +func (p *Proxy) dialContext() DialContext { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + dialer := &net.Dialer{ + Timeout: defaultTargetConnectTimeout, + } + separator := strings.LastIndex(addr, ":") + ips, err := p.dnsCache.Fetch(addr[:separator]) + if err != nil { + return nil, err + } + var ip string + for _, item := range ips { + ip = item.String() + if !strings.Contains(ip, ":") { + break + } + } + + addr = ip + addr[separator:] + + return dialer.DialContext(ctx, network, addr) + } +} + +// 获取底层连接 +func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) { + hijacker, ok := rw.(http.Hijacker) + if !ok { + return nil, fmt.Errorf("http server不支持Hijacker") + } + conn, buf, err := hijacker.Hijack() + if err != nil { + return nil, fmt.Errorf("hijacker错误: %s", err) + } + + return NewConnBuffer(conn, buf), nil +} + +// CopyHeader 浅拷贝Header +func CopyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// CloneHeader 深拷贝Header +func CloneHeader(h http.Header, h2 http.Header) { + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } +} + +var hopHeaders = []string{ + "Proxy-Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailer", + "Transfer-Encoding", +} + +type ConnBuffer struct { + net.Conn + buf *bufio.ReadWriter +} + +func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer { + if buf == nil { + buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + } + return &ConnBuffer{ + Conn: conn, + buf: buf, + } +} + +func (cb *ConnBuffer) BufferReader() *bufio.Reader { + return cb.buf.Reader +} + +func (cb *ConnBuffer) Read(b []byte) (n int, err error) { + return cb.buf.Read(b) +} + +func (cb *ConnBuffer) Peek(n int) ([]byte, error) { + return cb.buf.Peek(n) +} + +func (cb *ConnBuffer) Write(p []byte) (n int, err error) { + n, err = cb.buf.Write(p) + if err != nil { + return 0, err + } + + return n, cb.buf.Flush() +} + +func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return cb.Conn, cb.buf, nil +} + +func (cb *ConnBuffer) WriteHeader(_ int) {} + +func (cb *ConnBuffer) Header() http.Header { return nil }