init
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
.idea
|
||||
./delegate.go
|
||||
./proxy.go
|
454
README.md
Normal file
454
README.md
Normal file
@@ -0,0 +1,454 @@
|
||||
# GoProxy
|
||||
|
||||
GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebSocket代理,并提供了丰富的功能和扩展点。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 支持HTTP、HTTPS和WebSocket代理
|
||||
- 支持正向代理和反向代理
|
||||
- 支持HTTPS解密(中间人模式)
|
||||
- 自定义CA证书和私钥
|
||||
- 动态证书生成与缓存
|
||||
- 通配符域名证书支持
|
||||
- 支持RSA和ECDSA证书算法选择
|
||||
- 支持上游代理链
|
||||
- 支持负载均衡(轮询、随机、权重等)
|
||||
- 支持健康检查
|
||||
- 支持请求重试
|
||||
- 支持HTTP缓存
|
||||
- 支持请求限流
|
||||
- 支持监控指标收集
|
||||
- 支持自定义处理逻辑(委托模式)
|
||||
- 支持DNS缓存
|
||||
- 支持URL重写(反向代理模式)
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
go get github.com/goproxy
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 正向代理
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建代理
|
||||
p := proxy.New(nil)
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 启用HTTPS解密
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/goproxy/internal/config"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.DecryptHTTPS = true
|
||||
cfg.CACert = "ca.crt" // CA证书路径
|
||||
cfg.CAKey = "ca.key" // CA私钥路径
|
||||
cfg.UseECDSA = true // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 可选:使用自定义TLS证书
|
||||
// cfg.TLSCert = "server.crt"
|
||||
// cfg.TLSKey = "server.key"
|
||||
|
||||
// 创建证书缓存
|
||||
certCache := &proxy.MemCertCache{}
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(&proxy.Options{
|
||||
Config: cfg,
|
||||
CertCache: certCache,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("HTTPS解密代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **注意**: 使用HTTPS解密功能时,需要在客户端安装CA证书,否则会出现证书警告。
|
||||
|
||||
### 反向代理
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/goproxy/internal/config"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ReverseProxy = true
|
||||
cfg.EnableURLRewrite = true
|
||||
cfg.AddXForwardedFor = true
|
||||
cfg.AddXRealIP = true
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &ReverseProxyDelegate{
|
||||
backend: "localhost:8081",
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(&proxy.Options{
|
||||
Config: cfg,
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("反向代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseProxyDelegate 反向代理委托
|
||||
type ReverseProxyDelegate struct {
|
||||
proxy.DefaultDelegate
|
||||
backend string
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
return d.backend, nil
|
||||
}
|
||||
```
|
||||
|
||||
### 自定义委托
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{}
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(&proxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
proxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
log.Printf("响应错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||
}
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
- 正向代理示例: [cmd/example/main.go](cmd/example/main.go)
|
||||
- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go)
|
||||
|
||||
### 使用函数式选项模式
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/metrics"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建监控指标
|
||||
metricsCollector := metrics.NewSimpleMetrics()
|
||||
|
||||
// 创建证书缓存
|
||||
certCache := &proxy.MemCertCache{}
|
||||
|
||||
// 使用函数式选项模式创建代理
|
||||
p := proxy.NewProxy(
|
||||
// 启用HTTPS解密
|
||||
proxy.WithDecryptHTTPS(certCache),
|
||||
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
|
||||
// 设置监控指标
|
||||
proxy.WithMetrics(metricsCollector),
|
||||
|
||||
// 设置请求超时和连接池
|
||||
proxy.WithRequestTimeout(30 * time.Second),
|
||||
proxy.WithConnectionPoolSize(100),
|
||||
proxy.WithIdleTimeout(90 * time.Second),
|
||||
|
||||
// 启用DNS缓存
|
||||
proxy.WithDNSCacheTTL(10 * time.Minute),
|
||||
|
||||
// 启用请求重试
|
||||
proxy.WithEnableRetry(3, 1*time.Second, 10*time.Second),
|
||||
|
||||
// 启用CORS支持
|
||||
proxy.WithEnableCORS(true),
|
||||
)
|
||||
|
||||
// 启动HTTP服务器和监控服务器
|
||||
go func() {
|
||||
log.Println("监控服务器启动在 :8081")
|
||||
http.Handle("/metrics", metricsCollector)
|
||||
if err := http.ListenAndServe(":8081", nil); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 启动代理服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 架构设计
|
||||
|
||||
GoProxy采用模块化设计,主要包含以下模块:
|
||||
|
||||
- **代理核心(Proxy)**:处理HTTP请求和响应,实现代理功能
|
||||
- **反向代理(ReverseProxy)**:处理反向代理请求,支持URL重写和请求修改
|
||||
- **路由(Router)**:基于主机名、路径、正则表达式等规则路由请求到不同的后端
|
||||
- **URL重写(Rewriter)**:重写请求URL和响应中的URL
|
||||
- **代理上下文(Context)**:保存请求上下文信息,用于在处理过程中传递数据
|
||||
- **代理委托(Delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑
|
||||
- **连接缓冲区(ConnBuffer)**:封装网络连接,提供缓冲读写功能
|
||||
- **负载均衡(LoadBalancer)**:实现负载均衡算法,支持轮询、随机、权重等
|
||||
- **健康检查(HealthChecker)**:检查上游服务器的健康状态,自动剔除不健康的服务器
|
||||
- **缓存(Cache)**:实现HTTP缓存,减少重复请求
|
||||
- **缓存适配器(CacheAdapter)**:统一不同缓存实现的接口,提高代码可读性和性能
|
||||
- **证书生成(CertGenerator)**:动态生成TLS证书,支持HTTPS解密
|
||||
- **限流(RateLimit)**:实现请求限流,防止过载
|
||||
- **监控(Metrics)**:收集代理运行指标,用于监控和分析
|
||||
- **重试(Retry)**:实现请求重试,提高请求成功率
|
||||
|
||||
## 配置选项
|
||||
|
||||
GoProxy提供了丰富的配置选项,可以通过`Options`结构体进行配置:
|
||||
|
||||
```go
|
||||
type Options struct {
|
||||
// 配置
|
||||
Config *config.Config
|
||||
// 委托
|
||||
Delegate Delegate
|
||||
// 证书缓存
|
||||
CertCache CertificateCache
|
||||
// HTTP缓存
|
||||
HTTPCache cache.Cache
|
||||
// 负载均衡器
|
||||
LoadBalancer loadbalance.LoadBalancer
|
||||
// 健康检查器
|
||||
HealthChecker *healthcheck.HealthChecker
|
||||
// 监控指标
|
||||
Metrics metrics.Metrics
|
||||
// 客户端跟踪
|
||||
ClientTrace *httptrace.ClientTrace
|
||||
}
|
||||
```
|
||||
|
||||
### 函数式选项模式
|
||||
|
||||
GoProxy 现在支持函数式选项模式(Functional Options Pattern),通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于:
|
||||
|
||||
- 参数配置更加直观和清晰
|
||||
- 可以灵活选择需要的配置项,不必记忆参数顺序
|
||||
- 代码可读性更高,便于维护
|
||||
- 可以逐步添加新的配置选项而不破坏兼容性
|
||||
|
||||
可以使用 `NewProxy` 函数和函数式选项创建代理:
|
||||
|
||||
```go
|
||||
// 创建一个简单的代理
|
||||
proxy := proxy.NewProxy()
|
||||
|
||||
// 创建一个功能丰富的代理
|
||||
proxy := proxy.NewProxy(
|
||||
proxy.WithConfig(config.DefaultConfig()),
|
||||
proxy.WithHTTPCache(myCache),
|
||||
proxy.WithDecryptHTTPS(myCertCache),
|
||||
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
proxy.WithMetrics(myMetrics),
|
||||
proxy.WithLoadBalancer(myLoadBalancer),
|
||||
proxy.WithRequestTimeout(10 * time.Second),
|
||||
proxy.WithEnableCORS(true)
|
||||
)
|
||||
```
|
||||
|
||||
### 可用的 With 方法
|
||||
|
||||
GoProxy 提供了以下 With 方法用于配置代理的各个方面:
|
||||
|
||||
#### 基础配置选项
|
||||
- `WithConfig(cfg *config.Config)`: 设置代理配置
|
||||
- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用
|
||||
- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输
|
||||
- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪
|
||||
|
||||
#### 功能模块选项
|
||||
- `WithDelegate(delegate Delegate)`: 设置委托类
|
||||
- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存
|
||||
- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器
|
||||
- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器
|
||||
- `WithMetrics(m metrics.Metrics)`: 设置监控指标
|
||||
|
||||
#### 功能开启选项
|
||||
- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS
|
||||
- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成(默认使用RSA)
|
||||
- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截
|
||||
- `WithReverseProxy(enable bool)`: 启用反向代理模式
|
||||
- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试
|
||||
- `WithRateLimit(rps float64)`: 设置请求限流
|
||||
- `WithURLRewrite(enable bool)`: 启用URL重写
|
||||
- `WithEnableCORS(enable bool)`: 启用CORS支持
|
||||
|
||||
#### 证书相关选项
|
||||
- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥
|
||||
- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥
|
||||
|
||||
#### 性能和超时相关选项
|
||||
- `WithConnectionPoolSize(size int)`: 设置连接池大小
|
||||
- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间
|
||||
- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间
|
||||
- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL
|
||||
|
||||
### 配置HTTPS解密
|
||||
|
||||
要启用HTTPS解密功能,需要在配置中设置以下选项:
|
||||
|
||||
```go
|
||||
config := &config.Config{
|
||||
// 启用HTTPS解密
|
||||
DecryptHTTPS: true,
|
||||
|
||||
// 方式一:使用CA证书和私钥动态生成证书
|
||||
CACert: "path/to/ca.crt", // CA证书路径
|
||||
CAKey: "path/to/ca.key", // CA私钥路径
|
||||
|
||||
// 选择证书生成算法(可选)
|
||||
UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 方式二:使用固定的TLS证书和私钥
|
||||
// TLSCert: "path/to/server.crt",
|
||||
// TLSKey: "path/to/server.key",
|
||||
}
|
||||
```
|
||||
|
||||
或者使用函数式选项模式:
|
||||
|
||||
```go
|
||||
proxy := proxy.NewProxy(
|
||||
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
|
||||
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
|
||||
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
|
||||
// 或者使用静态TLS证书
|
||||
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
|
||||
)
|
||||
```
|
||||
|
||||
同时,建议配置证书缓存以提高性能:
|
||||
|
||||
```go
|
||||
certCache := &proxy.MemCertCache{}
|
||||
```
|
||||
|
||||
## 扩展点
|
||||
|
||||
GoProxy提供了多个扩展点,可以通过实现相应的接口进行扩展:
|
||||
|
||||
- **Delegate**:代理委托接口,用于自定义代理处理逻辑
|
||||
- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法
|
||||
- **Cache**:缓存接口,用于实现自定义缓存策略
|
||||
- **CertificateCache**:证书缓存接口,用于自定义证书存储方式
|
||||
- **Metrics**:监控接口,用于实现自定义监控指标收集
|
||||
|
||||
## 反向代理特性
|
||||
|
||||
GoProxy的反向代理模式提供以下特性:
|
||||
|
||||
- **URL重写**:支持基于前缀和正则表达式的URL重写
|
||||
- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则
|
||||
- **请求修改**:支持修改发往后端服务器的请求
|
||||
- **响应修改**:支持修改来自后端服务器的响应
|
||||
- **保留客户端信息**:支持添加X-Forwarded-For和X-Real-IP头
|
||||
- **CORS支持**:支持自动添加CORS头
|
||||
- **WebSocket支持**:支持WebSocket协议的透明代理
|
||||
- **负载均衡**:支持多种负载均衡算法
|
||||
- **健康检查**:支持对后端服务器进行健康检查
|
||||
- **监控指标**:支持收集反向代理的监控指标
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎贡献代码、报告问题或提出建议。请遵循以下步骤:
|
||||
|
||||
1. Fork 项目
|
||||
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
|
||||
3. 提交更改 (`git commit -m 'Add some amazing feature'`)
|
||||
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||
5. 创建 Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。
|
73
SUMMARY.md
Normal file
73
SUMMARY.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# GoProxy 项目总结
|
||||
|
||||
## 项目概述
|
||||
|
||||
GoProxy 是一个功能强大的 Go 语言 HTTP 代理库,支持 HTTP、HTTPS 和 WebSocket 代理,并提供了丰富的功能和扩展点。该项目采用模块化设计,各个模块之间职责明确,耦合度低,便于扩展和维护。
|
||||
|
||||
## 已完成的模块
|
||||
|
||||
1. **配置模块(config)**:提供代理配置选项,包括连接池大小、超时时间、是否启用缓存、负载均衡等。
|
||||
|
||||
2. **代理上下文(context)**:保存请求上下文信息,用于在代理处理过程中传递数据,包括原始请求、目标地址、上级代理等。
|
||||
|
||||
3. **代理委托(delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑,包括连接、认证、请求前、响应前等事件。
|
||||
|
||||
4. **连接缓冲区(conn_buffer)**:封装网络连接,提供缓冲读写功能,简化网络 IO 操作。
|
||||
|
||||
5. **负载均衡(loadbalance)**:实现负载均衡算法,支持轮询、随机、权重等,自动选择合适的上游服务器。
|
||||
|
||||
6. **健康检查(healthcheck)**:检查上游服务器的健康状态,自动剔除不健康的服务器,提高代理可靠性。
|
||||
|
||||
7. **缓存(cache)**:实现 HTTP 缓存,减少重复请求,提高代理性能。
|
||||
|
||||
8. **缓存适配器(cache_adapter)**:统一不同缓存实现的接口,使用适配器模式提高代码可读性和执行效率,支持多种缓存实现方式。
|
||||
|
||||
9. **证书生成(cert_generator)**:动态生成TLS证书,支持HTTPS解密(中间人模式),可基于CA证书和私钥创建域名证书,并支持证书缓存。支持RSA和ECDSA两种算法,用户可根据需要选择安全性和性能的平衡。
|
||||
|
||||
10. **限流(ratelimit)**:实现请求限流,防止过载,保护上游服务器。
|
||||
|
||||
11. **监控(metrics)**:收集代理运行指标,用于监控和分析,包括请求数、响应时间、错误数等。
|
||||
|
||||
12. **重试(retry)**:实现请求重试,提高请求成功率,处理临时性故障。
|
||||
|
||||
13. **代理核心(proxy)**:处理 HTTP 请求和响应,实现代理功能,包括 HTTP、HTTPS 和 WebSocket 代理。
|
||||
|
||||
## 项目特点
|
||||
|
||||
1. **模块化设计**:各个模块职责明确,耦合度低,便于扩展和维护。
|
||||
|
||||
2. **丰富的功能**:支持 HTTP、HTTPS 和 WebSocket 代理,并提供了负载均衡、健康检查、缓存、限流、监控等功能。
|
||||
|
||||
3. **灵活的扩展点**:提供了多个扩展点,可以通过实现相应的接口进行扩展,如代理委托、负载均衡、缓存、监控等。
|
||||
|
||||
4. **高性能**:采用 Go 语言的并发特性,实现高性能的代理服务。优化的缓存适配器和证书缓存机制进一步提升性能。
|
||||
|
||||
5. **可靠性**:通过健康检查、重试等机制,提高代理的可靠性。
|
||||
|
||||
6. **可观测性**:通过监控指标收集,提高代理的可观测性。
|
||||
|
||||
7. **安全性**:支持HTTPS解密功能,可用于调试、安全审计和内容过滤。提供RSA和ECDSA双算法选择,满足不同的安全需求和性能场景。
|
||||
|
||||
8. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持,以及对更多密钥算法(如Ed25519等)的支持。
|
||||
|
||||
## 使用示例
|
||||
|
||||
我们提供了一个完整的示例程序,展示了如何使用 GoProxy 库创建一个功能完善的代理服务器,包括负载均衡、健康检查、缓存、监控等功能。另外还提供了启用HTTPS解密功能的示例,展示如何使用CA证书动态生成站点证书。
|
||||
|
||||
## 未来计划
|
||||
|
||||
1. **完善文档**:编写更详细的文档,包括 API 文档、使用示例、最佳实践等。
|
||||
|
||||
2. **增加测试**:增加单元测试和集成测试,提高代码质量和可靠性。
|
||||
|
||||
3. **性能优化**:进一步优化代理性能,减少资源消耗。
|
||||
|
||||
4. **增加更多功能**:如请求过滤、内容修改、安全检查等。
|
||||
|
||||
5. **提供更多扩展点**:如请求路由、请求转换等。
|
||||
|
||||
6. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持。
|
||||
|
||||
## 总结
|
||||
|
||||
GoProxy 是一个功能强大、设计良好的 Go 语言 HTTP 代理库,可以满足各种代理需求,如开发调试、负载均衡、API 网关等。通过模块化设计和丰富的扩展点,GoProxy 可以灵活地适应各种场景,是构建代理服务的理想选择。最近的增强包括缓存适配器优化和完整的HTTPS解密功能实现,使GoProxy更加完善和高效。
|
165
cmd/example/main.go
Normal file
165
cmd/example/main.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/cache"
|
||||
"github.com/goproxy/internal/config"
|
||||
"github.com/goproxy/internal/healthcheck"
|
||||
"github.com/goproxy/internal/loadbalance"
|
||||
"github.com/goproxy/internal/metrics"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
// 监听地址
|
||||
addr = flag.String("addr", ":8080", "代理服务器监听地址")
|
||||
// 上游代理服务器
|
||||
upstream = flag.String("upstream", "", "上游代理服务器地址,多个地址用逗号分隔")
|
||||
// 是否启用负载均衡
|
||||
enableLoadBalance = flag.Bool("enable-lb", false, "是否启用负载均衡")
|
||||
// 是否启用健康检查
|
||||
enableHealthCheck = flag.Bool("enable-hc", false, "是否启用健康检查")
|
||||
// 是否启用缓存
|
||||
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
|
||||
// 是否启用重试
|
||||
enableRetry = flag.Bool("enable-retry", false, "是否启用重试")
|
||||
// 是否启用监控
|
||||
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
|
||||
// 监控地址
|
||||
metricsAddr = flag.String("metrics-addr", ":8081", "监控服务器监听地址")
|
||||
)
|
||||
|
||||
// 解析目标地址
|
||||
func parseTargets(targets string) []string {
|
||||
if targets == "" {
|
||||
return nil
|
||||
}
|
||||
return strings.Split(targets, ",")
|
||||
}
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
flag.Parse()
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.EnableLoadBalancing = *enableLoadBalance
|
||||
cfg.EnableHealthCheck = *enableHealthCheck
|
||||
cfg.EnableCache = *enableCache
|
||||
cfg.EnableRetry = *enableRetry
|
||||
|
||||
// 创建选项
|
||||
opts := &proxy.Options{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
// 创建负载均衡器
|
||||
if *enableLoadBalance && *upstream != "" {
|
||||
lb := loadbalance.NewRoundRobinBalancer()
|
||||
for _, target := range parseTargets(*upstream) {
|
||||
lb.Add(target, 1)
|
||||
}
|
||||
opts.LoadBalancer = lb
|
||||
}
|
||||
|
||||
// 创建健康检查器
|
||||
if *enableHealthCheck && opts.LoadBalancer != nil {
|
||||
hc := healthcheck.NewHealthChecker(&healthcheck.Config{
|
||||
Interval: time.Second * 10,
|
||||
Timeout: time.Second * 2,
|
||||
MaxFails: 3,
|
||||
MinSuccess: 2,
|
||||
})
|
||||
opts.HealthChecker = hc
|
||||
}
|
||||
|
||||
// 创建缓存
|
||||
if *enableCache {
|
||||
c := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000)
|
||||
opts.HTTPCache = c
|
||||
}
|
||||
|
||||
// 创建监控
|
||||
if *enableMetrics {
|
||||
m := metrics.NewSimpleMetrics()
|
||||
opts.Metrics = m
|
||||
|
||||
// 启动监控服务器
|
||||
go func() {
|
||||
mux := http.NewServeMux()
|
||||
handler := m.GetHandler()
|
||||
mux.Handle("/metrics", handler)
|
||||
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
|
||||
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{}
|
||||
opts.Delegate = delegate
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(opts)
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: *addr,
|
||||
Handler: p,
|
||||
}
|
||||
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
log.Printf("代理服务器启动在 %s\n", *addr)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("正在关闭代理服务器...")
|
||||
server.Close()
|
||||
log.Println("代理服务器已关闭")
|
||||
}
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
proxy.DefaultDelegate
|
||||
}
|
||||
|
||||
// Connect 连接事件
|
||||
func (d *CustomDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
|
||||
log.Printf("收到连接: %s -> %s\n", ctx.Req.RemoteAddr, ctx.Req.URL.Host)
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
log.Printf("响应错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
// ErrorLog 错误日志
|
||||
func (d *CustomDelegate) ErrorLog(err error) {
|
||||
log.Printf("错误: %v\n", err)
|
||||
}
|
185
cmd/reverse_proxy_example/main.go
Normal file
185
cmd/reverse_proxy_example/main.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/config"
|
||||
"github.com/goproxy/internal/metrics"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
// 监听地址
|
||||
addr = flag.String("addr", ":8080", "反向代理服务器监听地址")
|
||||
// 后端服务器
|
||||
backend = flag.String("backend", "localhost:8081", "后端服务器地址")
|
||||
// 路由规则文件
|
||||
routeFile = flag.String("route-file", "", "路由规则文件路径")
|
||||
// 是否启用URL重写
|
||||
enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写")
|
||||
// 是否启用缓存
|
||||
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
|
||||
// 是否启用压缩
|
||||
enableCompression = flag.Bool("enable-compression", false, "是否启用压缩")
|
||||
// 是否启用监控
|
||||
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
|
||||
// 监控地址
|
||||
metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址")
|
||||
// 是否添加X-Forwarded-For
|
||||
addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头")
|
||||
// 是否添加X-Real-IP
|
||||
addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头")
|
||||
// 是否启用CORS
|
||||
enableCORS = flag.Bool("enable-cors", false, "是否启用CORS")
|
||||
// 路径前缀
|
||||
pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除")
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
flag.Parse()
|
||||
|
||||
// 创建配置
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.ReverseProxy = true
|
||||
cfg.EnableCache = *enableCache
|
||||
cfg.EnableCompression = *enableCompression
|
||||
cfg.EnableURLRewrite = *enableRewrite
|
||||
cfg.AddXForwardedFor = *addXForwardedFor
|
||||
cfg.AddXRealIP = *addXRealIP
|
||||
cfg.EnableCORS = *enableCORS
|
||||
cfg.ReverseProxyRulesFile = *routeFile
|
||||
|
||||
// 创建选项
|
||||
opts := &proxy.Options{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
// 创建监控
|
||||
if *enableMetrics {
|
||||
m := metrics.NewSimpleMetrics()
|
||||
opts.Metrics = m
|
||||
|
||||
// 启动监控服务器
|
||||
go func() {
|
||||
mux := http.NewServeMux()
|
||||
handler := m.GetHandler()
|
||||
mux.Handle("/metrics", handler)
|
||||
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
|
||||
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
|
||||
log.Fatalf("监控服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &ReverseProxyDelegate{
|
||||
backend: *backend,
|
||||
prefix: *pathPrefix,
|
||||
}
|
||||
opts.Delegate = delegate
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(opts)
|
||||
|
||||
// 如果有路径前缀,添加重写规则
|
||||
if *pathPrefix != "" {
|
||||
reverseProxy := p.NewReverseProxy()
|
||||
log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix)
|
||||
reverseProxy.AddRewriteRule(*pathPrefix, "", false)
|
||||
}
|
||||
|
||||
// 创建HTTP服务器
|
||||
server := &http.Server{
|
||||
Addr: *addr,
|
||||
Handler: p,
|
||||
}
|
||||
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
log.Printf("反向代理服务器启动在 %s,后端服务器为 %s\n", *addr, *backend)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 等待信号
|
||||
quit := make(chan os.Signal, 1)
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-quit
|
||||
|
||||
log.Println("正在关闭代理服务器...")
|
||||
server.Close()
|
||||
log.Println("代理服务器已关闭")
|
||||
}
|
||||
|
||||
// ReverseProxyDelegate 反向代理委托
|
||||
type ReverseProxyDelegate struct {
|
||||
proxy.DefaultDelegate
|
||||
backend string
|
||||
prefix string
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
// 这里可以实现基于请求路径、主机名等的路由逻辑
|
||||
return d.backend, nil
|
||||
}
|
||||
|
||||
// ModifyRequest 修改请求
|
||||
func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) {
|
||||
// 移除路径前缀
|
||||
if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) {
|
||||
req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix)
|
||||
if req.URL.Path == "" {
|
||||
req.URL.Path = "/"
|
||||
}
|
||||
}
|
||||
|
||||
// 添加自定义请求头
|
||||
req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// ModifyResponse 修改响应
|
||||
func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error {
|
||||
// 添加自定义响应头
|
||||
resp.Header.Set("X-Proxied-By", "GoProxy")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Connect 连接事件
|
||||
func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
|
||||
log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path)
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||
log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path)
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
log.Printf("响应错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
// ErrorLog 错误日志
|
||||
func (d *ReverseProxyDelegate) ErrorLog(err error) {
|
||||
log.Printf("错误: %v\n", err)
|
||||
}
|
||||
|
||||
// HandleError 处理错误
|
||||
func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
log.Printf("处理错误: %v\n", err)
|
||||
http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
132
delegate.go
Normal file
132
delegate.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2018 ouqiang authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Context 代理上下文
|
||||
type Context struct {
|
||||
Req *http.Request
|
||||
Data map[interface{}]interface{}
|
||||
TunnelProxy bool
|
||||
abort bool
|
||||
}
|
||||
|
||||
func (c *Context) IsHTTPS() bool {
|
||||
return c.Req.URL.Scheme == "https"
|
||||
}
|
||||
|
||||
var defaultPorts = map[string]string{
|
||||
"https": "443",
|
||||
"http": "80",
|
||||
"": "80",
|
||||
}
|
||||
|
||||
func (c *Context) WebsocketUrl() *url.URL {
|
||||
u := new(url.URL)
|
||||
*u = *c.Req.URL
|
||||
if c.IsHTTPS() {
|
||||
u.Scheme = "wss"
|
||||
} else {
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (c *Context) Addr() string {
|
||||
addr := c.Req.Host
|
||||
|
||||
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// Abort 中断执行
|
||||
func (c *Context) Abort() {
|
||||
c.abort = true
|
||||
}
|
||||
|
||||
// IsAborted 是否已中断执行
|
||||
func (c *Context) IsAborted() bool {
|
||||
return c.abort
|
||||
}
|
||||
|
||||
// Reset 重置
|
||||
func (c *Context) Reset(req *http.Request) {
|
||||
c.Req = req
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
c.abort = false
|
||||
c.TunnelProxy = false
|
||||
}
|
||||
|
||||
type Delegate interface {
|
||||
// Connect 收到客户端连接
|
||||
Connect(ctx *Context, rw http.ResponseWriter)
|
||||
// Auth 代理身份认证
|
||||
Auth(ctx *Context, rw http.ResponseWriter)
|
||||
// BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body
|
||||
BeforeRequest(ctx *Context)
|
||||
// BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code
|
||||
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||
// WebSocketSendMessage websocket发送消息
|
||||
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
// WebSockerReceiveMessage websocket接收 消息
|
||||
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
// ParentProxy 上级代理
|
||||
ParentProxy(*http.Request) (*url.URL, error)
|
||||
// Finish 本次请求结束
|
||||
Finish(ctx *Context)
|
||||
// 记录错误信息
|
||||
ErrorLog(err error)
|
||||
}
|
||||
|
||||
var _ Delegate = &DefaultDelegate{}
|
||||
|
||||
// DefaultDelegate 默认Handler什么也不做
|
||||
type DefaultDelegate struct {
|
||||
Delegate
|
||||
}
|
||||
|
||||
func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
func (h *DefaultDelegate) BeforeRequest(ctx *Context) {}
|
||||
|
||||
func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {}
|
||||
|
||||
func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||
return http.ProxyFromEnvironment(req)
|
||||
}
|
||||
|
||||
// WebSocketSendMessage websocket发送消息
|
||||
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
// WebSockerReceiveMessage websocket接收 消息
|
||||
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
func (h *DefaultDelegate) Finish(ctx *Context) {}
|
||||
|
||||
func (h *DefaultDelegate) ErrorLog(err error) {
|
||||
log.Println(err)
|
||||
}
|
10
go.mod
Normal file
10
go.mod
Normal file
@@ -0,0 +1,10 @@
|
||||
module github.com/goproxy
|
||||
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/ouqiang/goproxy v1.3.2
|
||||
github.com/ouqiang/websocket v1.6.2
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
|
||||
golang.org/x/time v0.11.0
|
||||
)
|
8
go.sum
Normal file
8
go.sum
Normal file
@@ -0,0 +1,8 @@
|
||||
github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk=
|
||||
github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww=
|
||||
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
|
||||
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
|
||||
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
220
internal/cache/cache.go
vendored
Normal file
220
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,220 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache 缓存接口
|
||||
type Cache interface {
|
||||
// Get 获取缓存
|
||||
Get(key string) (*http.Response, bool)
|
||||
// Set 设置缓存
|
||||
Set(key string, resp *http.Response)
|
||||
// Delete 删除缓存
|
||||
Delete(key string)
|
||||
// Clear 清空缓存
|
||||
Clear()
|
||||
}
|
||||
|
||||
// MemoryCache 内存缓存实现
|
||||
type MemoryCache struct {
|
||||
// 缓存内容
|
||||
items sync.Map
|
||||
// 过期时间
|
||||
ttl time.Duration
|
||||
// 清理间隔
|
||||
cleanupInterval time.Duration
|
||||
// 最大条目数
|
||||
maxEntries int
|
||||
// 当前条目数
|
||||
size int32
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// CacheItem 缓存项
|
||||
type CacheItem struct {
|
||||
response *http.Response
|
||||
responseBody []byte
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
// NewMemoryCache 创建内存缓存
|
||||
func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache {
|
||||
cache := &MemoryCache{
|
||||
ttl: ttl,
|
||||
cleanupInterval: cleanupInterval,
|
||||
maxEntries: maxEntries,
|
||||
}
|
||||
|
||||
// 启动过期清理
|
||||
if cleanupInterval > 0 {
|
||||
go cache.startCleanup()
|
||||
}
|
||||
|
||||
return cache
|
||||
}
|
||||
|
||||
// Get 获取缓存
|
||||
func (c *MemoryCache) Get(key string) (*http.Response, bool) {
|
||||
value, ok := c.items.Load(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item := value.(*CacheItem)
|
||||
if time.Now().After(item.expiry) {
|
||||
c.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 克隆响应,避免修改原始数据
|
||||
resp := cloneResponse(item.response, item.responseBody)
|
||||
return resp, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *MemoryCache) Set(key string, resp *http.Response) {
|
||||
// 检查缓存是否已满
|
||||
c.mu.Lock()
|
||||
if c.maxEntries > 0 && c.size >= int32(c.maxEntries) {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.size++
|
||||
c.mu.Unlock()
|
||||
|
||||
// 读取并保存响应体
|
||||
var bodyBytes []byte
|
||||
if resp.Body != nil {
|
||||
bodyBytes, _ = io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
item := &CacheItem{
|
||||
response: resp,
|
||||
responseBody: bodyBytes,
|
||||
expiry: time.Now().Add(c.ttl),
|
||||
}
|
||||
|
||||
c.items.Store(key, item)
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (c *MemoryCache) Delete(key string) {
|
||||
c.items.Delete(key)
|
||||
c.mu.Lock()
|
||||
c.size--
|
||||
if c.size < 0 {
|
||||
c.size = 0
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Clear 清空缓存
|
||||
func (c *MemoryCache) Clear() {
|
||||
c.items = sync.Map{}
|
||||
c.mu.Lock()
|
||||
c.size = 0
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// startCleanup 启动过期清理
|
||||
func (c *MemoryCache) startCleanup() {
|
||||
ticker := time.NewTicker(c.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
c.items.Range(func(key, value interface{}) bool {
|
||||
item := value.(*CacheItem)
|
||||
if now.After(item.expiry) {
|
||||
c.Delete(key.(string))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateCacheKey 生成缓存键
|
||||
func GenerateCacheKey(req *http.Request) string {
|
||||
// 忽略一些可变的头部
|
||||
ignoredHeaders := map[string]bool{
|
||||
"Connection": true,
|
||||
"Keep-Alive": true,
|
||||
"Proxy-Authenticate": true,
|
||||
"Proxy-Authorization": true,
|
||||
"TE": true,
|
||||
"Trailers": true,
|
||||
"Transfer-Encoding": true,
|
||||
"Upgrade": true,
|
||||
}
|
||||
|
||||
// 提取缓存键组件
|
||||
components := []string{
|
||||
req.Method,
|
||||
req.URL.String(),
|
||||
}
|
||||
|
||||
// 添加选择性头部
|
||||
for key, values := range req.Header {
|
||||
if !ignoredHeaders[key] {
|
||||
for _, value := range values {
|
||||
components = append(components, key+":"+value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 连接并计算哈希
|
||||
data := strings.Join(components, "|")
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(data))
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// cloneResponse 克隆HTTP响应
|
||||
func cloneResponse(resp *http.Response, body []byte) *http.Response {
|
||||
clone := *resp
|
||||
clone.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
clone.Header = make(http.Header)
|
||||
|
||||
for k, v := range resp.Header {
|
||||
clone.Header[k] = v
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// ShouldCache 判断请求是否应该缓存
|
||||
func ShouldCache(req *http.Request, resp *http.Response) bool {
|
||||
// 只缓存GET请求
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查响应状态码
|
||||
if resp.StatusCode != http.StatusOK &&
|
||||
resp.StatusCode != http.StatusNotModified &&
|
||||
resp.StatusCode != http.StatusMovedPermanently &&
|
||||
resp.StatusCode != http.StatusPermanentRedirect {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查Cache-Control头
|
||||
cacheControl := resp.Header.Get("Cache-Control")
|
||||
if strings.Contains(cacheControl, "no-store") ||
|
||||
strings.Contains(cacheControl, "no-cache") ||
|
||||
strings.Contains(cacheControl, "private") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
135
internal/config/config.go
Normal file
135
internal/config/config.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config 代理配置
|
||||
type Config struct {
|
||||
// 监听地址
|
||||
ListenAddr string
|
||||
// 是否启用负载均衡
|
||||
EnableLoadBalancing bool
|
||||
// 负载均衡后端列表
|
||||
Backends []string
|
||||
// 是否启用限流
|
||||
EnableRateLimit bool
|
||||
// 每秒请求速率限制
|
||||
RateLimit float64
|
||||
// 并发请求峰值限制
|
||||
MaxBurst int
|
||||
// 最大连接数
|
||||
MaxConnections int
|
||||
// 是否启用连接池
|
||||
EnableConnectionPool bool
|
||||
// 连接池大小
|
||||
ConnectionPoolSize int
|
||||
// 连接空闲超时时间
|
||||
IdleTimeout time.Duration
|
||||
// 请求超时时间
|
||||
RequestTimeout time.Duration
|
||||
// 是否启用响应缓存
|
||||
EnableCache bool
|
||||
// 缓存过期时间
|
||||
CacheTTL time.Duration
|
||||
// 是否启用HTTPS解密
|
||||
DecryptHTTPS bool
|
||||
// TLS证书文件路径
|
||||
TLSCert string
|
||||
// TLS密钥文件路径
|
||||
TLSKey string
|
||||
// CA证书文件路径(用于生成动态证书)
|
||||
CACert string
|
||||
// CA密钥文件路径(用于生成动态证书)
|
||||
CAKey string
|
||||
// 是否启用健康检查
|
||||
EnableHealthCheck bool
|
||||
// 健康检查间隔时间
|
||||
HealthCheckInterval time.Duration
|
||||
// 健康检查超时时间
|
||||
HealthCheckTimeout time.Duration
|
||||
// 是否启用重试机制
|
||||
EnableRetry bool
|
||||
// 最大重试次数
|
||||
MaxRetries int
|
||||
// 重试间隔基数
|
||||
RetryBackoff time.Duration
|
||||
// 最大重试间隔
|
||||
MaxRetryBackoff time.Duration
|
||||
// 是否启用监控指标
|
||||
EnableMetrics bool
|
||||
// 是否启用请求追踪
|
||||
EnableTracing bool
|
||||
// 是否拦截WebSocket
|
||||
WebSocketIntercept bool
|
||||
// DNS缓存过期时间
|
||||
DNSCacheTTL time.Duration
|
||||
// 是否作为反向代理
|
||||
ReverseProxy bool
|
||||
// 反向代理规则文件路径
|
||||
ReverseProxyRulesFile string
|
||||
// 是否启用URL重写
|
||||
EnableURLRewrite bool
|
||||
// 是否保留客户端IP
|
||||
PreserveClientIP bool
|
||||
// 是否启用压缩
|
||||
EnableCompression bool
|
||||
// 是否自动添加CORS头
|
||||
EnableCORS bool
|
||||
// 重写Host头
|
||||
RewriteHostHeader bool
|
||||
// 是否添加X-Forwarded-For头
|
||||
AddXForwardedFor bool
|
||||
// 是否添加X-Real-IP头
|
||||
AddXRealIP bool
|
||||
// 是否支持Websocket升级
|
||||
SupportWebSocketUpgrade bool
|
||||
// 是否使用ECDSA生成证书(默认使用RSA)
|
||||
UseECDSA bool
|
||||
}
|
||||
|
||||
// DefaultConfig 默认配置
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
ListenAddr: ":8080",
|
||||
EnableLoadBalancing: false,
|
||||
Backends: []string{},
|
||||
EnableRateLimit: false,
|
||||
RateLimit: 100,
|
||||
MaxBurst: 50,
|
||||
MaxConnections: 1000,
|
||||
EnableConnectionPool: true,
|
||||
ConnectionPoolSize: 100,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
EnableCache: false,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
DecryptHTTPS: false,
|
||||
TLSCert: "",
|
||||
TLSKey: "",
|
||||
CACert: "",
|
||||
CAKey: "",
|
||||
EnableHealthCheck: false,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
HealthCheckTimeout: 5 * time.Second,
|
||||
EnableRetry: true,
|
||||
MaxRetries: 3,
|
||||
RetryBackoff: 100 * time.Millisecond,
|
||||
MaxRetryBackoff: 2 * time.Second,
|
||||
EnableMetrics: false,
|
||||
EnableTracing: false,
|
||||
WebSocketIntercept: false,
|
||||
DNSCacheTTL: 5 * time.Minute,
|
||||
ReverseProxy: false,
|
||||
ReverseProxyRulesFile: "",
|
||||
EnableURLRewrite: false,
|
||||
PreserveClientIP: true,
|
||||
EnableCompression: false,
|
||||
EnableCORS: false,
|
||||
RewriteHostHeader: false,
|
||||
AddXForwardedFor: true,
|
||||
AddXRealIP: true,
|
||||
SupportWebSocketUpgrade: true,
|
||||
UseECDSA: false,
|
||||
}
|
||||
}
|
248
internal/healthcheck/healthchecker.go
Normal file
248
internal/healthcheck/healthchecker.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthChecker 健康检查器
|
||||
type HealthChecker struct {
|
||||
// 配置
|
||||
config *Config
|
||||
// 健康检查状态
|
||||
statusMap sync.Map
|
||||
// 状态变更回调
|
||||
statusChangeCallback func(string, bool)
|
||||
// 是否运行中
|
||||
running bool
|
||||
// 上下文
|
||||
ctx context.Context
|
||||
// 取消函数
|
||||
cancel context.CancelFunc
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Config 健康检查配置
|
||||
type Config struct {
|
||||
// 检查间隔
|
||||
Interval time.Duration
|
||||
// 检查超时
|
||||
Timeout time.Duration
|
||||
// 检查路径
|
||||
Path string
|
||||
// 检查方法
|
||||
Method string
|
||||
// 检查状态码
|
||||
SuccessStatus int
|
||||
// 最大失败次数
|
||||
MaxFails int
|
||||
// 最小成功次数
|
||||
MinSuccess int
|
||||
}
|
||||
|
||||
// NewHealthChecker 创建健康检查器
|
||||
func NewHealthChecker(config *Config) *HealthChecker {
|
||||
if config.Path == "" {
|
||||
config.Path = "/"
|
||||
}
|
||||
if config.Method == "" {
|
||||
config.Method = http.MethodGet
|
||||
}
|
||||
if config.SuccessStatus == 0 {
|
||||
config.SuccessStatus = http.StatusOK
|
||||
}
|
||||
if config.MaxFails == 0 {
|
||||
config.MaxFails = 3
|
||||
}
|
||||
if config.MinSuccess == 0 {
|
||||
config.MinSuccess = 2
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &HealthChecker{
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
running: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动健康检查
|
||||
func (hc *HealthChecker) Start() {
|
||||
hc.mu.Lock()
|
||||
defer hc.mu.Unlock()
|
||||
|
||||
if hc.running {
|
||||
return
|
||||
}
|
||||
|
||||
hc.running = true
|
||||
go hc.run()
|
||||
}
|
||||
|
||||
// Stop 停止健康检查
|
||||
func (hc *HealthChecker) Stop() {
|
||||
hc.mu.Lock()
|
||||
defer hc.mu.Unlock()
|
||||
|
||||
if !hc.running {
|
||||
return
|
||||
}
|
||||
|
||||
hc.cancel()
|
||||
hc.running = false
|
||||
}
|
||||
|
||||
// AddTarget 添加监控目标
|
||||
func (hc *HealthChecker) AddTarget(target string) error {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 初始化为健康状态
|
||||
hc.statusMap.Store(u.String(), &backendStatus{
|
||||
URL: u,
|
||||
Healthy: true,
|
||||
FailCount: 0,
|
||||
SuccessCount: 0,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTarget 移除监控目标
|
||||
func (hc *HealthChecker) RemoveTarget(target string) error {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hc.statusMap.Delete(u.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy 检查目标是否健康
|
||||
func (hc *HealthChecker) IsHealthy(target string) bool {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
value, ok := hc.statusMap.Load(u.String())
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
status := value.(*backendStatus)
|
||||
return status.Healthy
|
||||
}
|
||||
|
||||
// SetStatusChangeCallback 设置状态变更回调
|
||||
func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) {
|
||||
hc.statusChangeCallback = callback
|
||||
}
|
||||
|
||||
// backendStatus 后端健康状态
|
||||
type backendStatus struct {
|
||||
URL *url.URL
|
||||
Healthy bool
|
||||
FailCount int
|
||||
SuccessCount int
|
||||
}
|
||||
|
||||
// run 运行健康检查
|
||||
func (hc *HealthChecker) run() {
|
||||
ticker := time.NewTicker(hc.config.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
hc.checkAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAll 检查所有后端
|
||||
func (hc *HealthChecker) checkAll() {
|
||||
hc.statusMap.Range(func(key, value interface{}) bool {
|
||||
go hc.check(key.(string), value.(*backendStatus))
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// check 检查单个后端
|
||||
func (hc *HealthChecker) check(key string, status *backendStatus) {
|
||||
// 创建检查请求
|
||||
u := *status.URL
|
||||
u.Path = hc.config.Path
|
||||
req, err := http.NewRequest(hc.config.Method, u.String(), nil)
|
||||
if err != nil {
|
||||
hc.updateStatus(key, status, false)
|
||||
return
|
||||
}
|
||||
|
||||
// 设置超时的客户端
|
||||
client := &http.Client{
|
||||
Timeout: hc.config.Timeout,
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: hc.config.Timeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ResponseHeaderTimeout: hc.config.Timeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
hc.updateStatus(key, status, false)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode == hc.config.SuccessStatus {
|
||||
hc.updateStatus(key, status, true)
|
||||
} else {
|
||||
hc.updateStatus(key, status, false)
|
||||
}
|
||||
}
|
||||
|
||||
// updateStatus 更新后端状态
|
||||
func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) {
|
||||
if success {
|
||||
status.SuccessCount++
|
||||
status.FailCount = 0
|
||||
if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess {
|
||||
status.Healthy = true
|
||||
if hc.statusChangeCallback != nil {
|
||||
hc.statusChangeCallback(key, true)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
status.FailCount++
|
||||
status.SuccessCount = 0
|
||||
if status.Healthy && status.FailCount >= hc.config.MaxFails {
|
||||
status.Healthy = false
|
||||
if hc.statusChangeCallback != nil {
|
||||
hc.statusChangeCallback(key, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hc.statusMap.Store(key, status)
|
||||
}
|
330
internal/loadbalance/loadbalancer.go
Normal file
330
internal/loadbalance/loadbalancer.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package loadbalance
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Strategy 负载均衡策略
|
||||
type Strategy int
|
||||
|
||||
const (
|
||||
// StrategyRoundRobin 轮询策略
|
||||
StrategyRoundRobin Strategy = iota
|
||||
// StrategyRandom 随机策略
|
||||
StrategyRandom
|
||||
// StrategyWeightedRoundRobin 加权轮询策略
|
||||
StrategyWeightedRoundRobin
|
||||
// StrategyIPHash IP哈希策略
|
||||
StrategyIPHash
|
||||
)
|
||||
|
||||
// LoadBalancer 负载均衡器接口
|
||||
type LoadBalancer interface {
|
||||
// Next 获取下一个后端
|
||||
Next(key string) (*url.URL, error)
|
||||
// Add 添加后端
|
||||
Add(backend string, weight int) error
|
||||
// Remove 删除后端
|
||||
Remove(backend string) error
|
||||
// MarkDown 标记后端为不可用
|
||||
MarkDown(backend string) error
|
||||
// MarkUp 标记后端为可用
|
||||
MarkUp(backend string) error
|
||||
// Reset 重置负载均衡器
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Backend 后端服务器
|
||||
type Backend struct {
|
||||
// URL 后端URL
|
||||
URL *url.URL
|
||||
// Weight 权重
|
||||
Weight int
|
||||
// Down 是否不可用
|
||||
Down bool
|
||||
}
|
||||
|
||||
// RoundRobinBalancer 轮询负载均衡器
|
||||
type RoundRobinBalancer struct {
|
||||
backends []*Backend
|
||||
current int32
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRoundRobinBalancer 创建轮询负载均衡器
|
||||
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||
return &RoundRobinBalancer{
|
||||
backends: make([]*Backend, 0),
|
||||
current: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Next 获取下一个后端
|
||||
func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
|
||||
if len(lb.backends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 计算可用后端数量
|
||||
var availableCount int
|
||||
for _, backend := range lb.backends {
|
||||
if !backend.Down {
|
||||
availableCount++
|
||||
}
|
||||
}
|
||||
|
||||
if availableCount == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 循环直到找到可用后端
|
||||
for i := 0; i < len(lb.backends); i++ {
|
||||
idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends))
|
||||
backend := lb.backends[idx]
|
||||
if !backend.Down {
|
||||
return backend.URL, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Add 添加后端
|
||||
func (lb *RoundRobinBalancer) Add(backend string, weight int) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
lb.backends = append(lb.backends, &Backend{
|
||||
URL: url,
|
||||
Weight: weight,
|
||||
Down: false,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 删除后端
|
||||
func (lb *RoundRobinBalancer) Remove(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkDown 标记后端为不可用
|
||||
func (lb *RoundRobinBalancer) MarkDown(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUp 标记后端为可用
|
||||
func (lb *RoundRobinBalancer) MarkUp(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置负载均衡器
|
||||
func (lb *RoundRobinBalancer) Reset() error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
lb.backends = make([]*Backend, 0)
|
||||
atomic.StoreInt32(&lb.current, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RandomBalancer 随机负载均衡器
|
||||
type RandomBalancer struct {
|
||||
backends []*Backend
|
||||
mutex sync.RWMutex
|
||||
rand *rand.Rand
|
||||
}
|
||||
|
||||
// NewRandomBalancer 创建随机负载均衡器
|
||||
func NewRandomBalancer() *RandomBalancer {
|
||||
source := rand.NewSource(time.Now().UnixNano())
|
||||
random := rand.New(source)
|
||||
return &RandomBalancer{
|
||||
backends: make([]*Backend, 0),
|
||||
rand: random,
|
||||
}
|
||||
}
|
||||
|
||||
// Next 获取下一个后端
|
||||
func (lb *RandomBalancer) Next(key string) (*url.URL, error) {
|
||||
lb.mutex.RLock()
|
||||
defer lb.mutex.RUnlock()
|
||||
|
||||
if len(lb.backends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 计算可用后端数量
|
||||
var availableBackends []*Backend
|
||||
for _, backend := range lb.backends {
|
||||
if !backend.Down {
|
||||
availableBackends = append(availableBackends, backend)
|
||||
}
|
||||
}
|
||||
|
||||
if len(availableBackends) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
idx := lb.rand.Intn(len(availableBackends))
|
||||
return availableBackends[idx].URL, nil
|
||||
}
|
||||
|
||||
// Add 添加后端
|
||||
func (lb *RandomBalancer) Add(backend string, weight int) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
lb.backends = append(lb.backends, &Backend{
|
||||
URL: url,
|
||||
Weight: weight,
|
||||
Down: false,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove 删除后端
|
||||
func (lb *RandomBalancer) Remove(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkDown 标记后端为不可用
|
||||
func (lb *RandomBalancer) MarkDown(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = true
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkUp 标记后端为可用
|
||||
func (lb *RandomBalancer) MarkUp(backend string) error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
url, err := url.Parse(backend)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, b := range lb.backends {
|
||||
if b.URL.String() == url.String() {
|
||||
b.Down = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置负载均衡器
|
||||
func (lb *RandomBalancer) Reset() error {
|
||||
lb.mutex.Lock()
|
||||
defer lb.mutex.Unlock()
|
||||
|
||||
lb.backends = make([]*Backend, 0)
|
||||
|
||||
return nil
|
||||
}
|
250
internal/metrics/metrics.go
Normal file
250
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Metrics 监控指标接口
|
||||
type Metrics interface {
|
||||
// 增加请求计数
|
||||
IncRequestCount()
|
||||
// 增加错误计数
|
||||
IncErrorCount(err error)
|
||||
// 观察请求持续时间
|
||||
ObserveRequestDuration(seconds float64)
|
||||
// 增加活跃连接数
|
||||
IncActiveConnections()
|
||||
// 减少活跃连接数
|
||||
DecActiveConnections()
|
||||
// 设置后端健康状态
|
||||
SetBackendHealth(backend string, healthy bool)
|
||||
// 设置后端响应时间
|
||||
SetBackendResponseTime(backend string, duration time.Duration)
|
||||
// 观察请求字节数
|
||||
ObserveRequestBytes(bytes int64)
|
||||
// 观察响应字节数
|
||||
ObserveResponseBytes(bytes int64)
|
||||
// 添加传输字节数
|
||||
AddBytesTransferred(direction string, bytes int64)
|
||||
// 增加缓存命中计数
|
||||
IncCacheHit()
|
||||
// 获取指标处理器
|
||||
GetHandler() http.Handler
|
||||
}
|
||||
|
||||
// SimpleMetrics 简单指标实现
|
||||
type SimpleMetrics struct {
|
||||
// 请求计数
|
||||
requestCount int64
|
||||
// 错误计数
|
||||
errorCount int64
|
||||
// 活跃连接数
|
||||
activeConnections int64
|
||||
// 累计响应时间
|
||||
totalResponseTime int64
|
||||
// 传输字节数
|
||||
bytesTransferred map[string]int64
|
||||
// 后端健康状态
|
||||
backendHealth map[string]bool
|
||||
// 后端响应时间
|
||||
backendResponseTime map[string]time.Duration
|
||||
// 缓存命中计数
|
||||
cacheHits int64
|
||||
// 互斥锁
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewSimpleMetrics 创建简单指标
|
||||
func NewSimpleMetrics() *SimpleMetrics {
|
||||
return &SimpleMetrics{
|
||||
bytesTransferred: make(map[string]int64),
|
||||
backendHealth: make(map[string]bool),
|
||||
backendResponseTime: make(map[string]time.Duration),
|
||||
}
|
||||
}
|
||||
|
||||
// IncRequestCount 增加请求计数
|
||||
func (m *SimpleMetrics) IncRequestCount() {
|
||||
atomic.AddInt64(&m.requestCount, 1)
|
||||
}
|
||||
|
||||
// IncErrorCount 增加错误计数
|
||||
func (m *SimpleMetrics) IncErrorCount(err error) {
|
||||
atomic.AddInt64(&m.errorCount, 1)
|
||||
}
|
||||
|
||||
// ObserveRequestDuration 观察请求持续时间
|
||||
func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) {
|
||||
nsec := int64(seconds * float64(time.Second))
|
||||
atomic.AddInt64(&m.totalResponseTime, nsec)
|
||||
}
|
||||
|
||||
// IncActiveConnections 增加活跃连接数
|
||||
func (m *SimpleMetrics) IncActiveConnections() {
|
||||
atomic.AddInt64(&m.activeConnections, 1)
|
||||
}
|
||||
|
||||
// DecActiveConnections 减少活跃连接数
|
||||
func (m *SimpleMetrics) DecActiveConnections() {
|
||||
atomic.AddInt64(&m.activeConnections, -1)
|
||||
}
|
||||
|
||||
// SetBackendHealth 设置后端健康状态
|
||||
func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) {
|
||||
m.backendHealth[backend] = healthy
|
||||
}
|
||||
|
||||
// SetBackendResponseTime 设置后端响应时间
|
||||
func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) {
|
||||
m.backendResponseTime[backend] = duration
|
||||
}
|
||||
|
||||
// ObserveRequestBytes 观察请求字节数
|
||||
func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred["request"] += bytes
|
||||
}
|
||||
|
||||
// ObserveResponseBytes 观察响应字节数
|
||||
func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred["response"] += bytes
|
||||
}
|
||||
|
||||
// AddBytesTransferred 添加传输字节数
|
||||
func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.bytesTransferred[direction] += bytes
|
||||
}
|
||||
|
||||
// IncCacheHit 增加缓存命中计数
|
||||
func (m *SimpleMetrics) IncCacheHit() {
|
||||
atomic.AddInt64(&m.cacheHits, 1)
|
||||
}
|
||||
|
||||
// GetHandler 获取指标处理器
|
||||
func (m *SimpleMetrics) GetHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
|
||||
// 输出基本指标
|
||||
w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_requests_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount)))
|
||||
|
||||
w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_errors_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount)))
|
||||
|
||||
w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n"))
|
||||
w.Write([]byte("# TYPE proxy_active_connections gauge\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections)))
|
||||
|
||||
// 输出缓存命中数据
|
||||
w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n"))
|
||||
w.Write([]byte("# TYPE proxy_cache_hits_total counter\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits)))
|
||||
|
||||
// 输出传输字节数
|
||||
for direction, bytes := range m.bytesTransferred {
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction)))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction)))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes)))
|
||||
}
|
||||
|
||||
// 输出后端健康状态
|
||||
for backend, healthy := range m.backendHealth {
|
||||
healthValue := 0
|
||||
if healthy {
|
||||
healthValue = 1
|
||||
}
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n")))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n")))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue)))
|
||||
}
|
||||
|
||||
// 输出后端响应时间
|
||||
for backend, duration := range m.backendResponseTime {
|
||||
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n")))
|
||||
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n")))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second))))
|
||||
}
|
||||
|
||||
// 平均响应时间
|
||||
if m.requestCount > 0 {
|
||||
avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second)
|
||||
w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n"))
|
||||
w.Write([]byte("# TYPE proxy_average_response_time gauge\n"))
|
||||
w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// PrometheusMetrics Prometheus指标实现
|
||||
type PrometheusMetrics struct {
|
||||
// 可以通过引入prometheus客户端库实现更完整的指标收集
|
||||
// 此处省略具体实现
|
||||
}
|
||||
|
||||
// MetricsMiddleware 指标中间件
|
||||
type MetricsMiddleware struct {
|
||||
metrics Metrics
|
||||
}
|
||||
|
||||
// NewMetricsMiddleware 创建指标中间件
|
||||
func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware {
|
||||
return &MetricsMiddleware{
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// 包装响应写入器,用于捕获状态码
|
||||
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
// 继续处理请求
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// 记录请求指标
|
||||
duration := time.Since(start)
|
||||
m.metrics.ObserveRequestDuration(duration.Seconds())
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriter 包装的响应写入器
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written int64
|
||||
}
|
||||
|
||||
// WriteHeader 写入状态码
|
||||
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||
rw.statusCode = statusCode
|
||||
rw.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Write 写入数据
|
||||
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||
n, err := rw.ResponseWriter.Write(b)
|
||||
rw.written += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush 刷新数据
|
||||
func (rw *responseWriter) Flush() {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
184
internal/middleware/ratelimiter.go
Normal file
184
internal/middleware/ratelimiter.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter 限流器接口
|
||||
type RateLimiter interface {
|
||||
// Allow 检查请求是否允许通过
|
||||
Allow(key string) bool
|
||||
}
|
||||
|
||||
// SimpleRateLimiter 简单限流器
|
||||
type SimpleRateLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
// NewSimpleRateLimiter 创建简单限流器
|
||||
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
|
||||
return &SimpleRateLimiter{
|
||||
limiter: rate.NewLimiter(rate.Limit(r), b),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查请求是否允许通过
|
||||
func (rl *SimpleRateLimiter) Allow(key string) bool {
|
||||
return rl.limiter.Allow()
|
||||
}
|
||||
|
||||
// IPRateLimiter 按IP限流
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
cleanupInterval time.Duration
|
||||
lastSeen map[string]time.Time
|
||||
}
|
||||
|
||||
// NewIPRateLimiter 创建IP限流器
|
||||
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
|
||||
limiter := &IPRateLimiter{
|
||||
ips: make(map[string]*rate.Limiter),
|
||||
rate: rate.Limit(r),
|
||||
burst: b,
|
||||
cleanupInterval: cleanup,
|
||||
lastSeen: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// 启动过期清理
|
||||
if cleanup > 0 {
|
||||
go limiter.startCleanup()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// startCleanup 启动过期清理
|
||||
func (rl *IPRateLimiter) startCleanup() {
|
||||
ticker := time.NewTicker(rl.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理过期限流器
|
||||
func (rl *IPRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for ip, lastSeen := range rl.lastSeen {
|
||||
if now.Sub(lastSeen) > rl.cleanupInterval {
|
||||
delete(rl.ips, ip)
|
||||
delete(rl.lastSeen, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddIP 添加IP限流器
|
||||
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.ips[ip] = limiter
|
||||
rl.lastSeen[ip] = time.Now()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// GetLimiter 获取IP限流器
|
||||
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.ips[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return rl.AddIP(ip)
|
||||
}
|
||||
|
||||
// 更新最后访问时间
|
||||
rl.mu.Lock()
|
||||
rl.lastSeen[ip] = time.Now()
|
||||
rl.mu.Unlock()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Allow 检查请求是否允许通过
|
||||
func (rl *IPRateLimiter) Allow(ip string) bool {
|
||||
limiter := rl.GetLimiter(ip)
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
limiter RateLimiter
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
limiter: limiter,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 获取客户端IP
|
||||
ip := getClientIP(r)
|
||||
|
||||
// 检查是否允许通过
|
||||
if !m.limiter.Allow(ip) {
|
||||
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 继续处理请求
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查 X-Forwarded-For 头
|
||||
ip := r.Header.Get("X-Forwarded-For")
|
||||
if ip != "" {
|
||||
// 取第一个IP
|
||||
for i := 0; i < len(ip) && i < 15; i++ {
|
||||
if ip[i] == ',' {
|
||||
ip = ip[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头
|
||||
ip = r.Header.Get("X-Real-IP")
|
||||
if ip != "" {
|
||||
return ip
|
||||
}
|
||||
|
||||
// 从 RemoteAddr 获取
|
||||
if r.RemoteAddr != "" {
|
||||
// 去掉端口部分
|
||||
for i := 0; i < len(r.RemoteAddr); i++ {
|
||||
if r.RemoteAddr[i] == ':' {
|
||||
return r.RemoteAddr[:i]
|
||||
}
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
156
internal/middleware/retry.go
Normal file
156
internal/middleware/retry.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RetryPolicy 重试策略
|
||||
type RetryPolicy struct {
|
||||
// 最大重试次数
|
||||
MaxRetries int
|
||||
// 基础退避时间
|
||||
BaseBackoff time.Duration
|
||||
// 最大退避时间
|
||||
MaxBackoff time.Duration
|
||||
// 重试判断函数
|
||||
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
|
||||
}
|
||||
|
||||
// DefaultRetryPolicy 默认重试策略
|
||||
func DefaultRetryPolicy() *RetryPolicy {
|
||||
return &RetryPolicy{
|
||||
MaxRetries: 3,
|
||||
BaseBackoff: 100 * time.Millisecond,
|
||||
MaxBackoff: 2 * time.Second,
|
||||
ShouldRetry: defaultShouldRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// defaultShouldRetry 默认重试判断
|
||||
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
|
||||
// 不重试非幂等请求
|
||||
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if err != nil {
|
||||
// 重试网络错误
|
||||
if netErr, ok := err.(net.Error); ok {
|
||||
return netErr.Temporary() || netErr.Timeout()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查响应状态码
|
||||
if resp != nil {
|
||||
// 重试服务器错误
|
||||
return resp.StatusCode >= 500 && resp.StatusCode < 600
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryRoundTripper 重试HTTP传输
|
||||
type RetryRoundTripper struct {
|
||||
// 下一级传输
|
||||
Next http.RoundTripper
|
||||
// 重试策略
|
||||
Policy *RetryPolicy
|
||||
}
|
||||
|
||||
// NewRetryRoundTripper 创建重试HTTP传输
|
||||
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
|
||||
if next == nil {
|
||||
next = http.DefaultTransport
|
||||
}
|
||||
if policy == nil {
|
||||
policy = DefaultRetryPolicy()
|
||||
}
|
||||
return &RetryRoundTripper{
|
||||
Next: next,
|
||||
Policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip 执行HTTP请求
|
||||
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// 需要保留原始请求体,以便重试
|
||||
var reqBodyBytes []byte
|
||||
if req.Body != nil {
|
||||
var err error
|
||||
reqBodyBytes, err = io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Body.Close()
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
// 尝试请求直到成功或达到最大重试次数
|
||||
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
|
||||
// 复制请求体
|
||||
if len(reqBodyBytes) > 0 {
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = rt.Next.RoundTrip(req)
|
||||
|
||||
// 检查是否需要重试
|
||||
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
|
||||
// 如果需要重试,先关闭当前响应
|
||||
if resp != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// 计算退避时间
|
||||
backoff := rt.calculateBackoff(attempt)
|
||||
time.Sleep(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// 不需要重试,返回响应
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// 所有重试都失败
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// calculateBackoff 计算退避时间
|
||||
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
|
||||
// 指数退避: baseBackoff * 2^attempt
|
||||
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
|
||||
if backoff > rt.Policy.MaxBackoff {
|
||||
backoff = rt.Policy.MaxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
|
||||
// RetryMiddleware 重试中间件
|
||||
type RetryMiddleware struct {
|
||||
policy *RetryPolicy
|
||||
}
|
||||
|
||||
// NewRetryMiddleware 创建重试中间件
|
||||
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
|
||||
if policy == nil {
|
||||
policy = DefaultRetryPolicy()
|
||||
}
|
||||
return &RetryMiddleware{
|
||||
policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 中间件处理函数
|
||||
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
|
||||
return NewRetryRoundTripper(next, m.policy)
|
||||
}
|
552
internal/proxy/certificate.go
Normal file
552
internal/proxy/certificate.go
Normal file
@@ -0,0 +1,552 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 内置默认CA证书和私钥
|
||||
var (
|
||||
// 默认根证书
|
||||
defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
|
||||
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
|
||||
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
|
||||
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
|
||||
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
|
||||
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
|
||||
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
|
||||
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
|
||||
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
|
||||
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
|
||||
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
|
||||
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
// 默认根私钥
|
||||
defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49
|
||||
AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/
|
||||
nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA==
|
||||
-----END EC PRIVATE KEY-----
|
||||
`)
|
||||
)
|
||||
|
||||
// 加载和初始化默认根证书
|
||||
var (
|
||||
defaultRootCA *x509.Certificate
|
||||
defaultRootKey *ecdsa.PrivateKey
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
block, _ := pem.Decode(defaultRootCAPem)
|
||||
if block == nil {
|
||||
panic("解析默认根证书PEM块失败")
|
||||
}
|
||||
defaultRootCA, err = x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("加载默认根证书失败: %s", err))
|
||||
}
|
||||
|
||||
block, _ = pem.Decode(defaultRootKeyPem)
|
||||
if block == nil {
|
||||
panic("解析默认根私钥PEM块失败")
|
||||
}
|
||||
defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("加载默认根私钥失败: %s", err))
|
||||
}
|
||||
}
|
||||
|
||||
// CertManager 证书管理器
|
||||
type CertManager struct {
|
||||
// 证书缓存
|
||||
cache CertificateCache
|
||||
// 默认私钥,可用于多个证书共享
|
||||
defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥
|
||||
// 默认使用ECDSA P-256曲线
|
||||
curve elliptic.Curve
|
||||
// 证书有效期(年)
|
||||
validityYears int
|
||||
// 是否使用ECDSA(否则使用RSA)
|
||||
useECDSA bool
|
||||
}
|
||||
|
||||
// NewCertManager 创建证书管理器
|
||||
func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager {
|
||||
manager := &CertManager{
|
||||
cache: cache,
|
||||
curve: elliptic.P256(), // 默认使用P-256曲线
|
||||
validityYears: 1, // 默认证书有效期1年
|
||||
useECDSA: true, // 默认使用ECDSA
|
||||
}
|
||||
|
||||
// 应用选项
|
||||
for _, option := range options {
|
||||
option(manager)
|
||||
}
|
||||
|
||||
return manager
|
||||
}
|
||||
|
||||
// CertManagerOption 证书管理器选项
|
||||
type CertManagerOption func(*CertManager)
|
||||
|
||||
// WithUseECDSA 设置是否使用ECDSA(否则使用RSA)
|
||||
func WithUseECDSA(useECDSA bool) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
m.useECDSA = useECDSA
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultPrivateKey 设置是否使用默认私钥
|
||||
func WithDefaultPrivateKey(enable bool) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
if enable {
|
||||
if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
priv, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err))
|
||||
}
|
||||
m.defaultPrivateKey = priv
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("生成默认RSA私钥失败: %s", err))
|
||||
}
|
||||
m.defaultPrivateKey = priv
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// WithCurve 设置椭圆曲线
|
||||
func WithCurve(curve elliptic.Curve) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
m.curve = curve
|
||||
}
|
||||
}
|
||||
|
||||
// WithValidityYears 设置证书有效期(年)
|
||||
func WithValidityYears(years int) CertManagerOption {
|
||||
return func(m *CertManager) {
|
||||
if years > 0 {
|
||||
m.validityYears = years
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTLSConfig 为指定主机生成TLS配置
|
||||
func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) {
|
||||
// 处理可能的端口
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// 检查证书缓存
|
||||
if m.cache != nil {
|
||||
// 检查主域名和子域名
|
||||
fields := strings.Split(host, ".")
|
||||
domains := []string{host}
|
||||
|
||||
// 添加父域名
|
||||
if len(fields) > 2 {
|
||||
domains = append(domains, strings.Join(fields[1:], "."))
|
||||
}
|
||||
|
||||
// 查找缓存
|
||||
for _, domain := range domains {
|
||||
if cert := m.cache.Get(domain); cert != nil {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成新证书
|
||||
cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存证书
|
||||
if m.cache != nil {
|
||||
// 缓存主机名
|
||||
m.cache.Set(host, cert)
|
||||
|
||||
// 如果是IP地址,不进行其他处理
|
||||
if net.ParseIP(host) != nil {
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 缓存域名和子域名证书
|
||||
fields := strings.Split(host, ".")
|
||||
if len(fields) >= 2 {
|
||||
// 缓存主域名
|
||||
domain := strings.Join(fields[1:], ".")
|
||||
m.cache.Set(domain, cert)
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{*cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateCertificate 生成证书
|
||||
func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) {
|
||||
// 准备私钥
|
||||
var priv interface{}
|
||||
var pubKey interface{}
|
||||
var err error
|
||||
|
||||
// 使用默认私钥或生成新私钥
|
||||
if m.defaultPrivateKey != nil {
|
||||
priv = m.defaultPrivateKey
|
||||
} else if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err)
|
||||
}
|
||||
priv = ecdsaKey
|
||||
pubKey = &ecdsaKey.PublicKey
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成RSA私钥失败: %s", err)
|
||||
}
|
||||
priv = rsaKey
|
||||
pubKey = &rsaKey.PublicKey
|
||||
}
|
||||
|
||||
// 获取公钥
|
||||
if pubKey == nil {
|
||||
switch k := priv.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
pubKey = &k.PublicKey
|
||||
case *rsa.PrivateKey:
|
||||
pubKey = &k.PublicKey
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的私钥类型")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建证书模板
|
||||
template := m.createCertificateTemplate(host)
|
||||
|
||||
// 签名证书
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 编码为PEM格式
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: derBytes,
|
||||
})
|
||||
|
||||
// 编码私钥
|
||||
var keyPEM []byte
|
||||
switch k := priv.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
privBytes, err := x509.MarshalECPrivateKey(k)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err)
|
||||
}
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: privBytes,
|
||||
})
|
||||
case *rsa.PrivateKey:
|
||||
privBytes := x509.MarshalPKCS1PrivateKey(k)
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: privBytes,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的私钥类型")
|
||||
}
|
||||
|
||||
// 创建TLS证书
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建TLS证书对失败: %s", err)
|
||||
}
|
||||
|
||||
return &tlsCert, nil
|
||||
}
|
||||
|
||||
// createCertificateTemplate 创建证书模板
|
||||
func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate {
|
||||
// 使用基于主机名的哈希值作为序列号
|
||||
fv := fnv.New64a()
|
||||
fv.Write([]byte(host))
|
||||
serialNumber := big.NewInt(0).SetUint64(fv.Sum64())
|
||||
|
||||
// 准备模板
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
Organization: []string{"GoProxy Dynamic CA"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效,容忍时间偏差
|
||||
NotAfter: time.Now().AddDate(m.validityYears, 0, 0),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: false,
|
||||
}
|
||||
|
||||
// 处理IP地址和域名
|
||||
ipAddr := net.ParseIP(host)
|
||||
if ipAddr != nil {
|
||||
template.IPAddresses = []net.IP{ipAddr}
|
||||
} else {
|
||||
// 移除可能的端口部分
|
||||
if strings.Contains(host, ":") {
|
||||
host = strings.Split(host, ":")[0]
|
||||
}
|
||||
|
||||
// 将主机名添加到DNS名称列表
|
||||
template.DNSNames = []string{host}
|
||||
|
||||
// 添加通配符域名支持
|
||||
fields := strings.Split(host, ".")
|
||||
fieldNum := len(fields)
|
||||
|
||||
// 为每一级子域名添加通配符
|
||||
for i := 0; i <= (fieldNum - 2); i++ {
|
||||
wildcardDomain := "*." + strings.Join(fields[i:], ".")
|
||||
// 避免重复
|
||||
if wildcardDomain != host {
|
||||
template.DNSNames = append(template.DNSNames, wildcardDomain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// LoadCAFromFiles 从文件加载CA证书和私钥
|
||||
func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
// 读取CA证书
|
||||
caCertPEM, err := os.ReadFile(certFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err)
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(caCertPEM)
|
||||
if block == nil {
|
||||
return nil, nil, fmt.Errorf("解析CA证书PEM块失败")
|
||||
}
|
||||
|
||||
caCert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析CA证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 读取CA私钥
|
||||
caKeyPEM, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err)
|
||||
}
|
||||
|
||||
block, _ = pem.Decode(caKeyPEM)
|
||||
if block == nil {
|
||||
return nil, nil, fmt.Errorf("解析CA私钥PEM块失败")
|
||||
}
|
||||
|
||||
var caKey *ecdsa.PrivateKey
|
||||
|
||||
// 尝试不同的私钥格式
|
||||
switch block.Type {
|
||||
case "EC PRIVATE KEY":
|
||||
caKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err)
|
||||
}
|
||||
case "PRIVATE KEY":
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err)
|
||||
}
|
||||
var ok bool
|
||||
caKey, ok = key.(*ecdsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("私钥不是ECDSA类型")
|
||||
}
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type)
|
||||
}
|
||||
|
||||
return caCert, caKey, nil
|
||||
}
|
||||
|
||||
// GetDefaultRootCA 获取默认根证书和私钥
|
||||
func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
return defaultRootCA, defaultRootKey
|
||||
}
|
||||
|
||||
// GenerateRootCA 生成新的根证书和私钥
|
||||
func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||
// 生成私钥
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err)
|
||||
}
|
||||
|
||||
// 随机生成序列号
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成序列号失败: %s", err)
|
||||
}
|
||||
|
||||
// 创建根证书模板
|
||||
notBefore := time.Now().Add(-10 * time.Minute)
|
||||
notAfter := notBefore.AddDate(validYears, 0, 0)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "GoProxy Root CA",
|
||||
Organization: []string{"GoProxy"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 2,
|
||||
}
|
||||
|
||||
// 自签名
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 解析生成的证书
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||
}
|
||||
|
||||
return cert, priv, nil
|
||||
}
|
||||
|
||||
// SaveCertificateToFile 将证书和私钥保存到文件
|
||||
func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error {
|
||||
// 保存证书
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
})
|
||||
if err := os.WriteFile(certFile, certPEM, 0644); err != nil {
|
||||
return fmt.Errorf("保存证书到文件失败: %s", err)
|
||||
}
|
||||
|
||||
// 保存私钥
|
||||
keyBytes, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化私钥失败: %s", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "EC PRIVATE KEY",
|
||||
Bytes: keyBytes,
|
||||
})
|
||||
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
|
||||
return fmt.Errorf("保存私钥到文件失败: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRootCA 生成根证书
|
||||
func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) {
|
||||
// 生成私钥
|
||||
var priv interface{}
|
||||
var pubKey interface{}
|
||||
var err error
|
||||
|
||||
if m.useECDSA {
|
||||
// 生成ECDSA私钥
|
||||
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err)
|
||||
}
|
||||
priv = ecdsaKey
|
||||
pubKey = &ecdsaKey.PublicKey
|
||||
} else {
|
||||
// 生成RSA私钥
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err)
|
||||
}
|
||||
priv = rsaKey
|
||||
pubKey = &rsaKey.PublicKey
|
||||
}
|
||||
|
||||
// 创建根证书模板
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "GoProxy Root CA",
|
||||
Organization: []string{"GoProxy Authority"},
|
||||
Country: []string{"CN"},
|
||||
Province: []string{"GuangDong"},
|
||||
Locality: []string{"Guangzhou"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-10 * time.Minute),
|
||||
NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 1,
|
||||
}
|
||||
|
||||
// 自签名
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||
}
|
||||
|
||||
// 解析证书
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||
}
|
||||
|
||||
return cert, priv, nil
|
||||
}
|
134
internal/proxy/conn_buffer.go
Normal file
134
internal/proxy/conn_buffer.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnBuffer 连接缓冲区
|
||||
// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口
|
||||
type ConnBuffer struct {
|
||||
// 底层连接
|
||||
conn net.Conn
|
||||
// 缓冲读取器
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
// NewConnBuffer 创建连接缓冲区
|
||||
func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer {
|
||||
if reader == nil {
|
||||
reader = bufio.NewReader(conn)
|
||||
}
|
||||
return &ConnBuffer{
|
||||
conn: conn,
|
||||
reader: reader,
|
||||
}
|
||||
}
|
||||
|
||||
// Read 从连接读取数据
|
||||
func (c *ConnBuffer) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// Write 向连接写入数据
|
||||
func (c *ConnBuffer) Write(b []byte) (int, error) {
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (c *ConnBuffer) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// LocalAddr 获取本地地址
|
||||
func (c *ConnBuffer) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
// RemoteAddr 获取远程地址
|
||||
func (c *ConnBuffer) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
// SetDeadline 设置读写超时
|
||||
func (c *ConnBuffer) SetDeadline(t time.Time) error {
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
// SetReadDeadline 设置读取超时
|
||||
func (c *ConnBuffer) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline 设置写入超时
|
||||
func (c *ConnBuffer) SetWriteDeadline(t time.Time) error {
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// BufferReader 获取缓冲读取器
|
||||
func (c *ConnBuffer) BufferReader() *bufio.Reader {
|
||||
return c.reader
|
||||
}
|
||||
|
||||
// Peek 查看缓冲区中的数据,但不消费
|
||||
func (c *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||
return c.reader.Peek(n)
|
||||
}
|
||||
|
||||
// ReadByte 读取一个字节
|
||||
func (c *ConnBuffer) ReadByte() (byte, error) {
|
||||
return c.reader.ReadByte()
|
||||
}
|
||||
|
||||
// UnreadByte 将最后读取的字节放回缓冲区
|
||||
func (c *ConnBuffer) UnreadByte() error {
|
||||
return c.reader.UnreadByte()
|
||||
}
|
||||
|
||||
// ReadLine 读取一行数据
|
||||
func (c *ConnBuffer) ReadLine() (string, error) {
|
||||
line, isPrefix, err := c.reader.ReadLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 如果一行数据没有读取完整,继续读取
|
||||
if isPrefix {
|
||||
var buf []byte
|
||||
buf = append(buf, line...)
|
||||
for isPrefix && err == nil {
|
||||
line, isPrefix, err = c.reader.ReadLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
buf = append(buf, line...)
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
return string(line), nil
|
||||
}
|
||||
|
||||
// ReadN 读取指定字节数的数据
|
||||
func (c *ConnBuffer) ReadN(n int) ([]byte, error) {
|
||||
buf := make([]byte, n)
|
||||
_, err := io.ReadFull(c.reader, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Flush 刷新缓冲区
|
||||
func (c *ConnBuffer) Flush() error {
|
||||
// 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset 重置连接缓冲区
|
||||
func (c *ConnBuffer) Reset(conn net.Conn) {
|
||||
c.conn = conn
|
||||
c.reader = bufio.NewReader(conn)
|
||||
}
|
132
internal/proxy/context.go
Normal file
132
internal/proxy/context.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Context 代理上下文
|
||||
// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据
|
||||
type Context struct {
|
||||
// 原始请求
|
||||
Req *http.Request
|
||||
// 请求开始时间
|
||||
StartTime time.Time
|
||||
// 上下文数据,用于在各个处理阶段传递数据
|
||||
Data map[interface{}]interface{}
|
||||
// 是否是隧道代理
|
||||
TunnelProxy bool
|
||||
// 请求ID
|
||||
RequestID string
|
||||
// 目标地址
|
||||
TargetAddr string
|
||||
// 上级代理地址
|
||||
ParentProxyURL *url.URL
|
||||
// 是否中断执行
|
||||
abort bool
|
||||
// 请求标签,用于标记请求类型
|
||||
Tags []string
|
||||
// 是否已中止
|
||||
aborted bool
|
||||
// 互斥锁
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// IsHTTPS 是否是HTTPS请求
|
||||
func (c *Context) IsHTTPS() bool {
|
||||
return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect
|
||||
}
|
||||
|
||||
// defaultPorts 默认端口映射
|
||||
var defaultPorts = map[string]string{
|
||||
"https": "443",
|
||||
"http": "80",
|
||||
"": "80",
|
||||
}
|
||||
|
||||
// WebSocketURL 获取WebSocket URL
|
||||
func (c *Context) WebSocketURL() *url.URL {
|
||||
u := *c.Req.URL
|
||||
if c.IsHTTPS() {
|
||||
u.Scheme = "wss"
|
||||
} else {
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
return &u
|
||||
}
|
||||
|
||||
// Addr 获取请求地址
|
||||
func (c *Context) Addr() string {
|
||||
addr := c.Req.Host
|
||||
|
||||
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// AddTag 添加请求标签
|
||||
func (c *Context) AddTag(tag string) {
|
||||
c.Tags = append(c.Tags, tag)
|
||||
}
|
||||
|
||||
// HasTag 检查是否包含指定标签
|
||||
func (c *Context) HasTag(tag string) bool {
|
||||
for _, t := range c.Tags {
|
||||
if t == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Abort 中断执行
|
||||
func (c *Context) Abort() {
|
||||
c.aborted = true
|
||||
}
|
||||
|
||||
// IsAborted 是否已中断执行
|
||||
func (c *Context) IsAborted() bool {
|
||||
return c.aborted
|
||||
}
|
||||
|
||||
// Reset 重置上下文
|
||||
func (c *Context) Reset(req *http.Request) {
|
||||
c.Req = req
|
||||
c.StartTime = time.Now()
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
c.abort = false
|
||||
c.TunnelProxy = false
|
||||
c.Tags = make([]string, 0)
|
||||
c.RequestID = ""
|
||||
c.TargetAddr = ""
|
||||
c.ParentProxyURL = nil
|
||||
c.aborted = false
|
||||
}
|
||||
|
||||
// Set 设置数据
|
||||
func (c *Context) Set(key, value interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.Data == nil {
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
}
|
||||
c.Data[key] = value
|
||||
}
|
||||
|
||||
// Get 获取数据
|
||||
func (c *Context) Get(key interface{}) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.Data == nil {
|
||||
return nil, false
|
||||
}
|
||||
val, ok := c.Data[key]
|
||||
return val, ok
|
||||
}
|
123
internal/proxy/delegate.go
Normal file
123
internal/proxy/delegate.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Delegate 代理委托接口
|
||||
// 定义了代理处理请求的各个阶段的回调方法
|
||||
type Delegate interface {
|
||||
// Connect 连接事件
|
||||
Connect(ctx *Context, rw http.ResponseWriter)
|
||||
|
||||
// Auth 认证事件
|
||||
Auth(ctx *Context, rw http.ResponseWriter)
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
BeforeRequest(ctx *Context)
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||
|
||||
// WebSocketSendMessage websocket发送消息拦截
|
||||
// WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
|
||||
// WebSocketReceiveMessage websocket接收消息拦截
|
||||
// WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
|
||||
// ParentProxy 获取上级代理
|
||||
ParentProxy(req *http.Request) (*url.URL, error)
|
||||
|
||||
// ErrorLog 错误日志
|
||||
ErrorLog(err error)
|
||||
|
||||
// Finish 完成事件
|
||||
Finish(ctx *Context)
|
||||
|
||||
// 以下是反向代理相关的方法
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
// 在反向代理模式下,根据请求确定应该转发到哪个后端服务器
|
||||
ResolveBackend(req *http.Request) (string, error)
|
||||
|
||||
// ModifyRequest 修改请求
|
||||
// 在反向代理模式下,可以修改发往后端服务器的请求
|
||||
ModifyRequest(req *http.Request)
|
||||
|
||||
// ModifyResponse 修改响应
|
||||
// 在反向代理模式下,可以修改来自后端服务器的响应
|
||||
ModifyResponse(resp *http.Response) error
|
||||
|
||||
// HandleError 处理错误
|
||||
// 在反向代理模式下,可以自定义错误处理逻辑
|
||||
HandleError(rw http.ResponseWriter, req *http.Request, err error)
|
||||
}
|
||||
|
||||
// DefaultDelegate 默认代理委托
|
||||
type DefaultDelegate struct{}
|
||||
|
||||
// Connect 连接事件
|
||||
func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// Auth 认证事件
|
||||
func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// BeforeRequest 请求前事件
|
||||
func (d *DefaultDelegate) BeforeRequest(ctx *Context) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// BeforeResponse 响应前事件
|
||||
func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// ParentProxy 获取上级代理
|
||||
func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||
// 默认实现不使用上级代理
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// WebSocketSendMessage websocket发送消息拦截
|
||||
// func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
// WebSocketReceiveMessage websocket接收消息拦截
|
||||
// func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
// ErrorLog 错误日志
|
||||
func (d *DefaultDelegate) ErrorLog(err error) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// Finish 完成事件
|
||||
func (d *DefaultDelegate) Finish(ctx *Context) {
|
||||
// 默认实现不做任何处理
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
// 默认实现返回请求中的主机
|
||||
return req.Host, nil
|
||||
}
|
||||
|
||||
// ModifyRequest 修改请求
|
||||
func (d *DefaultDelegate) ModifyRequest(req *http.Request) {
|
||||
// 默认实现不做任何修改
|
||||
}
|
||||
|
||||
// ModifyResponse 修改响应
|
||||
func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error {
|
||||
// 默认实现不做任何修改
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleError 处理错误
|
||||
func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
// 默认实现返回502错误
|
||||
http.Error(rw, err.Error(), http.StatusBadGateway)
|
||||
}
|
262
internal/proxy/options.go
Normal file
262
internal/proxy/options.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/cache"
|
||||
"github.com/goproxy/internal/config"
|
||||
"github.com/goproxy/internal/healthcheck"
|
||||
"github.com/goproxy/internal/loadbalance"
|
||||
"github.com/goproxy/internal/metrics"
|
||||
)
|
||||
|
||||
// Option 用于配置代理选项的函数类型
|
||||
type Option func(*Options)
|
||||
|
||||
// WithConfig 设置代理配置
|
||||
func WithConfig(cfg *config.Config) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Config = cfg
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisableKeepAlive 设置连接是否重用
|
||||
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// 在transport中设置DisableKeepAlives
|
||||
}
|
||||
}
|
||||
|
||||
// WithClientTrace 设置HTTP客户端跟踪
|
||||
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||
return func(opt *Options) {
|
||||
opt.ClientTrace = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDelegate 设置委托类
|
||||
func WithDelegate(delegate Delegate) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Delegate = delegate
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransport 使用自定义HTTP传输
|
||||
func WithTransport(t *http.Transport) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// 在New方法中处理transport
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecryptHTTPS 启用中间人代理解密HTTPS
|
||||
func WithDecryptHTTPS(c CertificateCache) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.DecryptHTTPS = true
|
||||
opt.CertCache = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableWebsocketIntercept 启用WebSocket拦截
|
||||
func WithEnableWebsocketIntercept() Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
// WebSocket拦截在代理处理逻辑中实现
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPCache 设置HTTP缓存
|
||||
func WithHTTPCache(c cache.Cache) Option {
|
||||
return func(opt *Options) {
|
||||
opt.HTTPCache = c
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableCache = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithLoadBalancer 设置负载均衡器
|
||||
func WithLoadBalancer(lb loadbalance.LoadBalancer) Option {
|
||||
return func(opt *Options) {
|
||||
opt.LoadBalancer = lb
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableLoadBalancing = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithHealthChecker 设置健康检查器
|
||||
func WithHealthChecker(hc *healthcheck.HealthChecker) Option {
|
||||
return func(opt *Options) {
|
||||
opt.HealthChecker = hc
|
||||
}
|
||||
}
|
||||
|
||||
// WithMetrics 设置监控指标
|
||||
func WithMetrics(m metrics.Metrics) Option {
|
||||
return func(opt *Options) {
|
||||
opt.Metrics = m
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSCertAndKey 设置TLS证书和密钥
|
||||
func WithTLSCertAndKey(certPath, keyPath string) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.TLSCert = certPath
|
||||
opt.Config.TLSKey = keyPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithCACertAndKey 设置CA证书和密钥(用于生成动态证书)
|
||||
func WithCACertAndKey(caCertPath, caKeyPath string) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.CACert = caCertPath
|
||||
opt.Config.CAKey = caKeyPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectionPoolSize 设置连接池大小
|
||||
func WithConnectionPoolSize(size int) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.ConnectionPoolSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// WithIdleTimeout 设置空闲超时时间
|
||||
func WithIdleTimeout(timeout time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.IdleTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithRequestTimeout 设置请求超时时间
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.RequestTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithReverseProxy 启用反向代理模式
|
||||
func WithReverseProxy(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.ReverseProxy = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableRetry 启用请求重试
|
||||
func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableRetry = true
|
||||
opt.Config.MaxRetries = maxRetries
|
||||
opt.Config.RetryBackoff = baseBackoff
|
||||
opt.Config.MaxRetryBackoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
// WithRateLimit 设置请求限流
|
||||
func WithRateLimit(rps float64) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableRateLimit = true
|
||||
opt.Config.RateLimit = rps
|
||||
}
|
||||
}
|
||||
|
||||
// WithDNSCacheTTL 设置DNS缓存TTL
|
||||
func WithDNSCacheTTL(ttl time.Duration) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.DNSCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithURLRewrite 启用URL重写
|
||||
func WithURLRewrite(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableURLRewrite = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableCORS 启用CORS支持
|
||||
func WithEnableCORS(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.EnableCORS = enable
|
||||
}
|
||||
}
|
||||
|
||||
// WithCertManager 设置证书管理器
|
||||
// 这是一个内部函数,主要用于在New方法中设置CertManager
|
||||
func WithCertManager(certManager *CertManager) Option {
|
||||
return func(opt *Options) {
|
||||
opt.CertManager = certManager
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableECDSA 启用ECDSA证书生成(默认使用RSA)
|
||||
func WithEnableECDSA(enable bool) Option {
|
||||
return func(opt *Options) {
|
||||
if opt.Config == nil {
|
||||
opt.Config = config.DefaultConfig()
|
||||
}
|
||||
opt.Config.UseECDSA = enable
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithOptions 使用选项函数创建代理
|
||||
func NewWithOptions(options ...Option) *Proxy {
|
||||
opts := &Options{
|
||||
Config: config.DefaultConfig(),
|
||||
}
|
||||
|
||||
// 应用所有选项
|
||||
for _, option := range options {
|
||||
option(opts)
|
||||
}
|
||||
|
||||
return New(opts)
|
||||
}
|
1130
internal/proxy/proxy.go
Normal file
1130
internal/proxy/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
277
internal/proxy/reverse_proxy.go
Normal file
277
internal/proxy/reverse_proxy.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/rewriter"
|
||||
"github.com/goproxy/internal/router"
|
||||
)
|
||||
|
||||
// ReverseProxy 反向代理
|
||||
type ReverseProxy struct {
|
||||
// 代理对象
|
||||
proxy *Proxy
|
||||
// 路由器
|
||||
router *router.Router
|
||||
// URL重写器
|
||||
rewriter *rewriter.Rewriter
|
||||
// HTTP传输对象
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// NewReverseProxy 创建反向代理
|
||||
func (p *Proxy) NewReverseProxy() *ReverseProxy {
|
||||
rp := &ReverseProxy{
|
||||
proxy: p,
|
||||
router: router.NewRouter(),
|
||||
rewriter: rewriter.NewRewriter(),
|
||||
}
|
||||
|
||||
// 创建自定义的传输对象
|
||||
transport := &http.Transport{
|
||||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||
// 使用代理委托中的方法获取代理
|
||||
return p.delegate.ParentProxy(req)
|
||||
},
|
||||
DialContext: p.dialContextWithCache(),
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
MaxIdleConns: p.config.ConnectionPoolSize,
|
||||
IdleConnTimeout: p.config.IdleTimeout,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConnsPerHost: p.config.ConnectionPoolSize,
|
||||
DisableCompression: !p.config.EnableCompression,
|
||||
}
|
||||
|
||||
rp.transport = transport
|
||||
|
||||
// 如果配置了规则文件,加载规则
|
||||
if p.config.ReverseProxyRulesFile != "" {
|
||||
// 省略加载规则文件的实现
|
||||
}
|
||||
|
||||
return rp
|
||||
}
|
||||
|
||||
// ServeHTTP 处理反向代理请求
|
||||
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// 获取请求上下文
|
||||
ctx := ctxPool.Get().(*Context)
|
||||
ctx.Reset(req)
|
||||
defer ctxPool.Put(ctx)
|
||||
|
||||
// 调用连接事件
|
||||
rp.proxy.delegate.Connect(ctx, rw)
|
||||
|
||||
// 认证检查
|
||||
rp.proxy.delegate.Auth(ctx, rw)
|
||||
if ctx.IsAborted() {
|
||||
return
|
||||
}
|
||||
|
||||
// 请求前处理
|
||||
rp.proxy.delegate.BeforeRequest(ctx)
|
||||
if ctx.IsAborted() {
|
||||
return
|
||||
}
|
||||
|
||||
// 解析后端地址
|
||||
backend, err := rp.proxy.delegate.ResolveBackend(req)
|
||||
if err != nil {
|
||||
rp.proxy.delegate.HandleError(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建请求代理对象
|
||||
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: backend,
|
||||
})
|
||||
|
||||
// 使用自定义传输对象
|
||||
proxy.Transport = rp.transport
|
||||
|
||||
// 设置自定义错误处理函数
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
rp.proxy.delegate.HandleError(rw, req, err)
|
||||
}
|
||||
|
||||
// 设置请求修改函数
|
||||
originalDirector := proxy.Director
|
||||
proxy.Director = func(req *http.Request) {
|
||||
// 调用原始Director函数
|
||||
originalDirector(req)
|
||||
|
||||
// 处理URL重写
|
||||
if rp.proxy.config.EnableURLRewrite {
|
||||
rp.rewriter.Rewrite(req)
|
||||
}
|
||||
|
||||
// 修改请求头
|
||||
if rp.proxy.config.RewriteHostHeader {
|
||||
req.Host = backend
|
||||
}
|
||||
|
||||
// 添加X-Forwarded-For头
|
||||
if rp.proxy.config.AddXForwardedFor {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
// 如果已经有X-Forwarded-For,添加到末尾
|
||||
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
req.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加X-Real-IP头
|
||||
if rp.proxy.config.AddXRealIP {
|
||||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err == nil {
|
||||
req.Header.Set("X-Real-IP", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置协议头
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
if req.TLS != nil {
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
}
|
||||
|
||||
// 调用委托的ModifyRequest方法
|
||||
rp.proxy.delegate.ModifyRequest(req)
|
||||
}
|
||||
|
||||
// 设置响应修改函数
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// 处理响应URL重写
|
||||
if rp.proxy.config.EnableURLRewrite && resp != nil {
|
||||
rp.rewriter.RewriteResponse(resp, req.Host)
|
||||
}
|
||||
|
||||
// 添加CORS头
|
||||
if rp.proxy.config.EnableCORS && resp != nil {
|
||||
resp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||
resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
}
|
||||
|
||||
// 调用委托的ModifyResponse方法
|
||||
return rp.proxy.delegate.ModifyResponse(resp)
|
||||
}
|
||||
|
||||
// 更新监控指标
|
||||
if rp.proxy.metrics != nil {
|
||||
rp.proxy.metrics.IncActiveConnections()
|
||||
defer rp.proxy.metrics.DecActiveConnections()
|
||||
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(startTime)
|
||||
rp.proxy.metrics.ObserveRequestDuration(duration.Seconds())
|
||||
rp.proxy.metrics.IncRequestCount()
|
||||
}()
|
||||
}
|
||||
|
||||
// 处理WebSocket升级
|
||||
if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) {
|
||||
rp.handleWebSocketUpgrade(rw, req, backend)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理普通请求
|
||||
proxy.ServeHTTP(rw, req)
|
||||
|
||||
// 完成事件
|
||||
rp.proxy.delegate.Finish(ctx)
|
||||
}
|
||||
|
||||
// 处理WebSocket升级
|
||||
func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) {
|
||||
// 创建WebSocket代理
|
||||
target := &url.URL{
|
||||
Scheme: "ws",
|
||||
Host: backend,
|
||||
}
|
||||
|
||||
if req.TLS != nil {
|
||||
target.Scheme = "wss"
|
||||
}
|
||||
|
||||
// 创建连接到后端的WebSocket连接
|
||||
backendConn, err := rp.dialBackend(target.String(), req)
|
||||
if err != nil {
|
||||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err))
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// 将请求转发给后端
|
||||
err = req.Write(backendConn)
|
||||
if err != nil {
|
||||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 升级客户端连接
|
||||
clientConn, err := hijacker(rw)
|
||||
if err != nil {
|
||||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 双向转发数据
|
||||
rp.proxy.transfer(clientConn, backendConn)
|
||||
}
|
||||
|
||||
// 连接到后端
|
||||
func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
backend := strings.TrimPrefix(url, "ws://")
|
||||
backend = strings.TrimPrefix(backend, "wss://")
|
||||
|
||||
if strings.Contains(backend, "/") {
|
||||
backend = backend[:strings.Index(backend, "/")]
|
||||
}
|
||||
|
||||
// 根据协议选择连接方式
|
||||
if strings.HasPrefix(url, "wss://") {
|
||||
// 使用 tls.Dialer 替代不存在的 tls.DialWithContext
|
||||
dialer := &tls.Dialer{
|
||||
Config: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
return dialer.DialContext(ctx, "tcp", backend)
|
||||
}
|
||||
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", backend)
|
||||
}
|
||||
|
||||
// 添加路由规则
|
||||
func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) {
|
||||
route := &router.Route{
|
||||
Pattern: pattern,
|
||||
Type: routeType,
|
||||
Target: target,
|
||||
}
|
||||
rp.router.AddRoute(route)
|
||||
}
|
||||
|
||||
// 添加重写规则
|
||||
func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error {
|
||||
return rp.rewriter.AddRule(pattern, replacement, useRegex)
|
||||
}
|
98
internal/rewriter/rewriter.go
Normal file
98
internal/rewriter/rewriter.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package rewriter
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Rewriter URL重写器
|
||||
// 用于在反向代理中重写请求URL
|
||||
type Rewriter struct {
|
||||
// 重写规则列表
|
||||
rules []*RewriteRule
|
||||
}
|
||||
|
||||
// RewriteRule 重写规则
|
||||
type RewriteRule struct {
|
||||
// 匹配模式
|
||||
Pattern string
|
||||
// 替换模式
|
||||
Replacement string
|
||||
// 是否使用正则表达式
|
||||
UseRegex bool
|
||||
// 编译后的正则表达式
|
||||
regex *regexp.Regexp
|
||||
}
|
||||
|
||||
// NewRewriter 创建URL重写器
|
||||
func NewRewriter() *Rewriter {
|
||||
return &Rewriter{
|
||||
rules: make([]*RewriteRule, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule 添加重写规则
|
||||
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
|
||||
rule := &RewriteRule{
|
||||
Pattern: pattern,
|
||||
Replacement: replacement,
|
||||
UseRegex: useRegex,
|
||||
}
|
||||
|
||||
if useRegex {
|
||||
regex, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rule.regex = regex
|
||||
}
|
||||
|
||||
r.rules = append(r.rules, rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rewrite 重写URL
|
||||
func (r *Rewriter) Rewrite(req *http.Request) {
|
||||
path := req.URL.Path
|
||||
|
||||
for _, rule := range r.rules {
|
||||
if rule.UseRegex {
|
||||
if rule.regex.MatchString(path) {
|
||||
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if strings.HasPrefix(path, rule.Pattern) {
|
||||
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RewriteResponse 重写响应
|
||||
// 主要用于处理响应中的Location头和内容中的URL
|
||||
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
|
||||
// 处理重定向头
|
||||
location := resp.Header.Get("Location")
|
||||
if location != "" {
|
||||
// 将后端服务器的域名替换成代理服务器的域名
|
||||
for _, rule := range r.rules {
|
||||
if rule.UseRegex && rule.regex != nil {
|
||||
if rule.regex.MatchString(location) {
|
||||
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
|
||||
resp.Header.Set("Location", newLocation)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRulesFromFile 从文件加载重写规则
|
||||
func (r *Rewriter) LoadRulesFromFile(filename string) error {
|
||||
// 实现从配置文件加载规则的逻辑
|
||||
// 这里省略实现细节
|
||||
return nil
|
||||
}
|
103
internal/router/router.go
Normal file
103
internal/router/router.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Route 路由规则
|
||||
type Route struct {
|
||||
// 匹配模式(主机名、路径、正则表达式)
|
||||
Pattern string
|
||||
// 匹配类型
|
||||
Type RouteType
|
||||
// 目标地址
|
||||
Target string
|
||||
// 路径重写规则
|
||||
RewritePattern string
|
||||
// 请求头修改
|
||||
HeaderModifier HeaderModifier
|
||||
// 自定义匹配函数
|
||||
MatchFunc func(req *http.Request) bool
|
||||
}
|
||||
|
||||
// RouteType 路由类型
|
||||
type RouteType int
|
||||
|
||||
const (
|
||||
// HostRoute 主机名路由
|
||||
HostRoute RouteType = iota
|
||||
// PathRoute 路径路由
|
||||
PathRoute
|
||||
// RegexRoute 正则表达式路由
|
||||
RegexRoute
|
||||
// CustomRoute 自定义路由
|
||||
CustomRoute
|
||||
)
|
||||
|
||||
// HeaderModifier 头部修改接口
|
||||
type HeaderModifier interface {
|
||||
// ModifyRequest 修改请求头
|
||||
ModifyRequest(req *http.Request)
|
||||
// ModifyResponse 修改响应头
|
||||
ModifyResponse(resp *http.Response)
|
||||
}
|
||||
|
||||
// Router 路由器
|
||||
type Router struct {
|
||||
routes []*Route
|
||||
}
|
||||
|
||||
// NewRouter 创建路由器
|
||||
func NewRouter() *Router {
|
||||
return &Router{
|
||||
routes: make([]*Route, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRoute 添加路由规则
|
||||
func (r *Router) AddRoute(route *Route) {
|
||||
r.routes = append(r.routes, route)
|
||||
}
|
||||
|
||||
// Match 匹配请求
|
||||
func (r *Router) Match(req *http.Request) (*Route, bool) {
|
||||
for _, route := range r.routes {
|
||||
switch route.Type {
|
||||
case HostRoute:
|
||||
if matchHost(req.Host, route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case PathRoute:
|
||||
if matchPath(req.URL.Path, route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case RegexRoute:
|
||||
if matchRegex(req.URL.String(), route.Pattern) {
|
||||
return route, true
|
||||
}
|
||||
case CustomRoute:
|
||||
if route.MatchFunc != nil && route.MatchFunc(req) {
|
||||
return route, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 匹配主机名
|
||||
func matchHost(host, pattern string) bool {
|
||||
return host == pattern || strings.HasSuffix(host, "."+pattern)
|
||||
}
|
||||
|
||||
// 匹配路径
|
||||
func matchPath(path, pattern string) bool {
|
||||
return strings.HasPrefix(path, pattern)
|
||||
}
|
||||
|
||||
// 匹配正则表达式
|
||||
func matchRegex(url, pattern string) bool {
|
||||
matched, _ := regexp.MatchString(pattern, url)
|
||||
return matched
|
||||
}
|
14
mitm-proxy.crt
Normal file
14
mitm-proxy.crt
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
|
||||
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
|
||||
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
|
||||
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
|
||||
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
|
||||
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
|
||||
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
|
||||
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
|
||||
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
|
||||
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
|
||||
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
|
||||
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
|
||||
-----END CERTIFICATE-----
|
753
proxy.go
Normal file
753
proxy.go
Normal file
@@ -0,0 +1,753 @@
|
||||
// Copyright 2018 ouqiang authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/viki-org/dnscache"
|
||||
|
||||
"github.com/ouqiang/goproxy/cert"
|
||||
"github.com/ouqiang/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// 连接目标服务器超时时间
|
||||
defaultTargetConnectTimeout = 5 * time.Second
|
||||
// 目标服务器读写超时时间
|
||||
defaultTargetReadWriteTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// 隧道连接成功响应行
|
||||
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
|
||||
|
||||
var (
|
||||
bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
ctxPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Context)
|
||||
},
|
||||
}
|
||||
headerPool = NewHeaderPool()
|
||||
requestPool = newRequestPool()
|
||||
)
|
||||
|
||||
type RequestPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newRequestPool() *RequestPool {
|
||||
return &RequestPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(http.Request)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RequestPool) Get() *http.Request {
|
||||
req := p.pool.Get().(*http.Request)
|
||||
|
||||
req.Method = ""
|
||||
req.URL = nil
|
||||
req.Proto = ""
|
||||
req.ProtoMajor = 0
|
||||
req.ProtoMinor = 0
|
||||
req.Header = nil
|
||||
req.Body = nil
|
||||
req.GetBody = nil
|
||||
req.ContentLength = 0
|
||||
req.TransferEncoding = nil
|
||||
req.Close = false
|
||||
req.Host = ""
|
||||
req.Form = nil
|
||||
req.PostForm = nil
|
||||
req.MultipartForm = nil
|
||||
req.Trailer = nil
|
||||
req.RemoteAddr = ""
|
||||
req.RequestURI = ""
|
||||
req.TLS = nil
|
||||
req.Cancel = nil
|
||||
req.Response = nil
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (p *RequestPool) Put(req *http.Request) {
|
||||
if req != nil {
|
||||
p.pool.Put(req)
|
||||
}
|
||||
}
|
||||
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return http.Header{}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
header := p.pool.Get().(http.Header)
|
||||
for k := range header {
|
||||
delete(header, k)
|
||||
}
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (p *HeaderPool) Put(header http.Header) {
|
||||
if header != nil {
|
||||
p.pool.Put(header)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成隧道建立请求行
|
||||
func makeTunnelRequestLine(addr string) string {
|
||||
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
disableKeepAlive bool
|
||||
delegate Delegate
|
||||
|
||||
decryptHTTPS bool
|
||||
websocketIntercept bool
|
||||
certCache cert.Cache
|
||||
transport *http.Transport
|
||||
clientTrace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
type Option func(*options)
|
||||
|
||||
// WithDisableKeepAlive 连接是否重用
|
||||
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||
return func(opt *options) {
|
||||
opt.disableKeepAlive = disableKeepAlive
|
||||
}
|
||||
}
|
||||
|
||||
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||
return func(opt *options) {
|
||||
opt.clientTrace = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDelegate 设置委托类
|
||||
func WithDelegate(delegate Delegate) Option {
|
||||
return func(opt *options) {
|
||||
opt.delegate = delegate
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransport 自定义http transport
|
||||
func WithTransport(t *http.Transport) Option {
|
||||
return func(opt *options) {
|
||||
opt.transport = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
|
||||
func WithDecryptHTTPS(c cert.Cache) Option {
|
||||
return func(opt *options) {
|
||||
opt.decryptHTTPS = true
|
||||
opt.certCache = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableWebsocketIntercept 拦截websocket
|
||||
func WithEnableWebsocketIntercept() Option {
|
||||
return func(opt *options) {
|
||||
opt.websocketIntercept = true
|
||||
}
|
||||
}
|
||||
|
||||
// New 创建proxy实例
|
||||
func New(opt ...Option) *Proxy {
|
||||
opts := &options{}
|
||||
for _, o := range opt {
|
||||
o(opts)
|
||||
}
|
||||
if opts.delegate == nil {
|
||||
opts.delegate = &DefaultDelegate{}
|
||||
}
|
||||
if opts.transport == nil {
|
||||
opts.transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxConnsPerHost: 10,
|
||||
IdleConnTimeout: 10 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
p := &Proxy{}
|
||||
p.delegate = opts.delegate
|
||||
p.websocketIntercept = opts.websocketIntercept
|
||||
p.decryptHTTPS = opts.decryptHTTPS
|
||||
if p.decryptHTTPS {
|
||||
p.cert = cert.NewCertificate(opts.certCache, true)
|
||||
}
|
||||
p.transport = opts.transport
|
||||
p.transport.DialContext = p.dialContext()
|
||||
p.dnsCache = dnscache.New(5 * time.Minute)
|
||||
p.transport.DisableKeepAlives = opts.disableKeepAlive
|
||||
p.transport.Proxy = p.delegate.ParentProxy
|
||||
p.clientTrace = opts.clientTrace
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Proxy 实现了http.Handler接口
|
||||
type Proxy struct {
|
||||
delegate Delegate
|
||||
clientConnNum int32
|
||||
decryptHTTPS bool
|
||||
websocketIntercept bool
|
||||
cert *cert.Certificate
|
||||
transport *http.Transport
|
||||
clientTrace *httptrace.ClientTrace
|
||||
dnsCache *dnscache.Resolver
|
||||
}
|
||||
|
||||
var _ http.Handler = &Proxy{}
|
||||
|
||||
// ServeHTTP 实现了http.Handler接口
|
||||
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Host == "" {
|
||||
req.URL.Host = req.Host
|
||||
}
|
||||
atomic.AddInt32(&p.clientConnNum, 1)
|
||||
ctx := ctxPool.Get().(*Context)
|
||||
ctx.Reset(req)
|
||||
|
||||
defer func() {
|
||||
p.delegate.Finish(ctx)
|
||||
ctxPool.Put(ctx)
|
||||
atomic.AddInt32(&p.clientConnNum, -1)
|
||||
}()
|
||||
p.delegate.Connect(ctx, rw)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
p.delegate.Auth(ctx, rw)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case ctx.Req.Method == http.MethodConnect:
|
||||
p.tunnelProxy(ctx, rw)
|
||||
case websocket.IsWebSocketUpgrade(ctx.Req):
|
||||
p.tunnelProxy(ctx, rw)
|
||||
default:
|
||||
p.httpProxy(ctx, rw)
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConnNum 获取客户端连接数
|
||||
func (p *Proxy) ClientConnNum() int32 {
|
||||
return atomic.LoadInt32(&p.clientConnNum)
|
||||
}
|
||||
|
||||
// DoRequest 执行HTTP请求,并调用responseFunc处理response
|
||||
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
|
||||
if ctx.Data == nil {
|
||||
ctx.Data = make(map[interface{}]interface{})
|
||||
}
|
||||
p.delegate.BeforeRequest(ctx)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
newReq := requestPool.Get()
|
||||
*newReq = *ctx.Req
|
||||
newHeader := headerPool.Get()
|
||||
CloneHeader(newReq.Header, newHeader)
|
||||
newReq.Header = newHeader
|
||||
for _, item := range hopHeaders {
|
||||
if newReq.Header.Get(item) != "" {
|
||||
newReq.Header.Del(item)
|
||||
}
|
||||
}
|
||||
if p.clientTrace != nil {
|
||||
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
|
||||
}
|
||||
|
||||
resp, err := p.transport.RoundTrip(newReq)
|
||||
p.delegate.BeforeResponse(ctx, resp, err)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
for _, h := range hopHeaders {
|
||||
resp.Header.Del(h)
|
||||
}
|
||||
}
|
||||
responseFunc(resp, err)
|
||||
headerPool.Put(newHeader)
|
||||
requestPool.Put(newReq)
|
||||
}
|
||||
|
||||
// HTTP代理
|
||||
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
|
||||
ctx.Req.URL.Scheme = "http"
|
||||
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
CopyHeader(rw.Header(), resp.Header)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, _ = io.CopyBuffer(rw, resp.Body, buf)
|
||||
bufPool.Put(buf)
|
||||
})
|
||||
}
|
||||
|
||||
// HTTPS代理
|
||||
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
|
||||
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
|
||||
return
|
||||
}
|
||||
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
|
||||
_, _ = tlsClientConn.Write(badGateway)
|
||||
return
|
||||
}
|
||||
err = resp.Write(tlsClientConn)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// 隧道代理
|
||||
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
|
||||
clientConn, err := hijacker(rw)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||
p.websocketProxy(ctx, clientConn)
|
||||
return
|
||||
}
|
||||
|
||||
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if parentProxyURL == nil {
|
||||
_, err = clientConn.Write(tunnelEstablishedResponseLine)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
isWebsocket := p.detectConnProtocol(clientConn)
|
||||
if isWebsocket {
|
||||
req, err := http.ReadRequest(clientConn.BufferReader())
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
req.RemoteAddr = ctx.Req.RemoteAddr
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = req.Host
|
||||
ctx.Req = req
|
||||
|
||||
p.websocketProxy(ctx, clientConn)
|
||||
return
|
||||
}
|
||||
var tlsClientConn *tls.Conn
|
||||
if p.decryptHTTPS {
|
||||
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
tlsClientConn = tls.Server(clientConn, tlsConfig)
|
||||
defer func() {
|
||||
_ = tlsClientConn.Close()
|
||||
}()
|
||||
if err := tlsClientConn.Handshake(); err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := bufio.NewReader(tlsClientConn)
|
||||
tlsReq, err := http.ReadRequest(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
|
||||
tlsReq.URL.Scheme = "https"
|
||||
tlsReq.URL.Host = tlsReq.Host
|
||||
ctx.Req = tlsReq
|
||||
}
|
||||
|
||||
targetAddr := ctx.Req.URL.Host
|
||||
if parentProxyURL != nil {
|
||||
targetAddr = parentProxyURL.Host
|
||||
}
|
||||
if !strings.Contains(targetAddr, ":") {
|
||||
targetAddr += ":443"
|
||||
}
|
||||
|
||||
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = targetConn.Close()
|
||||
}()
|
||||
if parentProxyURL != nil {
|
||||
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
|
||||
_, _ = targetConn.Write([]byte(tunnelRequestLine))
|
||||
}
|
||||
|
||||
if p.decryptHTTPS {
|
||||
p.httpsProxy(ctx, tlsClientConn)
|
||||
} else {
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transfer(clientConn, targetConn)
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket代理
|
||||
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
|
||||
if !p.websocketIntercept {
|
||||
remoteAddr := ctx.Addr()
|
||||
var err error
|
||||
var targetConn net.Conn
|
||||
if ctx.IsHTTPS() {
|
||||
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
|
||||
} else {
|
||||
targetConn, err = net.Dial("tcp", remoteAddr)
|
||||
}
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
err = ctx.Req.Write(targetConn)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transfer(srcConn, targetConn)
|
||||
return
|
||||
}
|
||||
|
||||
up := &websocket.Upgrader{
|
||||
HandshakeTimeout: defaultTargetConnectTimeout,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
|
||||
u := ctx.WebsocketUrl()
|
||||
d := websocket.Dialer{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
}
|
||||
|
||||
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
|
||||
defer cancel()
|
||||
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
|
||||
}
|
||||
|
||||
// 探测连接协议
|
||||
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
|
||||
methodBytes, err := connBuf.Peek(3)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
method := string(methodBytes)
|
||||
if method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// webSocket双向转发
|
||||
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
|
||||
doneCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if doneCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msgType, msg, err := srcConn.ReadMessage()
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
|
||||
err = targetConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if doneCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msgType, msg, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
|
||||
err = srcConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 双向转发
|
||||
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
|
||||
go func() {
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, err := io.CopyBuffer(src, dst, buf)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
|
||||
}
|
||||
bufPool.Put(buf)
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, err := io.CopyBuffer(dst, src, buf)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
|
||||
}
|
||||
bufPool.Put(buf)
|
||||
_ = dst.Close()
|
||||
_ = src.Close()
|
||||
}
|
||||
|
||||
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
|
||||
ctx.TunnelProxy = true
|
||||
p.delegate.BeforeRequest(ctx)
|
||||
if err != nil {
|
||||
p.delegate.BeforeResponse(ctx, nil, err)
|
||||
return
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: "200 OK",
|
||||
StatusCode: http.StatusOK,
|
||||
Proto: "1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
p.delegate.BeforeResponse(ctx, resp, nil)
|
||||
}
|
||||
|
||||
func (p *Proxy) dialContext() DialContext {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTargetConnectTimeout,
|
||||
}
|
||||
separator := strings.LastIndex(addr, ":")
|
||||
ips, err := p.dnsCache.Fetch(addr[:separator])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ip string
|
||||
for _, item := range ips {
|
||||
ip = item.String()
|
||||
if !strings.Contains(ip, ":") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
addr = ip + addr[separator:]
|
||||
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取底层连接
|
||||
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("http server不支持Hijacker")
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hijacker错误: %s", err)
|
||||
}
|
||||
|
||||
return NewConnBuffer(conn, buf), nil
|
||||
}
|
||||
|
||||
// CopyHeader 浅拷贝Header
|
||||
func CopyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloneHeader 深拷贝Header
|
||||
func CloneHeader(h http.Header, h2 http.Header) {
|
||||
for k, vv := range h {
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
h2[k] = vv2
|
||||
}
|
||||
}
|
||||
|
||||
var hopHeaders = []string{
|
||||
"Proxy-Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
}
|
||||
|
||||
type ConnBuffer struct {
|
||||
net.Conn
|
||||
buf *bufio.ReadWriter
|
||||
}
|
||||
|
||||
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
|
||||
if buf == nil {
|
||||
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
}
|
||||
return &ConnBuffer{
|
||||
Conn: conn,
|
||||
buf: buf,
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
|
||||
return cb.buf.Reader
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
|
||||
return cb.buf.Read(b)
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||
return cb.buf.Peek(n)
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
|
||||
n, err = cb.buf.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, cb.buf.Flush()
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return cb.Conn, cb.buf, nil
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) WriteHeader(_ int) {}
|
||||
|
||||
func (cb *ConnBuffer) Header() http.Header { return nil }
|
Reference in New Issue
Block a user