init
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
.idea
|
||||||
|
./delegate.go
|
||||||
|
./proxy.go
|
454
README.md
Normal file
454
README.md
Normal file
@@ -0,0 +1,454 @@
|
|||||||
|
# GoProxy
|
||||||
|
|
||||||
|
GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebSocket代理,并提供了丰富的功能和扩展点。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- 支持HTTP、HTTPS和WebSocket代理
|
||||||
|
- 支持正向代理和反向代理
|
||||||
|
- 支持HTTPS解密(中间人模式)
|
||||||
|
- 自定义CA证书和私钥
|
||||||
|
- 动态证书生成与缓存
|
||||||
|
- 通配符域名证书支持
|
||||||
|
- 支持RSA和ECDSA证书算法选择
|
||||||
|
- 支持上游代理链
|
||||||
|
- 支持负载均衡(轮询、随机、权重等)
|
||||||
|
- 支持健康检查
|
||||||
|
- 支持请求重试
|
||||||
|
- 支持HTTP缓存
|
||||||
|
- 支持请求限流
|
||||||
|
- 支持监控指标收集
|
||||||
|
- 支持自定义处理逻辑(委托模式)
|
||||||
|
- 支持DNS缓存
|
||||||
|
- 支持URL重写(反向代理模式)
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go get github.com/goproxy
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 正向代理
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(nil)
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
log.Println("代理服务器启动在 :8080")
|
||||||
|
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启用HTTPS解密
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/config"
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建配置
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.DecryptHTTPS = true
|
||||||
|
cfg.CACert = "ca.crt" // CA证书路径
|
||||||
|
cfg.CAKey = "ca.key" // CA私钥路径
|
||||||
|
cfg.UseECDSA = true // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||||
|
|
||||||
|
// 可选:使用自定义TLS证书
|
||||||
|
// cfg.TLSCert = "server.crt"
|
||||||
|
// cfg.TLSKey = "server.key"
|
||||||
|
|
||||||
|
// 创建证书缓存
|
||||||
|
certCache := &proxy.MemCertCache{}
|
||||||
|
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(&proxy.Options{
|
||||||
|
Config: cfg,
|
||||||
|
CertCache: certCache,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
log.Println("HTTPS解密代理服务器启动在 :8080")
|
||||||
|
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> **注意**: 使用HTTPS解密功能时,需要在客户端安装CA证书,否则会出现证书警告。
|
||||||
|
|
||||||
|
### 反向代理
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/config"
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建配置
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.ReverseProxy = true
|
||||||
|
cfg.EnableURLRewrite = true
|
||||||
|
cfg.AddXForwardedFor = true
|
||||||
|
cfg.AddXRealIP = true
|
||||||
|
|
||||||
|
// 创建自定义委托
|
||||||
|
delegate := &ReverseProxyDelegate{
|
||||||
|
backend: "localhost:8081",
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(&proxy.Options{
|
||||||
|
Config: cfg,
|
||||||
|
Delegate: delegate,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
log.Println("反向代理服务器启动在 :8080")
|
||||||
|
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyDelegate 反向代理委托
|
||||||
|
type ReverseProxyDelegate struct {
|
||||||
|
proxy.DefaultDelegate
|
||||||
|
backend string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBackend 解析后端服务器
|
||||||
|
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||||
|
return d.backend, nil
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 自定义委托
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建自定义委托
|
||||||
|
delegate := &CustomDelegate{}
|
||||||
|
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(&proxy.Options{
|
||||||
|
Delegate: delegate,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
log.Println("代理服务器启动在 :8080")
|
||||||
|
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomDelegate 自定义委托
|
||||||
|
type CustomDelegate struct {
|
||||||
|
proxy.DefaultDelegate
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRequest 请求前事件
|
||||||
|
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||||
|
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeResponse 响应前事件
|
||||||
|
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("响应错误: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 完整示例
|
||||||
|
|
||||||
|
- 正向代理示例: [cmd/example/main.go](cmd/example/main.go)
|
||||||
|
- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go)
|
||||||
|
|
||||||
|
### 使用函数式选项模式
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/metrics"
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 创建监控指标
|
||||||
|
metricsCollector := metrics.NewSimpleMetrics()
|
||||||
|
|
||||||
|
// 创建证书缓存
|
||||||
|
certCache := &proxy.MemCertCache{}
|
||||||
|
|
||||||
|
// 使用函数式选项模式创建代理
|
||||||
|
p := proxy.NewProxy(
|
||||||
|
// 启用HTTPS解密
|
||||||
|
proxy.WithDecryptHTTPS(certCache),
|
||||||
|
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||||
|
|
||||||
|
// 设置监控指标
|
||||||
|
proxy.WithMetrics(metricsCollector),
|
||||||
|
|
||||||
|
// 设置请求超时和连接池
|
||||||
|
proxy.WithRequestTimeout(30 * time.Second),
|
||||||
|
proxy.WithConnectionPoolSize(100),
|
||||||
|
proxy.WithIdleTimeout(90 * time.Second),
|
||||||
|
|
||||||
|
// 启用DNS缓存
|
||||||
|
proxy.WithDNSCacheTTL(10 * time.Minute),
|
||||||
|
|
||||||
|
// 启用请求重试
|
||||||
|
proxy.WithEnableRetry(3, 1*time.Second, 10*time.Second),
|
||||||
|
|
||||||
|
// 启用CORS支持
|
||||||
|
proxy.WithEnableCORS(true),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 启动HTTP服务器和监控服务器
|
||||||
|
go func() {
|
||||||
|
log.Println("监控服务器启动在 :8081")
|
||||||
|
http.Handle("/metrics", metricsCollector)
|
||||||
|
if err := http.ListenAndServe(":8081", nil); err != nil {
|
||||||
|
log.Fatalf("监控服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 启动代理服务器
|
||||||
|
log.Println("代理服务器启动在 :8080")
|
||||||
|
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 架构设计
|
||||||
|
|
||||||
|
GoProxy采用模块化设计,主要包含以下模块:
|
||||||
|
|
||||||
|
- **代理核心(Proxy)**:处理HTTP请求和响应,实现代理功能
|
||||||
|
- **反向代理(ReverseProxy)**:处理反向代理请求,支持URL重写和请求修改
|
||||||
|
- **路由(Router)**:基于主机名、路径、正则表达式等规则路由请求到不同的后端
|
||||||
|
- **URL重写(Rewriter)**:重写请求URL和响应中的URL
|
||||||
|
- **代理上下文(Context)**:保存请求上下文信息,用于在处理过程中传递数据
|
||||||
|
- **代理委托(Delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑
|
||||||
|
- **连接缓冲区(ConnBuffer)**:封装网络连接,提供缓冲读写功能
|
||||||
|
- **负载均衡(LoadBalancer)**:实现负载均衡算法,支持轮询、随机、权重等
|
||||||
|
- **健康检查(HealthChecker)**:检查上游服务器的健康状态,自动剔除不健康的服务器
|
||||||
|
- **缓存(Cache)**:实现HTTP缓存,减少重复请求
|
||||||
|
- **缓存适配器(CacheAdapter)**:统一不同缓存实现的接口,提高代码可读性和性能
|
||||||
|
- **证书生成(CertGenerator)**:动态生成TLS证书,支持HTTPS解密
|
||||||
|
- **限流(RateLimit)**:实现请求限流,防止过载
|
||||||
|
- **监控(Metrics)**:收集代理运行指标,用于监控和分析
|
||||||
|
- **重试(Retry)**:实现请求重试,提高请求成功率
|
||||||
|
|
||||||
|
## 配置选项
|
||||||
|
|
||||||
|
GoProxy提供了丰富的配置选项,可以通过`Options`结构体进行配置:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Options struct {
|
||||||
|
// 配置
|
||||||
|
Config *config.Config
|
||||||
|
// 委托
|
||||||
|
Delegate Delegate
|
||||||
|
// 证书缓存
|
||||||
|
CertCache CertificateCache
|
||||||
|
// HTTP缓存
|
||||||
|
HTTPCache cache.Cache
|
||||||
|
// 负载均衡器
|
||||||
|
LoadBalancer loadbalance.LoadBalancer
|
||||||
|
// 健康检查器
|
||||||
|
HealthChecker *healthcheck.HealthChecker
|
||||||
|
// 监控指标
|
||||||
|
Metrics metrics.Metrics
|
||||||
|
// 客户端跟踪
|
||||||
|
ClientTrace *httptrace.ClientTrace
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 函数式选项模式
|
||||||
|
|
||||||
|
GoProxy 现在支持函数式选项模式(Functional Options Pattern),通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于:
|
||||||
|
|
||||||
|
- 参数配置更加直观和清晰
|
||||||
|
- 可以灵活选择需要的配置项,不必记忆参数顺序
|
||||||
|
- 代码可读性更高,便于维护
|
||||||
|
- 可以逐步添加新的配置选项而不破坏兼容性
|
||||||
|
|
||||||
|
可以使用 `NewProxy` 函数和函数式选项创建代理:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// 创建一个简单的代理
|
||||||
|
proxy := proxy.NewProxy()
|
||||||
|
|
||||||
|
// 创建一个功能丰富的代理
|
||||||
|
proxy := proxy.NewProxy(
|
||||||
|
proxy.WithConfig(config.DefaultConfig()),
|
||||||
|
proxy.WithHTTPCache(myCache),
|
||||||
|
proxy.WithDecryptHTTPS(myCertCache),
|
||||||
|
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||||
|
proxy.WithMetrics(myMetrics),
|
||||||
|
proxy.WithLoadBalancer(myLoadBalancer),
|
||||||
|
proxy.WithRequestTimeout(10 * time.Second),
|
||||||
|
proxy.WithEnableCORS(true)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 可用的 With 方法
|
||||||
|
|
||||||
|
GoProxy 提供了以下 With 方法用于配置代理的各个方面:
|
||||||
|
|
||||||
|
#### 基础配置选项
|
||||||
|
- `WithConfig(cfg *config.Config)`: 设置代理配置
|
||||||
|
- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用
|
||||||
|
- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输
|
||||||
|
- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪
|
||||||
|
|
||||||
|
#### 功能模块选项
|
||||||
|
- `WithDelegate(delegate Delegate)`: 设置委托类
|
||||||
|
- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存
|
||||||
|
- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器
|
||||||
|
- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器
|
||||||
|
- `WithMetrics(m metrics.Metrics)`: 设置监控指标
|
||||||
|
|
||||||
|
#### 功能开启选项
|
||||||
|
- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS
|
||||||
|
- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成(默认使用RSA)
|
||||||
|
- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截
|
||||||
|
- `WithReverseProxy(enable bool)`: 启用反向代理模式
|
||||||
|
- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试
|
||||||
|
- `WithRateLimit(rps float64)`: 设置请求限流
|
||||||
|
- `WithURLRewrite(enable bool)`: 启用URL重写
|
||||||
|
- `WithEnableCORS(enable bool)`: 启用CORS支持
|
||||||
|
|
||||||
|
#### 证书相关选项
|
||||||
|
- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥
|
||||||
|
- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥
|
||||||
|
|
||||||
|
#### 性能和超时相关选项
|
||||||
|
- `WithConnectionPoolSize(size int)`: 设置连接池大小
|
||||||
|
- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间
|
||||||
|
- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间
|
||||||
|
- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL
|
||||||
|
|
||||||
|
### 配置HTTPS解密
|
||||||
|
|
||||||
|
要启用HTTPS解密功能,需要在配置中设置以下选项:
|
||||||
|
|
||||||
|
```go
|
||||||
|
config := &config.Config{
|
||||||
|
// 启用HTTPS解密
|
||||||
|
DecryptHTTPS: true,
|
||||||
|
|
||||||
|
// 方式一:使用CA证书和私钥动态生成证书
|
||||||
|
CACert: "path/to/ca.crt", // CA证书路径
|
||||||
|
CAKey: "path/to/ca.key", // CA私钥路径
|
||||||
|
|
||||||
|
// 选择证书生成算法(可选)
|
||||||
|
UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||||
|
|
||||||
|
// 方式二:使用固定的TLS证书和私钥
|
||||||
|
// TLSCert: "path/to/server.crt",
|
||||||
|
// TLSKey: "path/to/server.key",
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
或者使用函数式选项模式:
|
||||||
|
|
||||||
|
```go
|
||||||
|
proxy := proxy.NewProxy(
|
||||||
|
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
|
||||||
|
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
|
||||||
|
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
|
||||||
|
// 或者使用静态TLS证书
|
||||||
|
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
同时,建议配置证书缓存以提高性能:
|
||||||
|
|
||||||
|
```go
|
||||||
|
certCache := &proxy.MemCertCache{}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 扩展点
|
||||||
|
|
||||||
|
GoProxy提供了多个扩展点,可以通过实现相应的接口进行扩展:
|
||||||
|
|
||||||
|
- **Delegate**:代理委托接口,用于自定义代理处理逻辑
|
||||||
|
- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法
|
||||||
|
- **Cache**:缓存接口,用于实现自定义缓存策略
|
||||||
|
- **CertificateCache**:证书缓存接口,用于自定义证书存储方式
|
||||||
|
- **Metrics**:监控接口,用于实现自定义监控指标收集
|
||||||
|
|
||||||
|
## 反向代理特性
|
||||||
|
|
||||||
|
GoProxy的反向代理模式提供以下特性:
|
||||||
|
|
||||||
|
- **URL重写**:支持基于前缀和正则表达式的URL重写
|
||||||
|
- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则
|
||||||
|
- **请求修改**:支持修改发往后端服务器的请求
|
||||||
|
- **响应修改**:支持修改来自后端服务器的响应
|
||||||
|
- **保留客户端信息**:支持添加X-Forwarded-For和X-Real-IP头
|
||||||
|
- **CORS支持**:支持自动添加CORS头
|
||||||
|
- **WebSocket支持**:支持WebSocket协议的透明代理
|
||||||
|
- **负载均衡**:支持多种负载均衡算法
|
||||||
|
- **健康检查**:支持对后端服务器进行健康检查
|
||||||
|
- **监控指标**:支持收集反向代理的监控指标
|
||||||
|
|
||||||
|
## 贡献
|
||||||
|
|
||||||
|
欢迎贡献代码、报告问题或提出建议。请遵循以下步骤:
|
||||||
|
|
||||||
|
1. Fork 项目
|
||||||
|
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
|
||||||
|
3. 提交更改 (`git commit -m 'Add some amazing feature'`)
|
||||||
|
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||||
|
5. 创建 Pull Request
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。
|
73
SUMMARY.md
Normal file
73
SUMMARY.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# GoProxy 项目总结
|
||||||
|
|
||||||
|
## 项目概述
|
||||||
|
|
||||||
|
GoProxy 是一个功能强大的 Go 语言 HTTP 代理库,支持 HTTP、HTTPS 和 WebSocket 代理,并提供了丰富的功能和扩展点。该项目采用模块化设计,各个模块之间职责明确,耦合度低,便于扩展和维护。
|
||||||
|
|
||||||
|
## 已完成的模块
|
||||||
|
|
||||||
|
1. **配置模块(config)**:提供代理配置选项,包括连接池大小、超时时间、是否启用缓存、负载均衡等。
|
||||||
|
|
||||||
|
2. **代理上下文(context)**:保存请求上下文信息,用于在代理处理过程中传递数据,包括原始请求、目标地址、上级代理等。
|
||||||
|
|
||||||
|
3. **代理委托(delegate)**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑,包括连接、认证、请求前、响应前等事件。
|
||||||
|
|
||||||
|
4. **连接缓冲区(conn_buffer)**:封装网络连接,提供缓冲读写功能,简化网络 IO 操作。
|
||||||
|
|
||||||
|
5. **负载均衡(loadbalance)**:实现负载均衡算法,支持轮询、随机、权重等,自动选择合适的上游服务器。
|
||||||
|
|
||||||
|
6. **健康检查(healthcheck)**:检查上游服务器的健康状态,自动剔除不健康的服务器,提高代理可靠性。
|
||||||
|
|
||||||
|
7. **缓存(cache)**:实现 HTTP 缓存,减少重复请求,提高代理性能。
|
||||||
|
|
||||||
|
8. **缓存适配器(cache_adapter)**:统一不同缓存实现的接口,使用适配器模式提高代码可读性和执行效率,支持多种缓存实现方式。
|
||||||
|
|
||||||
|
9. **证书生成(cert_generator)**:动态生成TLS证书,支持HTTPS解密(中间人模式),可基于CA证书和私钥创建域名证书,并支持证书缓存。支持RSA和ECDSA两种算法,用户可根据需要选择安全性和性能的平衡。
|
||||||
|
|
||||||
|
10. **限流(ratelimit)**:实现请求限流,防止过载,保护上游服务器。
|
||||||
|
|
||||||
|
11. **监控(metrics)**:收集代理运行指标,用于监控和分析,包括请求数、响应时间、错误数等。
|
||||||
|
|
||||||
|
12. **重试(retry)**:实现请求重试,提高请求成功率,处理临时性故障。
|
||||||
|
|
||||||
|
13. **代理核心(proxy)**:处理 HTTP 请求和响应,实现代理功能,包括 HTTP、HTTPS 和 WebSocket 代理。
|
||||||
|
|
||||||
|
## 项目特点
|
||||||
|
|
||||||
|
1. **模块化设计**:各个模块职责明确,耦合度低,便于扩展和维护。
|
||||||
|
|
||||||
|
2. **丰富的功能**:支持 HTTP、HTTPS 和 WebSocket 代理,并提供了负载均衡、健康检查、缓存、限流、监控等功能。
|
||||||
|
|
||||||
|
3. **灵活的扩展点**:提供了多个扩展点,可以通过实现相应的接口进行扩展,如代理委托、负载均衡、缓存、监控等。
|
||||||
|
|
||||||
|
4. **高性能**:采用 Go 语言的并发特性,实现高性能的代理服务。优化的缓存适配器和证书缓存机制进一步提升性能。
|
||||||
|
|
||||||
|
5. **可靠性**:通过健康检查、重试等机制,提高代理的可靠性。
|
||||||
|
|
||||||
|
6. **可观测性**:通过监控指标收集,提高代理的可观测性。
|
||||||
|
|
||||||
|
7. **安全性**:支持HTTPS解密功能,可用于调试、安全审计和内容过滤。提供RSA和ECDSA双算法选择,满足不同的安全需求和性能场景。
|
||||||
|
|
||||||
|
8. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持,以及对更多密钥算法(如Ed25519等)的支持。
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
我们提供了一个完整的示例程序,展示了如何使用 GoProxy 库创建一个功能完善的代理服务器,包括负载均衡、健康检查、缓存、监控等功能。另外还提供了启用HTTPS解密功能的示例,展示如何使用CA证书动态生成站点证书。
|
||||||
|
|
||||||
|
## 未来计划
|
||||||
|
|
||||||
|
1. **完善文档**:编写更详细的文档,包括 API 文档、使用示例、最佳实践等。
|
||||||
|
|
||||||
|
2. **增加测试**:增加单元测试和集成测试,提高代码质量和可靠性。
|
||||||
|
|
||||||
|
3. **性能优化**:进一步优化代理性能,减少资源消耗。
|
||||||
|
|
||||||
|
4. **增加更多功能**:如请求过滤、内容修改、安全检查等。
|
||||||
|
|
||||||
|
5. **提供更多扩展点**:如请求路由、请求转换等。
|
||||||
|
|
||||||
|
6. **支持更多证书格式**:增加对PEM、DER等更多证书格式的支持。
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
GoProxy 是一个功能强大、设计良好的 Go 语言 HTTP 代理库,可以满足各种代理需求,如开发调试、负载均衡、API 网关等。通过模块化设计和丰富的扩展点,GoProxy 可以灵活地适应各种场景,是构建代理服务的理想选择。最近的增强包括缓存适配器优化和完整的HTTPS解密功能实现,使GoProxy更加完善和高效。
|
165
cmd/example/main.go
Normal file
165
cmd/example/main.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/cache"
|
||||||
|
"github.com/goproxy/internal/config"
|
||||||
|
"github.com/goproxy/internal/healthcheck"
|
||||||
|
"github.com/goproxy/internal/loadbalance"
|
||||||
|
"github.com/goproxy/internal/metrics"
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// 监听地址
|
||||||
|
addr = flag.String("addr", ":8080", "代理服务器监听地址")
|
||||||
|
// 上游代理服务器
|
||||||
|
upstream = flag.String("upstream", "", "上游代理服务器地址,多个地址用逗号分隔")
|
||||||
|
// 是否启用负载均衡
|
||||||
|
enableLoadBalance = flag.Bool("enable-lb", false, "是否启用负载均衡")
|
||||||
|
// 是否启用健康检查
|
||||||
|
enableHealthCheck = flag.Bool("enable-hc", false, "是否启用健康检查")
|
||||||
|
// 是否启用缓存
|
||||||
|
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
|
||||||
|
// 是否启用重试
|
||||||
|
enableRetry = flag.Bool("enable-retry", false, "是否启用重试")
|
||||||
|
// 是否启用监控
|
||||||
|
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
|
||||||
|
// 监控地址
|
||||||
|
metricsAddr = flag.String("metrics-addr", ":8081", "监控服务器监听地址")
|
||||||
|
)
|
||||||
|
|
||||||
|
// 解析目标地址
|
||||||
|
func parseTargets(targets string) []string {
|
||||||
|
if targets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return strings.Split(targets, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 解析命令行参数
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// 创建配置
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.EnableLoadBalancing = *enableLoadBalance
|
||||||
|
cfg.EnableHealthCheck = *enableHealthCheck
|
||||||
|
cfg.EnableCache = *enableCache
|
||||||
|
cfg.EnableRetry = *enableRetry
|
||||||
|
|
||||||
|
// 创建选项
|
||||||
|
opts := &proxy.Options{
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建负载均衡器
|
||||||
|
if *enableLoadBalance && *upstream != "" {
|
||||||
|
lb := loadbalance.NewRoundRobinBalancer()
|
||||||
|
for _, target := range parseTargets(*upstream) {
|
||||||
|
lb.Add(target, 1)
|
||||||
|
}
|
||||||
|
opts.LoadBalancer = lb
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建健康检查器
|
||||||
|
if *enableHealthCheck && opts.LoadBalancer != nil {
|
||||||
|
hc := healthcheck.NewHealthChecker(&healthcheck.Config{
|
||||||
|
Interval: time.Second * 10,
|
||||||
|
Timeout: time.Second * 2,
|
||||||
|
MaxFails: 3,
|
||||||
|
MinSuccess: 2,
|
||||||
|
})
|
||||||
|
opts.HealthChecker = hc
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建缓存
|
||||||
|
if *enableCache {
|
||||||
|
c := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000)
|
||||||
|
opts.HTTPCache = c
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建监控
|
||||||
|
if *enableMetrics {
|
||||||
|
m := metrics.NewSimpleMetrics()
|
||||||
|
opts.Metrics = m
|
||||||
|
|
||||||
|
// 启动监控服务器
|
||||||
|
go func() {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
handler := m.GetHandler()
|
||||||
|
mux.Handle("/metrics", handler)
|
||||||
|
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
|
||||||
|
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
|
||||||
|
log.Fatalf("监控服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建自定义委托
|
||||||
|
delegate := &CustomDelegate{}
|
||||||
|
opts.Delegate = delegate
|
||||||
|
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(opts)
|
||||||
|
|
||||||
|
// 创建HTTP服务器
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
Handler: p,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
go func() {
|
||||||
|
log.Printf("代理服务器启动在 %s\n", *addr)
|
||||||
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待信号
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
|
||||||
|
log.Println("正在关闭代理服务器...")
|
||||||
|
server.Close()
|
||||||
|
log.Println("代理服务器已关闭")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CustomDelegate 自定义委托
|
||||||
|
type CustomDelegate struct {
|
||||||
|
proxy.DefaultDelegate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 连接事件
|
||||||
|
func (d *CustomDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
|
||||||
|
log.Printf("收到连接: %s -> %s\n", ctx.Req.RemoteAddr, ctx.Req.URL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRequest 请求前事件
|
||||||
|
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||||
|
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeResponse 响应前事件
|
||||||
|
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("响应错误: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorLog 错误日志
|
||||||
|
func (d *CustomDelegate) ErrorLog(err error) {
|
||||||
|
log.Printf("错误: %v\n", err)
|
||||||
|
}
|
185
cmd/reverse_proxy_example/main.go
Normal file
185
cmd/reverse_proxy_example/main.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/config"
|
||||||
|
"github.com/goproxy/internal/metrics"
|
||||||
|
"github.com/goproxy/internal/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// 监听地址
|
||||||
|
addr = flag.String("addr", ":8080", "反向代理服务器监听地址")
|
||||||
|
// 后端服务器
|
||||||
|
backend = flag.String("backend", "localhost:8081", "后端服务器地址")
|
||||||
|
// 路由规则文件
|
||||||
|
routeFile = flag.String("route-file", "", "路由规则文件路径")
|
||||||
|
// 是否启用URL重写
|
||||||
|
enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写")
|
||||||
|
// 是否启用缓存
|
||||||
|
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
|
||||||
|
// 是否启用压缩
|
||||||
|
enableCompression = flag.Bool("enable-compression", false, "是否启用压缩")
|
||||||
|
// 是否启用监控
|
||||||
|
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
|
||||||
|
// 监控地址
|
||||||
|
metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址")
|
||||||
|
// 是否添加X-Forwarded-For
|
||||||
|
addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头")
|
||||||
|
// 是否添加X-Real-IP
|
||||||
|
addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头")
|
||||||
|
// 是否启用CORS
|
||||||
|
enableCORS = flag.Bool("enable-cors", false, "是否启用CORS")
|
||||||
|
// 路径前缀
|
||||||
|
pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除")
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 解析命令行参数
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// 创建配置
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
cfg.ReverseProxy = true
|
||||||
|
cfg.EnableCache = *enableCache
|
||||||
|
cfg.EnableCompression = *enableCompression
|
||||||
|
cfg.EnableURLRewrite = *enableRewrite
|
||||||
|
cfg.AddXForwardedFor = *addXForwardedFor
|
||||||
|
cfg.AddXRealIP = *addXRealIP
|
||||||
|
cfg.EnableCORS = *enableCORS
|
||||||
|
cfg.ReverseProxyRulesFile = *routeFile
|
||||||
|
|
||||||
|
// 创建选项
|
||||||
|
opts := &proxy.Options{
|
||||||
|
Config: cfg,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建监控
|
||||||
|
if *enableMetrics {
|
||||||
|
m := metrics.NewSimpleMetrics()
|
||||||
|
opts.Metrics = m
|
||||||
|
|
||||||
|
// 启动监控服务器
|
||||||
|
go func() {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
handler := m.GetHandler()
|
||||||
|
mux.Handle("/metrics", handler)
|
||||||
|
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
|
||||||
|
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
|
||||||
|
log.Fatalf("监控服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建自定义委托
|
||||||
|
delegate := &ReverseProxyDelegate{
|
||||||
|
backend: *backend,
|
||||||
|
prefix: *pathPrefix,
|
||||||
|
}
|
||||||
|
opts.Delegate = delegate
|
||||||
|
|
||||||
|
// 创建代理
|
||||||
|
p := proxy.New(opts)
|
||||||
|
|
||||||
|
// 如果有路径前缀,添加重写规则
|
||||||
|
if *pathPrefix != "" {
|
||||||
|
reverseProxy := p.NewReverseProxy()
|
||||||
|
log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix)
|
||||||
|
reverseProxy.AddRewriteRule(*pathPrefix, "", false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建HTTP服务器
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: *addr,
|
||||||
|
Handler: p,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动HTTP服务器
|
||||||
|
go func() {
|
||||||
|
log.Printf("反向代理服务器启动在 %s,后端服务器为 %s\n", *addr, *backend)
|
||||||
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("代理服务器启动失败: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待信号
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
|
||||||
|
log.Println("正在关闭代理服务器...")
|
||||||
|
server.Close()
|
||||||
|
log.Println("代理服务器已关闭")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReverseProxyDelegate 反向代理委托
|
||||||
|
type ReverseProxyDelegate struct {
|
||||||
|
proxy.DefaultDelegate
|
||||||
|
backend string
|
||||||
|
prefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBackend 解析后端服务器
|
||||||
|
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||||
|
// 这里可以实现基于请求路径、主机名等的路由逻辑
|
||||||
|
return d.backend, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyRequest 修改请求
|
||||||
|
func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) {
|
||||||
|
// 移除路径前缀
|
||||||
|
if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) {
|
||||||
|
req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix)
|
||||||
|
if req.URL.Path == "" {
|
||||||
|
req.URL.Path = "/"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加自定义请求头
|
||||||
|
req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyResponse 修改响应
|
||||||
|
func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error {
|
||||||
|
// 添加自定义响应头
|
||||||
|
resp.Header.Set("X-Proxied-By", "GoProxy")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect 连接事件
|
||||||
|
func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
|
||||||
|
log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRequest 请求前事件
|
||||||
|
func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) {
|
||||||
|
log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeResponse 响应前事件
|
||||||
|
func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("响应错误: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorLog 错误日志
|
||||||
|
func (d *ReverseProxyDelegate) ErrorLog(err error) {
|
||||||
|
log.Printf("错误: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleError 处理错误
|
||||||
|
func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
log.Printf("处理错误: %v\n", err)
|
||||||
|
http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
132
delegate.go
Normal file
132
delegate.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
// Copyright 2018 ouqiang authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
package goproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context 代理上下文
|
||||||
|
type Context struct {
|
||||||
|
Req *http.Request
|
||||||
|
Data map[interface{}]interface{}
|
||||||
|
TunnelProxy bool
|
||||||
|
abort bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) IsHTTPS() bool {
|
||||||
|
return c.Req.URL.Scheme == "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultPorts = map[string]string{
|
||||||
|
"https": "443",
|
||||||
|
"http": "80",
|
||||||
|
"": "80",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) WebsocketUrl() *url.URL {
|
||||||
|
u := new(url.URL)
|
||||||
|
*u = *c.Req.URL
|
||||||
|
if c.IsHTTPS() {
|
||||||
|
u.Scheme = "wss"
|
||||||
|
} else {
|
||||||
|
u.Scheme = "ws"
|
||||||
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) Addr() string {
|
||||||
|
addr := c.Req.Host
|
||||||
|
|
||||||
|
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||||
|
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort 中断执行
|
||||||
|
func (c *Context) Abort() {
|
||||||
|
c.abort = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAborted 是否已中断执行
|
||||||
|
func (c *Context) IsAborted() bool {
|
||||||
|
return c.abort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置
|
||||||
|
func (c *Context) Reset(req *http.Request) {
|
||||||
|
c.Req = req
|
||||||
|
c.Data = make(map[interface{}]interface{})
|
||||||
|
c.abort = false
|
||||||
|
c.TunnelProxy = false
|
||||||
|
}
|
||||||
|
|
||||||
|
type Delegate interface {
|
||||||
|
// Connect 收到客户端连接
|
||||||
|
Connect(ctx *Context, rw http.ResponseWriter)
|
||||||
|
// Auth 代理身份认证
|
||||||
|
Auth(ctx *Context, rw http.ResponseWriter)
|
||||||
|
// BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body
|
||||||
|
BeforeRequest(ctx *Context)
|
||||||
|
// BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code
|
||||||
|
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||||
|
// WebSocketSendMessage websocket发送消息
|
||||||
|
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||||
|
// WebSockerReceiveMessage websocket接收 消息
|
||||||
|
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||||
|
// ParentProxy 上级代理
|
||||||
|
ParentProxy(*http.Request) (*url.URL, error)
|
||||||
|
// Finish 本次请求结束
|
||||||
|
Finish(ctx *Context)
|
||||||
|
// 记录错误信息
|
||||||
|
ErrorLog(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Delegate = &DefaultDelegate{}
|
||||||
|
|
||||||
|
// DefaultDelegate 默认Handler什么也不做
|
||||||
|
type DefaultDelegate struct {
|
||||||
|
Delegate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) BeforeRequest(ctx *Context) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||||
|
return http.ProxyFromEnvironment(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketSendMessage websocket发送消息
|
||||||
|
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||||
|
|
||||||
|
// WebSockerReceiveMessage websocket接收 消息
|
||||||
|
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) Finish(ctx *Context) {}
|
||||||
|
|
||||||
|
func (h *DefaultDelegate) ErrorLog(err error) {
|
||||||
|
log.Println(err)
|
||||||
|
}
|
10
go.mod
Normal file
10
go.mod
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
module github.com/goproxy
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/ouqiang/goproxy v1.3.2
|
||||||
|
github.com/ouqiang/websocket v1.6.2
|
||||||
|
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
|
||||||
|
golang.org/x/time v0.11.0
|
||||||
|
)
|
8
go.sum
Normal file
8
go.sum
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk=
|
||||||
|
github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww=
|
||||||
|
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
|
||||||
|
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
|
||||||
|
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
|
||||||
|
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
|
||||||
|
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
|
||||||
|
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
220
internal/cache/cache.go
vendored
Normal file
220
internal/cache/cache.go
vendored
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache 缓存接口
|
||||||
|
type Cache interface {
|
||||||
|
// Get 获取缓存
|
||||||
|
Get(key string) (*http.Response, bool)
|
||||||
|
// Set 设置缓存
|
||||||
|
Set(key string, resp *http.Response)
|
||||||
|
// Delete 删除缓存
|
||||||
|
Delete(key string)
|
||||||
|
// Clear 清空缓存
|
||||||
|
Clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryCache 内存缓存实现
|
||||||
|
type MemoryCache struct {
|
||||||
|
// 缓存内容
|
||||||
|
items sync.Map
|
||||||
|
// 过期时间
|
||||||
|
ttl time.Duration
|
||||||
|
// 清理间隔
|
||||||
|
cleanupInterval time.Duration
|
||||||
|
// 最大条目数
|
||||||
|
maxEntries int
|
||||||
|
// 当前条目数
|
||||||
|
size int32
|
||||||
|
// 互斥锁
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheItem 缓存项
|
||||||
|
type CacheItem struct {
|
||||||
|
response *http.Response
|
||||||
|
responseBody []byte
|
||||||
|
expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryCache 创建内存缓存
|
||||||
|
func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache {
|
||||||
|
cache := &MemoryCache{
|
||||||
|
ttl: ttl,
|
||||||
|
cleanupInterval: cleanupInterval,
|
||||||
|
maxEntries: maxEntries,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动过期清理
|
||||||
|
if cleanupInterval > 0 {
|
||||||
|
go cache.startCleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取缓存
|
||||||
|
func (c *MemoryCache) Get(key string) (*http.Response, bool) {
|
||||||
|
value, ok := c.items.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
item := value.(*CacheItem)
|
||||||
|
if time.Now().After(item.expiry) {
|
||||||
|
c.Delete(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 克隆响应,避免修改原始数据
|
||||||
|
resp := cloneResponse(item.response, item.responseBody)
|
||||||
|
return resp, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置缓存
|
||||||
|
func (c *MemoryCache) Set(key string, resp *http.Response) {
|
||||||
|
// 检查缓存是否已满
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.maxEntries > 0 && c.size >= int32(c.maxEntries) {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.size++
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// 读取并保存响应体
|
||||||
|
var bodyBytes []byte
|
||||||
|
if resp.Body != nil {
|
||||||
|
bodyBytes, _ = io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
item := &CacheItem{
|
||||||
|
response: resp,
|
||||||
|
responseBody: bodyBytes,
|
||||||
|
expiry: time.Now().Add(c.ttl),
|
||||||
|
}
|
||||||
|
|
||||||
|
c.items.Store(key, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 删除缓存
|
||||||
|
func (c *MemoryCache) Delete(key string) {
|
||||||
|
c.items.Delete(key)
|
||||||
|
c.mu.Lock()
|
||||||
|
c.size--
|
||||||
|
if c.size < 0 {
|
||||||
|
c.size = 0
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear 清空缓存
|
||||||
|
func (c *MemoryCache) Clear() {
|
||||||
|
c.items = sync.Map{}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.size = 0
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCleanup 启动过期清理
|
||||||
|
func (c *MemoryCache) startCleanup() {
|
||||||
|
ticker := time.NewTicker(c.cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
c.items.Range(func(key, value interface{}) bool {
|
||||||
|
item := value.(*CacheItem)
|
||||||
|
if now.After(item.expiry) {
|
||||||
|
c.Delete(key.(string))
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCacheKey 生成缓存键
|
||||||
|
func GenerateCacheKey(req *http.Request) string {
|
||||||
|
// 忽略一些可变的头部
|
||||||
|
ignoredHeaders := map[string]bool{
|
||||||
|
"Connection": true,
|
||||||
|
"Keep-Alive": true,
|
||||||
|
"Proxy-Authenticate": true,
|
||||||
|
"Proxy-Authorization": true,
|
||||||
|
"TE": true,
|
||||||
|
"Trailers": true,
|
||||||
|
"Transfer-Encoding": true,
|
||||||
|
"Upgrade": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取缓存键组件
|
||||||
|
components := []string{
|
||||||
|
req.Method,
|
||||||
|
req.URL.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加选择性头部
|
||||||
|
for key, values := range req.Header {
|
||||||
|
if !ignoredHeaders[key] {
|
||||||
|
for _, value := range values {
|
||||||
|
components = append(components, key+":"+value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接并计算哈希
|
||||||
|
data := strings.Join(components, "|")
|
||||||
|
hash := md5.New()
|
||||||
|
hash.Write([]byte(data))
|
||||||
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneResponse 克隆HTTP响应
|
||||||
|
func cloneResponse(resp *http.Response, body []byte) *http.Response {
|
||||||
|
clone := *resp
|
||||||
|
clone.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||||
|
clone.Header = make(http.Header)
|
||||||
|
|
||||||
|
for k, v := range resp.Header {
|
||||||
|
clone.Header[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldCache 判断请求是否应该缓存
|
||||||
|
func ShouldCache(req *http.Request, resp *http.Response) bool {
|
||||||
|
// 只缓存GET请求
|
||||||
|
if req.Method != http.MethodGet {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查响应状态码
|
||||||
|
if resp.StatusCode != http.StatusOK &&
|
||||||
|
resp.StatusCode != http.StatusNotModified &&
|
||||||
|
resp.StatusCode != http.StatusMovedPermanently &&
|
||||||
|
resp.StatusCode != http.StatusPermanentRedirect {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查Cache-Control头
|
||||||
|
cacheControl := resp.Header.Get("Cache-Control")
|
||||||
|
if strings.Contains(cacheControl, "no-store") ||
|
||||||
|
strings.Contains(cacheControl, "no-cache") ||
|
||||||
|
strings.Contains(cacheControl, "private") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
135
internal/config/config.go
Normal file
135
internal/config/config.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config 代理配置
|
||||||
|
type Config struct {
|
||||||
|
// 监听地址
|
||||||
|
ListenAddr string
|
||||||
|
// 是否启用负载均衡
|
||||||
|
EnableLoadBalancing bool
|
||||||
|
// 负载均衡后端列表
|
||||||
|
Backends []string
|
||||||
|
// 是否启用限流
|
||||||
|
EnableRateLimit bool
|
||||||
|
// 每秒请求速率限制
|
||||||
|
RateLimit float64
|
||||||
|
// 并发请求峰值限制
|
||||||
|
MaxBurst int
|
||||||
|
// 最大连接数
|
||||||
|
MaxConnections int
|
||||||
|
// 是否启用连接池
|
||||||
|
EnableConnectionPool bool
|
||||||
|
// 连接池大小
|
||||||
|
ConnectionPoolSize int
|
||||||
|
// 连接空闲超时时间
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
// 请求超时时间
|
||||||
|
RequestTimeout time.Duration
|
||||||
|
// 是否启用响应缓存
|
||||||
|
EnableCache bool
|
||||||
|
// 缓存过期时间
|
||||||
|
CacheTTL time.Duration
|
||||||
|
// 是否启用HTTPS解密
|
||||||
|
DecryptHTTPS bool
|
||||||
|
// TLS证书文件路径
|
||||||
|
TLSCert string
|
||||||
|
// TLS密钥文件路径
|
||||||
|
TLSKey string
|
||||||
|
// CA证书文件路径(用于生成动态证书)
|
||||||
|
CACert string
|
||||||
|
// CA密钥文件路径(用于生成动态证书)
|
||||||
|
CAKey string
|
||||||
|
// 是否启用健康检查
|
||||||
|
EnableHealthCheck bool
|
||||||
|
// 健康检查间隔时间
|
||||||
|
HealthCheckInterval time.Duration
|
||||||
|
// 健康检查超时时间
|
||||||
|
HealthCheckTimeout time.Duration
|
||||||
|
// 是否启用重试机制
|
||||||
|
EnableRetry bool
|
||||||
|
// 最大重试次数
|
||||||
|
MaxRetries int
|
||||||
|
// 重试间隔基数
|
||||||
|
RetryBackoff time.Duration
|
||||||
|
// 最大重试间隔
|
||||||
|
MaxRetryBackoff time.Duration
|
||||||
|
// 是否启用监控指标
|
||||||
|
EnableMetrics bool
|
||||||
|
// 是否启用请求追踪
|
||||||
|
EnableTracing bool
|
||||||
|
// 是否拦截WebSocket
|
||||||
|
WebSocketIntercept bool
|
||||||
|
// DNS缓存过期时间
|
||||||
|
DNSCacheTTL time.Duration
|
||||||
|
// 是否作为反向代理
|
||||||
|
ReverseProxy bool
|
||||||
|
// 反向代理规则文件路径
|
||||||
|
ReverseProxyRulesFile string
|
||||||
|
// 是否启用URL重写
|
||||||
|
EnableURLRewrite bool
|
||||||
|
// 是否保留客户端IP
|
||||||
|
PreserveClientIP bool
|
||||||
|
// 是否启用压缩
|
||||||
|
EnableCompression bool
|
||||||
|
// 是否自动添加CORS头
|
||||||
|
EnableCORS bool
|
||||||
|
// 重写Host头
|
||||||
|
RewriteHostHeader bool
|
||||||
|
// 是否添加X-Forwarded-For头
|
||||||
|
AddXForwardedFor bool
|
||||||
|
// 是否添加X-Real-IP头
|
||||||
|
AddXRealIP bool
|
||||||
|
// 是否支持Websocket升级
|
||||||
|
SupportWebSocketUpgrade bool
|
||||||
|
// 是否使用ECDSA生成证书(默认使用RSA)
|
||||||
|
UseECDSA bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig 默认配置
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
ListenAddr: ":8080",
|
||||||
|
EnableLoadBalancing: false,
|
||||||
|
Backends: []string{},
|
||||||
|
EnableRateLimit: false,
|
||||||
|
RateLimit: 100,
|
||||||
|
MaxBurst: 50,
|
||||||
|
MaxConnections: 1000,
|
||||||
|
EnableConnectionPool: true,
|
||||||
|
ConnectionPoolSize: 100,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
RequestTimeout: 30 * time.Second,
|
||||||
|
EnableCache: false,
|
||||||
|
CacheTTL: 5 * time.Minute,
|
||||||
|
DecryptHTTPS: false,
|
||||||
|
TLSCert: "",
|
||||||
|
TLSKey: "",
|
||||||
|
CACert: "",
|
||||||
|
CAKey: "",
|
||||||
|
EnableHealthCheck: false,
|
||||||
|
HealthCheckInterval: 10 * time.Second,
|
||||||
|
HealthCheckTimeout: 5 * time.Second,
|
||||||
|
EnableRetry: true,
|
||||||
|
MaxRetries: 3,
|
||||||
|
RetryBackoff: 100 * time.Millisecond,
|
||||||
|
MaxRetryBackoff: 2 * time.Second,
|
||||||
|
EnableMetrics: false,
|
||||||
|
EnableTracing: false,
|
||||||
|
WebSocketIntercept: false,
|
||||||
|
DNSCacheTTL: 5 * time.Minute,
|
||||||
|
ReverseProxy: false,
|
||||||
|
ReverseProxyRulesFile: "",
|
||||||
|
EnableURLRewrite: false,
|
||||||
|
PreserveClientIP: true,
|
||||||
|
EnableCompression: false,
|
||||||
|
EnableCORS: false,
|
||||||
|
RewriteHostHeader: false,
|
||||||
|
AddXForwardedFor: true,
|
||||||
|
AddXRealIP: true,
|
||||||
|
SupportWebSocketUpgrade: true,
|
||||||
|
UseECDSA: false,
|
||||||
|
}
|
||||||
|
}
|
248
internal/healthcheck/healthchecker.go
Normal file
248
internal/healthcheck/healthchecker.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package healthcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthChecker 健康检查器
|
||||||
|
type HealthChecker struct {
|
||||||
|
// 配置
|
||||||
|
config *Config
|
||||||
|
// 健康检查状态
|
||||||
|
statusMap sync.Map
|
||||||
|
// 状态变更回调
|
||||||
|
statusChangeCallback func(string, bool)
|
||||||
|
// 是否运行中
|
||||||
|
running bool
|
||||||
|
// 上下文
|
||||||
|
ctx context.Context
|
||||||
|
// 取消函数
|
||||||
|
cancel context.CancelFunc
|
||||||
|
// 互斥锁
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config 健康检查配置
|
||||||
|
type Config struct {
|
||||||
|
// 检查间隔
|
||||||
|
Interval time.Duration
|
||||||
|
// 检查超时
|
||||||
|
Timeout time.Duration
|
||||||
|
// 检查路径
|
||||||
|
Path string
|
||||||
|
// 检查方法
|
||||||
|
Method string
|
||||||
|
// 检查状态码
|
||||||
|
SuccessStatus int
|
||||||
|
// 最大失败次数
|
||||||
|
MaxFails int
|
||||||
|
// 最小成功次数
|
||||||
|
MinSuccess int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHealthChecker 创建健康检查器
|
||||||
|
func NewHealthChecker(config *Config) *HealthChecker {
|
||||||
|
if config.Path == "" {
|
||||||
|
config.Path = "/"
|
||||||
|
}
|
||||||
|
if config.Method == "" {
|
||||||
|
config.Method = http.MethodGet
|
||||||
|
}
|
||||||
|
if config.SuccessStatus == 0 {
|
||||||
|
config.SuccessStatus = http.StatusOK
|
||||||
|
}
|
||||||
|
if config.MaxFails == 0 {
|
||||||
|
config.MaxFails = 3
|
||||||
|
}
|
||||||
|
if config.MinSuccess == 0 {
|
||||||
|
config.MinSuccess = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &HealthChecker{
|
||||||
|
config: config,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
running: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start 启动健康检查
|
||||||
|
func (hc *HealthChecker) Start() {
|
||||||
|
hc.mu.Lock()
|
||||||
|
defer hc.mu.Unlock()
|
||||||
|
|
||||||
|
if hc.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.running = true
|
||||||
|
go hc.run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop 停止健康检查
|
||||||
|
func (hc *HealthChecker) Stop() {
|
||||||
|
hc.mu.Lock()
|
||||||
|
defer hc.mu.Unlock()
|
||||||
|
|
||||||
|
if !hc.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.cancel()
|
||||||
|
hc.running = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTarget 添加监控目标
|
||||||
|
func (hc *HealthChecker) AddTarget(target string) error {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化为健康状态
|
||||||
|
hc.statusMap.Store(u.String(), &backendStatus{
|
||||||
|
URL: u,
|
||||||
|
Healthy: true,
|
||||||
|
FailCount: 0,
|
||||||
|
SuccessCount: 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveTarget 移除监控目标
|
||||||
|
func (hc *HealthChecker) RemoveTarget(target string) error {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.statusMap.Delete(u.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHealthy 检查目标是否健康
|
||||||
|
func (hc *HealthChecker) IsHealthy(target string) bool {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
value, ok := hc.statusMap.Load(u.String())
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
status := value.(*backendStatus)
|
||||||
|
return status.Healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetStatusChangeCallback 设置状态变更回调
|
||||||
|
func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) {
|
||||||
|
hc.statusChangeCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// backendStatus 后端健康状态
|
||||||
|
type backendStatus struct {
|
||||||
|
URL *url.URL
|
||||||
|
Healthy bool
|
||||||
|
FailCount int
|
||||||
|
SuccessCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// run 运行健康检查
|
||||||
|
func (hc *HealthChecker) run() {
|
||||||
|
ticker := time.NewTicker(hc.config.Interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-hc.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
hc.checkAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkAll 检查所有后端
|
||||||
|
func (hc *HealthChecker) checkAll() {
|
||||||
|
hc.statusMap.Range(func(key, value interface{}) bool {
|
||||||
|
go hc.check(key.(string), value.(*backendStatus))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// check 检查单个后端
|
||||||
|
func (hc *HealthChecker) check(key string, status *backendStatus) {
|
||||||
|
// 创建检查请求
|
||||||
|
u := *status.URL
|
||||||
|
u.Path = hc.config.Path
|
||||||
|
req, err := http.NewRequest(hc.config.Method, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
hc.updateStatus(key, status, false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置超时的客户端
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: hc.config.Timeout,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
DialContext: (&net.Dialer{
|
||||||
|
Timeout: hc.config.Timeout,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}).DialContext,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ResponseHeaderTimeout: hc.config.Timeout,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
MaxIdleConns: 10,
|
||||||
|
IdleConnTimeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
hc.updateStatus(key, status, false)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 检查响应状态
|
||||||
|
if resp.StatusCode == hc.config.SuccessStatus {
|
||||||
|
hc.updateStatus(key, status, true)
|
||||||
|
} else {
|
||||||
|
hc.updateStatus(key, status, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStatus 更新后端状态
|
||||||
|
func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) {
|
||||||
|
if success {
|
||||||
|
status.SuccessCount++
|
||||||
|
status.FailCount = 0
|
||||||
|
if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess {
|
||||||
|
status.Healthy = true
|
||||||
|
if hc.statusChangeCallback != nil {
|
||||||
|
hc.statusChangeCallback(key, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
status.FailCount++
|
||||||
|
status.SuccessCount = 0
|
||||||
|
if status.Healthy && status.FailCount >= hc.config.MaxFails {
|
||||||
|
status.Healthy = false
|
||||||
|
if hc.statusChangeCallback != nil {
|
||||||
|
hc.statusChangeCallback(key, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hc.statusMap.Store(key, status)
|
||||||
|
}
|
330
internal/loadbalance/loadbalancer.go
Normal file
330
internal/loadbalance/loadbalancer.go
Normal file
@@ -0,0 +1,330 @@
|
|||||||
|
package loadbalance
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Strategy 负载均衡策略
|
||||||
|
type Strategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// StrategyRoundRobin 轮询策略
|
||||||
|
StrategyRoundRobin Strategy = iota
|
||||||
|
// StrategyRandom 随机策略
|
||||||
|
StrategyRandom
|
||||||
|
// StrategyWeightedRoundRobin 加权轮询策略
|
||||||
|
StrategyWeightedRoundRobin
|
||||||
|
// StrategyIPHash IP哈希策略
|
||||||
|
StrategyIPHash
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadBalancer 负载均衡器接口
|
||||||
|
type LoadBalancer interface {
|
||||||
|
// Next 获取下一个后端
|
||||||
|
Next(key string) (*url.URL, error)
|
||||||
|
// Add 添加后端
|
||||||
|
Add(backend string, weight int) error
|
||||||
|
// Remove 删除后端
|
||||||
|
Remove(backend string) error
|
||||||
|
// MarkDown 标记后端为不可用
|
||||||
|
MarkDown(backend string) error
|
||||||
|
// MarkUp 标记后端为可用
|
||||||
|
MarkUp(backend string) error
|
||||||
|
// Reset 重置负载均衡器
|
||||||
|
Reset() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend 后端服务器
|
||||||
|
type Backend struct {
|
||||||
|
// URL 后端URL
|
||||||
|
URL *url.URL
|
||||||
|
// Weight 权重
|
||||||
|
Weight int
|
||||||
|
// Down 是否不可用
|
||||||
|
Down bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundRobinBalancer 轮询负载均衡器
|
||||||
|
type RoundRobinBalancer struct {
|
||||||
|
backends []*Backend
|
||||||
|
current int32
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoundRobinBalancer 创建轮询负载均衡器
|
||||||
|
func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||||
|
return &RoundRobinBalancer{
|
||||||
|
backends: make([]*Backend, 0),
|
||||||
|
current: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next 获取下一个后端
|
||||||
|
func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) {
|
||||||
|
lb.mutex.RLock()
|
||||||
|
defer lb.mutex.RUnlock()
|
||||||
|
|
||||||
|
if len(lb.backends) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算可用后端数量
|
||||||
|
var availableCount int
|
||||||
|
for _, backend := range lb.backends {
|
||||||
|
if !backend.Down {
|
||||||
|
availableCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if availableCount == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 循环直到找到可用后端
|
||||||
|
for i := 0; i < len(lb.backends); i++ {
|
||||||
|
idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends))
|
||||||
|
backend := lb.backends[idx]
|
||||||
|
if !backend.Down {
|
||||||
|
return backend.URL, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 添加后端
|
||||||
|
func (lb *RoundRobinBalancer) Add(backend string, weight int) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lb.backends = append(lb.backends, &Backend{
|
||||||
|
URL: url,
|
||||||
|
Weight: weight,
|
||||||
|
Down: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 删除后端
|
||||||
|
func (lb *RoundRobinBalancer) Remove(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkDown 标记后端为不可用
|
||||||
|
func (lb *RoundRobinBalancer) MarkDown(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
b.Down = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkUp 标记后端为可用
|
||||||
|
func (lb *RoundRobinBalancer) MarkUp(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
b.Down = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置负载均衡器
|
||||||
|
func (lb *RoundRobinBalancer) Reset() error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
lb.backends = make([]*Backend, 0)
|
||||||
|
atomic.StoreInt32(&lb.current, 0)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomBalancer 随机负载均衡器
|
||||||
|
type RandomBalancer struct {
|
||||||
|
backends []*Backend
|
||||||
|
mutex sync.RWMutex
|
||||||
|
rand *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRandomBalancer 创建随机负载均衡器
|
||||||
|
func NewRandomBalancer() *RandomBalancer {
|
||||||
|
source := rand.NewSource(time.Now().UnixNano())
|
||||||
|
random := rand.New(source)
|
||||||
|
return &RandomBalancer{
|
||||||
|
backends: make([]*Backend, 0),
|
||||||
|
rand: random,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next 获取下一个后端
|
||||||
|
func (lb *RandomBalancer) Next(key string) (*url.URL, error) {
|
||||||
|
lb.mutex.RLock()
|
||||||
|
defer lb.mutex.RUnlock()
|
||||||
|
|
||||||
|
if len(lb.backends) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算可用后端数量
|
||||||
|
var availableBackends []*Backend
|
||||||
|
for _, backend := range lb.backends {
|
||||||
|
if !backend.Down {
|
||||||
|
availableBackends = append(availableBackends, backend)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(availableBackends) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := lb.rand.Intn(len(availableBackends))
|
||||||
|
return availableBackends[idx].URL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 添加后端
|
||||||
|
func (lb *RandomBalancer) Add(backend string, weight int) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lb.backends = append(lb.backends, &Backend{
|
||||||
|
URL: url,
|
||||||
|
Weight: weight,
|
||||||
|
Down: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove 删除后端
|
||||||
|
func (lb *RandomBalancer) Remove(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkDown 标记后端为不可用
|
||||||
|
func (lb *RandomBalancer) MarkDown(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
b.Down = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkUp 标记后端为可用
|
||||||
|
func (lb *RandomBalancer) MarkUp(backend string) error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
url, err := url.Parse(backend)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range lb.backends {
|
||||||
|
if b.URL.String() == url.String() {
|
||||||
|
b.Down = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置负载均衡器
|
||||||
|
func (lb *RandomBalancer) Reset() error {
|
||||||
|
lb.mutex.Lock()
|
||||||
|
defer lb.mutex.Unlock()
|
||||||
|
|
||||||
|
lb.backends = make([]*Backend, 0)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
250
internal/metrics/metrics.go
Normal file
250
internal/metrics/metrics.go
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Metrics 监控指标接口
|
||||||
|
type Metrics interface {
|
||||||
|
// 增加请求计数
|
||||||
|
IncRequestCount()
|
||||||
|
// 增加错误计数
|
||||||
|
IncErrorCount(err error)
|
||||||
|
// 观察请求持续时间
|
||||||
|
ObserveRequestDuration(seconds float64)
|
||||||
|
// 增加活跃连接数
|
||||||
|
IncActiveConnections()
|
||||||
|
// 减少活跃连接数
|
||||||
|
DecActiveConnections()
|
||||||
|
// 设置后端健康状态
|
||||||
|
SetBackendHealth(backend string, healthy bool)
|
||||||
|
// 设置后端响应时间
|
||||||
|
SetBackendResponseTime(backend string, duration time.Duration)
|
||||||
|
// 观察请求字节数
|
||||||
|
ObserveRequestBytes(bytes int64)
|
||||||
|
// 观察响应字节数
|
||||||
|
ObserveResponseBytes(bytes int64)
|
||||||
|
// 添加传输字节数
|
||||||
|
AddBytesTransferred(direction string, bytes int64)
|
||||||
|
// 增加缓存命中计数
|
||||||
|
IncCacheHit()
|
||||||
|
// 获取指标处理器
|
||||||
|
GetHandler() http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleMetrics 简单指标实现
|
||||||
|
type SimpleMetrics struct {
|
||||||
|
// 请求计数
|
||||||
|
requestCount int64
|
||||||
|
// 错误计数
|
||||||
|
errorCount int64
|
||||||
|
// 活跃连接数
|
||||||
|
activeConnections int64
|
||||||
|
// 累计响应时间
|
||||||
|
totalResponseTime int64
|
||||||
|
// 传输字节数
|
||||||
|
bytesTransferred map[string]int64
|
||||||
|
// 后端健康状态
|
||||||
|
backendHealth map[string]bool
|
||||||
|
// 后端响应时间
|
||||||
|
backendResponseTime map[string]time.Duration
|
||||||
|
// 缓存命中计数
|
||||||
|
cacheHits int64
|
||||||
|
// 互斥锁
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleMetrics 创建简单指标
|
||||||
|
func NewSimpleMetrics() *SimpleMetrics {
|
||||||
|
return &SimpleMetrics{
|
||||||
|
bytesTransferred: make(map[string]int64),
|
||||||
|
backendHealth: make(map[string]bool),
|
||||||
|
backendResponseTime: make(map[string]time.Duration),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncRequestCount 增加请求计数
|
||||||
|
func (m *SimpleMetrics) IncRequestCount() {
|
||||||
|
atomic.AddInt64(&m.requestCount, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncErrorCount 增加错误计数
|
||||||
|
func (m *SimpleMetrics) IncErrorCount(err error) {
|
||||||
|
atomic.AddInt64(&m.errorCount, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObserveRequestDuration 观察请求持续时间
|
||||||
|
func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) {
|
||||||
|
nsec := int64(seconds * float64(time.Second))
|
||||||
|
atomic.AddInt64(&m.totalResponseTime, nsec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncActiveConnections 增加活跃连接数
|
||||||
|
func (m *SimpleMetrics) IncActiveConnections() {
|
||||||
|
atomic.AddInt64(&m.activeConnections, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecActiveConnections 减少活跃连接数
|
||||||
|
func (m *SimpleMetrics) DecActiveConnections() {
|
||||||
|
atomic.AddInt64(&m.activeConnections, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBackendHealth 设置后端健康状态
|
||||||
|
func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) {
|
||||||
|
m.backendHealth[backend] = healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBackendResponseTime 设置后端响应时间
|
||||||
|
func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) {
|
||||||
|
m.backendResponseTime[backend] = duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObserveRequestBytes 观察请求字节数
|
||||||
|
func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.bytesTransferred["request"] += bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// ObserveResponseBytes 观察响应字节数
|
||||||
|
func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.bytesTransferred["response"] += bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddBytesTransferred 添加传输字节数
|
||||||
|
func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.bytesTransferred[direction] += bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncCacheHit 增加缓存命中计数
|
||||||
|
func (m *SimpleMetrics) IncCacheHit() {
|
||||||
|
atomic.AddInt64(&m.cacheHits, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandler 获取指标处理器
|
||||||
|
func (m *SimpleMetrics) GetHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
|
||||||
|
// 输出基本指标
|
||||||
|
w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n"))
|
||||||
|
w.Write([]byte("# TYPE proxy_requests_total counter\n"))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount)))
|
||||||
|
|
||||||
|
w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n"))
|
||||||
|
w.Write([]byte("# TYPE proxy_errors_total counter\n"))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount)))
|
||||||
|
|
||||||
|
w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n"))
|
||||||
|
w.Write([]byte("# TYPE proxy_active_connections gauge\n"))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections)))
|
||||||
|
|
||||||
|
// 输出缓存命中数据
|
||||||
|
w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n"))
|
||||||
|
w.Write([]byte("# TYPE proxy_cache_hits_total counter\n"))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits)))
|
||||||
|
|
||||||
|
// 输出传输字节数
|
||||||
|
for direction, bytes := range m.bytesTransferred {
|
||||||
|
w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction)))
|
||||||
|
w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction)))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出后端健康状态
|
||||||
|
for backend, healthy := range m.backendHealth {
|
||||||
|
healthValue := 0
|
||||||
|
if healthy {
|
||||||
|
healthValue = 1
|
||||||
|
}
|
||||||
|
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n")))
|
||||||
|
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n")))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 输出后端响应时间
|
||||||
|
for backend, duration := range m.backendResponseTime {
|
||||||
|
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n")))
|
||||||
|
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n")))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 平均响应时间
|
||||||
|
if m.requestCount > 0 {
|
||||||
|
avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second)
|
||||||
|
w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n"))
|
||||||
|
w.Write([]byte("# TYPE proxy_average_response_time gauge\n"))
|
||||||
|
w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime)))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrometheusMetrics Prometheus指标实现
|
||||||
|
type PrometheusMetrics struct {
|
||||||
|
// 可以通过引入prometheus客户端库实现更完整的指标收集
|
||||||
|
// 此处省略具体实现
|
||||||
|
}
|
||||||
|
|
||||||
|
// MetricsMiddleware 指标中间件
|
||||||
|
type MetricsMiddleware struct {
|
||||||
|
metrics Metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMetricsMiddleware 创建指标中间件
|
||||||
|
func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware {
|
||||||
|
return &MetricsMiddleware{
|
||||||
|
metrics: metrics,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware 中间件处理函数
|
||||||
|
func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// 包装响应写入器,用于捕获状态码
|
||||||
|
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||||
|
|
||||||
|
// 继续处理请求
|
||||||
|
next.ServeHTTP(rw, r)
|
||||||
|
|
||||||
|
// 记录请求指标
|
||||||
|
duration := time.Since(start)
|
||||||
|
m.metrics.ObserveRequestDuration(duration.Seconds())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseWriter 包装的响应写入器
|
||||||
|
type responseWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
statusCode int
|
||||||
|
written int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader 写入状态码
|
||||||
|
func (rw *responseWriter) WriteHeader(statusCode int) {
|
||||||
|
rw.statusCode = statusCode
|
||||||
|
rw.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write 写入数据
|
||||||
|
func (rw *responseWriter) Write(b []byte) (int, error) {
|
||||||
|
n, err := rw.ResponseWriter.Write(b)
|
||||||
|
rw.written += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush 刷新数据
|
||||||
|
func (rw *responseWriter) Flush() {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
184
internal/middleware/ratelimiter.go
Normal file
184
internal/middleware/ratelimiter.go
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RateLimiter 限流器接口
|
||||||
|
type RateLimiter interface {
|
||||||
|
// Allow 检查请求是否允许通过
|
||||||
|
Allow(key string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleRateLimiter 简单限流器
|
||||||
|
type SimpleRateLimiter struct {
|
||||||
|
limiter *rate.Limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSimpleRateLimiter 创建简单限流器
|
||||||
|
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
|
||||||
|
return &SimpleRateLimiter{
|
||||||
|
limiter: rate.NewLimiter(rate.Limit(r), b),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow 检查请求是否允许通过
|
||||||
|
func (rl *SimpleRateLimiter) Allow(key string) bool {
|
||||||
|
return rl.limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPRateLimiter 按IP限流
|
||||||
|
type IPRateLimiter struct {
|
||||||
|
ips map[string]*rate.Limiter
|
||||||
|
mu sync.RWMutex
|
||||||
|
rate rate.Limit
|
||||||
|
burst int
|
||||||
|
cleanupInterval time.Duration
|
||||||
|
lastSeen map[string]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPRateLimiter 创建IP限流器
|
||||||
|
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
|
||||||
|
limiter := &IPRateLimiter{
|
||||||
|
ips: make(map[string]*rate.Limiter),
|
||||||
|
rate: rate.Limit(r),
|
||||||
|
burst: b,
|
||||||
|
cleanupInterval: cleanup,
|
||||||
|
lastSeen: make(map[string]time.Time),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动过期清理
|
||||||
|
if cleanup > 0 {
|
||||||
|
go limiter.startCleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCleanup 启动过期清理
|
||||||
|
func (rl *IPRateLimiter) startCleanup() {
|
||||||
|
ticker := time.NewTicker(rl.cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
rl.cleanup()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup 清理过期限流器
|
||||||
|
func (rl *IPRateLimiter) cleanup() {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for ip, lastSeen := range rl.lastSeen {
|
||||||
|
if now.Sub(lastSeen) > rl.cleanupInterval {
|
||||||
|
delete(rl.ips, ip)
|
||||||
|
delete(rl.lastSeen, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddIP 添加IP限流器
|
||||||
|
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
||||||
|
rl.ips[ip] = limiter
|
||||||
|
rl.lastSeen[ip] = time.Now()
|
||||||
|
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimiter 获取IP限流器
|
||||||
|
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
||||||
|
rl.mu.RLock()
|
||||||
|
limiter, exists := rl.ips[ip]
|
||||||
|
rl.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return rl.AddIP(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新最后访问时间
|
||||||
|
rl.mu.Lock()
|
||||||
|
rl.lastSeen[ip] = time.Now()
|
||||||
|
rl.mu.Unlock()
|
||||||
|
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow 检查请求是否允许通过
|
||||||
|
func (rl *IPRateLimiter) Allow(ip string) bool {
|
||||||
|
limiter := rl.GetLimiter(ip)
|
||||||
|
return limiter.Allow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitMiddleware 限流中间件
|
||||||
|
type RateLimitMiddleware struct {
|
||||||
|
limiter RateLimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateLimitMiddleware 创建限流中间件
|
||||||
|
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
|
||||||
|
return &RateLimitMiddleware{
|
||||||
|
limiter: limiter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware 中间件处理函数
|
||||||
|
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 获取客户端IP
|
||||||
|
ip := getClientIP(r)
|
||||||
|
|
||||||
|
// 检查是否允许通过
|
||||||
|
if !m.limiter.Allow(ip) {
|
||||||
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 继续处理请求
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientIP 获取客户端IP
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
// 检查 X-Forwarded-For 头
|
||||||
|
ip := r.Header.Get("X-Forwarded-For")
|
||||||
|
if ip != "" {
|
||||||
|
// 取第一个IP
|
||||||
|
for i := 0; i < len(ip) && i < 15; i++ {
|
||||||
|
if ip[i] == ',' {
|
||||||
|
ip = ip[:i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 X-Real-IP 头
|
||||||
|
ip = r.Header.Get("X-Real-IP")
|
||||||
|
if ip != "" {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 RemoteAddr 获取
|
||||||
|
if r.RemoteAddr != "" {
|
||||||
|
// 去掉端口部分
|
||||||
|
for i := 0; i < len(r.RemoteAddr); i++ {
|
||||||
|
if r.RemoteAddr[i] == ':' {
|
||||||
|
return r.RemoteAddr[:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
156
internal/middleware/retry.go
Normal file
156
internal/middleware/retry.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RetryPolicy 重试策略
|
||||||
|
type RetryPolicy struct {
|
||||||
|
// 最大重试次数
|
||||||
|
MaxRetries int
|
||||||
|
// 基础退避时间
|
||||||
|
BaseBackoff time.Duration
|
||||||
|
// 最大退避时间
|
||||||
|
MaxBackoff time.Duration
|
||||||
|
// 重试判断函数
|
||||||
|
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRetryPolicy 默认重试策略
|
||||||
|
func DefaultRetryPolicy() *RetryPolicy {
|
||||||
|
return &RetryPolicy{
|
||||||
|
MaxRetries: 3,
|
||||||
|
BaseBackoff: 100 * time.Millisecond,
|
||||||
|
MaxBackoff: 2 * time.Second,
|
||||||
|
ShouldRetry: defaultShouldRetry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultShouldRetry 默认重试判断
|
||||||
|
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
|
||||||
|
// 不重试非幂等请求
|
||||||
|
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查错误
|
||||||
|
if err != nil {
|
||||||
|
// 重试网络错误
|
||||||
|
if netErr, ok := err.(net.Error); ok {
|
||||||
|
return netErr.Temporary() || netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查响应状态码
|
||||||
|
if resp != nil {
|
||||||
|
// 重试服务器错误
|
||||||
|
return resp.StatusCode >= 500 && resp.StatusCode < 600
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryRoundTripper 重试HTTP传输
|
||||||
|
type RetryRoundTripper struct {
|
||||||
|
// 下一级传输
|
||||||
|
Next http.RoundTripper
|
||||||
|
// 重试策略
|
||||||
|
Policy *RetryPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRetryRoundTripper 创建重试HTTP传输
|
||||||
|
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
|
||||||
|
if next == nil {
|
||||||
|
next = http.DefaultTransport
|
||||||
|
}
|
||||||
|
if policy == nil {
|
||||||
|
policy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
return &RetryRoundTripper{
|
||||||
|
Next: next,
|
||||||
|
Policy: policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip 执行HTTP请求
|
||||||
|
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
// 需要保留原始请求体,以便重试
|
||||||
|
var reqBodyBytes []byte
|
||||||
|
if req.Body != nil {
|
||||||
|
var err error
|
||||||
|
reqBodyBytes, err = io.ReadAll(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 尝试请求直到成功或达到最大重试次数
|
||||||
|
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
|
||||||
|
// 复制请求体
|
||||||
|
if len(reqBodyBytes) > 0 {
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
resp, err = rt.Next.RoundTrip(req)
|
||||||
|
|
||||||
|
// 检查是否需要重试
|
||||||
|
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
|
||||||
|
// 如果需要重试,先关闭当前响应
|
||||||
|
if resp != nil {
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算退避时间
|
||||||
|
backoff := rt.calculateBackoff(attempt)
|
||||||
|
time.Sleep(backoff)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不需要重试,返回响应
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有重试都失败
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff 计算退避时间
|
||||||
|
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
|
||||||
|
// 指数退避: baseBackoff * 2^attempt
|
||||||
|
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
|
||||||
|
if backoff > rt.Policy.MaxBackoff {
|
||||||
|
backoff = rt.Policy.MaxBackoff
|
||||||
|
}
|
||||||
|
return backoff
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryMiddleware 重试中间件
|
||||||
|
type RetryMiddleware struct {
|
||||||
|
policy *RetryPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRetryMiddleware 创建重试中间件
|
||||||
|
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
|
||||||
|
if policy == nil {
|
||||||
|
policy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
return &RetryMiddleware{
|
||||||
|
policy: policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware 中间件处理函数
|
||||||
|
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return NewRetryRoundTripper(next, m.policy)
|
||||||
|
}
|
552
internal/proxy/certificate.go
Normal file
552
internal/proxy/certificate.go
Normal file
@@ -0,0 +1,552 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 内置默认CA证书和私钥
|
||||||
|
var (
|
||||||
|
// 默认根证书
|
||||||
|
defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE-----
|
||||||
|
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
|
||||||
|
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
|
||||||
|
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
|
||||||
|
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
|
||||||
|
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
|
||||||
|
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
|
||||||
|
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
|
||||||
|
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
|
||||||
|
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
|
||||||
|
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
|
||||||
|
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
|
||||||
|
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
|
||||||
|
-----END CERTIFICATE-----
|
||||||
|
`)
|
||||||
|
// 默认根私钥
|
||||||
|
defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY-----
|
||||||
|
MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49
|
||||||
|
AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/
|
||||||
|
nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA==
|
||||||
|
-----END EC PRIVATE KEY-----
|
||||||
|
`)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 加载和初始化默认根证书
|
||||||
|
var (
|
||||||
|
defaultRootCA *x509.Certificate
|
||||||
|
defaultRootKey *ecdsa.PrivateKey
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
block, _ := pem.Decode(defaultRootCAPem)
|
||||||
|
if block == nil {
|
||||||
|
panic("解析默认根证书PEM块失败")
|
||||||
|
}
|
||||||
|
defaultRootCA, err = x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("加载默认根证书失败: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ = pem.Decode(defaultRootKeyPem)
|
||||||
|
if block == nil {
|
||||||
|
panic("解析默认根私钥PEM块失败")
|
||||||
|
}
|
||||||
|
defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("加载默认根私钥失败: %s", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CertManager 证书管理器
|
||||||
|
type CertManager struct {
|
||||||
|
// 证书缓存
|
||||||
|
cache CertificateCache
|
||||||
|
// 默认私钥,可用于多个证书共享
|
||||||
|
defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥
|
||||||
|
// 默认使用ECDSA P-256曲线
|
||||||
|
curve elliptic.Curve
|
||||||
|
// 证书有效期(年)
|
||||||
|
validityYears int
|
||||||
|
// 是否使用ECDSA(否则使用RSA)
|
||||||
|
useECDSA bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCertManager 创建证书管理器
|
||||||
|
func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager {
|
||||||
|
manager := &CertManager{
|
||||||
|
cache: cache,
|
||||||
|
curve: elliptic.P256(), // 默认使用P-256曲线
|
||||||
|
validityYears: 1, // 默认证书有效期1年
|
||||||
|
useECDSA: true, // 默认使用ECDSA
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用选项
|
||||||
|
for _, option := range options {
|
||||||
|
option(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// CertManagerOption 证书管理器选项
|
||||||
|
type CertManagerOption func(*CertManager)
|
||||||
|
|
||||||
|
// WithUseECDSA 设置是否使用ECDSA(否则使用RSA)
|
||||||
|
func WithUseECDSA(useECDSA bool) CertManagerOption {
|
||||||
|
return func(m *CertManager) {
|
||||||
|
m.useECDSA = useECDSA
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDefaultPrivateKey 设置是否使用默认私钥
|
||||||
|
func WithDefaultPrivateKey(enable bool) CertManagerOption {
|
||||||
|
return func(m *CertManager) {
|
||||||
|
if enable {
|
||||||
|
if m.useECDSA {
|
||||||
|
// 生成ECDSA私钥
|
||||||
|
priv, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err))
|
||||||
|
}
|
||||||
|
m.defaultPrivateKey = priv
|
||||||
|
} else {
|
||||||
|
// 生成RSA私钥
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("生成默认RSA私钥失败: %s", err))
|
||||||
|
}
|
||||||
|
m.defaultPrivateKey = priv
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCurve 设置椭圆曲线
|
||||||
|
func WithCurve(curve elliptic.Curve) CertManagerOption {
|
||||||
|
return func(m *CertManager) {
|
||||||
|
m.curve = curve
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValidityYears 设置证书有效期(年)
|
||||||
|
func WithValidityYears(years int) CertManagerOption {
|
||||||
|
return func(m *CertManager) {
|
||||||
|
if years > 0 {
|
||||||
|
m.validityYears = years
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateTLSConfig 为指定主机生成TLS配置
|
||||||
|
func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) {
|
||||||
|
// 处理可能的端口
|
||||||
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = h
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查证书缓存
|
||||||
|
if m.cache != nil {
|
||||||
|
// 检查主域名和子域名
|
||||||
|
fields := strings.Split(host, ".")
|
||||||
|
domains := []string{host}
|
||||||
|
|
||||||
|
// 添加父域名
|
||||||
|
if len(fields) > 2 {
|
||||||
|
domains = append(domains, strings.Join(fields[1:], "."))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找缓存
|
||||||
|
for _, domain := range domains {
|
||||||
|
if cert := m.cache.Get(domain); cert != nil {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成新证书
|
||||||
|
cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缓存证书
|
||||||
|
if m.cache != nil {
|
||||||
|
// 缓存主机名
|
||||||
|
m.cache.Set(host, cert)
|
||||||
|
|
||||||
|
// 如果是IP地址,不进行其他处理
|
||||||
|
if net.ParseIP(host) != nil {
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 缓存域名和子域名证书
|
||||||
|
fields := strings.Split(host, ".")
|
||||||
|
if len(fields) >= 2 {
|
||||||
|
// 缓存主域名
|
||||||
|
domain := strings.Join(fields[1:], ".")
|
||||||
|
m.cache.Set(domain, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCertificate 生成证书
|
||||||
|
func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) {
|
||||||
|
// 准备私钥
|
||||||
|
var priv interface{}
|
||||||
|
var pubKey interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// 使用默认私钥或生成新私钥
|
||||||
|
if m.defaultPrivateKey != nil {
|
||||||
|
priv = m.defaultPrivateKey
|
||||||
|
} else if m.useECDSA {
|
||||||
|
// 生成ECDSA私钥
|
||||||
|
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
priv = ecdsaKey
|
||||||
|
pubKey = &ecdsaKey.PublicKey
|
||||||
|
} else {
|
||||||
|
// 生成RSA私钥
|
||||||
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("生成RSA私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
priv = rsaKey
|
||||||
|
pubKey = &rsaKey.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取公钥
|
||||||
|
if pubKey == nil {
|
||||||
|
switch k := priv.(type) {
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
pubKey = &k.PublicKey
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
pubKey = &k.PublicKey
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("不支持的私钥类型")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建证书模板
|
||||||
|
template := m.createCertificateTemplate(host)
|
||||||
|
|
||||||
|
// 签名证书
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 编码为PEM格式
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: derBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 编码私钥
|
||||||
|
var keyPEM []byte
|
||||||
|
switch k := priv.(type) {
|
||||||
|
case *ecdsa.PrivateKey:
|
||||||
|
privBytes, err := x509.MarshalECPrivateKey(k)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "EC PRIVATE KEY",
|
||||||
|
Bytes: privBytes,
|
||||||
|
})
|
||||||
|
case *rsa.PrivateKey:
|
||||||
|
privBytes := x509.MarshalPKCS1PrivateKey(k)
|
||||||
|
keyPEM = pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: privBytes,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("不支持的私钥类型")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建TLS证书
|
||||||
|
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("创建TLS证书对失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tlsCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createCertificateTemplate 创建证书模板
|
||||||
|
func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate {
|
||||||
|
// 使用基于主机名的哈希值作为序列号
|
||||||
|
fv := fnv.New64a()
|
||||||
|
fv.Write([]byte(host))
|
||||||
|
serialNumber := big.NewInt(0).SetUint64(fv.Sum64())
|
||||||
|
|
||||||
|
// 准备模板
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: host,
|
||||||
|
Organization: []string{"GoProxy Dynamic CA"},
|
||||||
|
Country: []string{"CN"},
|
||||||
|
Province: []string{"GuangDong"},
|
||||||
|
Locality: []string{"Guangzhou"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效,容忍时间偏差
|
||||||
|
NotAfter: time.Now().AddDate(m.validityYears, 0, 0),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理IP地址和域名
|
||||||
|
ipAddr := net.ParseIP(host)
|
||||||
|
if ipAddr != nil {
|
||||||
|
template.IPAddresses = []net.IP{ipAddr}
|
||||||
|
} else {
|
||||||
|
// 移除可能的端口部分
|
||||||
|
if strings.Contains(host, ":") {
|
||||||
|
host = strings.Split(host, ":")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将主机名添加到DNS名称列表
|
||||||
|
template.DNSNames = []string{host}
|
||||||
|
|
||||||
|
// 添加通配符域名支持
|
||||||
|
fields := strings.Split(host, ".")
|
||||||
|
fieldNum := len(fields)
|
||||||
|
|
||||||
|
// 为每一级子域名添加通配符
|
||||||
|
for i := 0; i <= (fieldNum - 2); i++ {
|
||||||
|
wildcardDomain := "*." + strings.Join(fields[i:], ".")
|
||||||
|
// 避免重复
|
||||||
|
if wildcardDomain != host {
|
||||||
|
template.DNSNames = append(template.DNSNames, wildcardDomain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadCAFromFiles 从文件加载CA证书和私钥
|
||||||
|
func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||||
|
// 读取CA证书
|
||||||
|
caCertPEM, err := os.ReadFile(certFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(caCertPEM)
|
||||||
|
if block == nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析CA证书PEM块失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
caCert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析CA证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取CA私钥
|
||||||
|
caKeyPEM, err := os.ReadFile(keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ = pem.Decode(caKeyPEM)
|
||||||
|
if block == nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析CA私钥PEM块失败")
|
||||||
|
}
|
||||||
|
|
||||||
|
var caKey *ecdsa.PrivateKey
|
||||||
|
|
||||||
|
// 尝试不同的私钥格式
|
||||||
|
switch block.Type {
|
||||||
|
case "EC PRIVATE KEY":
|
||||||
|
caKey, err = x509.ParseECPrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
case "PRIVATE KEY":
|
||||||
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
var ok bool
|
||||||
|
caKey, ok = key.(*ecdsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, fmt.Errorf("私钥不是ECDSA类型")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
return caCert, caKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultRootCA 获取默认根证书和私钥
|
||||||
|
func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||||
|
return defaultRootCA, defaultRootKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRootCA 生成新的根证书和私钥
|
||||||
|
func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) {
|
||||||
|
// 生成私钥
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 随机生成序列号
|
||||||
|
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||||
|
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("生成序列号失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建根证书模板
|
||||||
|
notBefore := time.Now().Add(-10 * time.Minute)
|
||||||
|
notAfter := notBefore.AddDate(validYears, 0, 0)
|
||||||
|
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: serialNumber,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "GoProxy Root CA",
|
||||||
|
Organization: []string{"GoProxy"},
|
||||||
|
Country: []string{"CN"},
|
||||||
|
Province: []string{"GuangDong"},
|
||||||
|
Locality: []string{"Guangzhou"},
|
||||||
|
},
|
||||||
|
NotBefore: notBefore,
|
||||||
|
NotAfter: notAfter,
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
MaxPathLen: 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自签名
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析生成的证书
|
||||||
|
cert, err := x509.ParseCertificate(derBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, priv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveCertificateToFile 将证书和私钥保存到文件
|
||||||
|
func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error {
|
||||||
|
// 保存证书
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: cert.Raw,
|
||||||
|
})
|
||||||
|
if err := os.WriteFile(certFile, certPEM, 0644); err != nil {
|
||||||
|
return fmt.Errorf("保存证书到文件失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 保存私钥
|
||||||
|
keyBytes, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("序列化私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "EC PRIVATE KEY",
|
||||||
|
Bytes: keyBytes,
|
||||||
|
})
|
||||||
|
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
|
||||||
|
return fmt.Errorf("保存私钥到文件失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRootCA 生成根证书
|
||||||
|
func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) {
|
||||||
|
// 生成私钥
|
||||||
|
var priv interface{}
|
||||||
|
var pubKey interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if m.useECDSA {
|
||||||
|
// 生成ECDSA私钥
|
||||||
|
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
priv = ecdsaKey
|
||||||
|
pubKey = &ecdsaKey.PublicKey
|
||||||
|
} else {
|
||||||
|
// 生成RSA私钥
|
||||||
|
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err)
|
||||||
|
}
|
||||||
|
priv = rsaKey
|
||||||
|
pubKey = &rsaKey.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建根证书模板
|
||||||
|
template := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "GoProxy Root CA",
|
||||||
|
Organization: []string{"GoProxy Authority"},
|
||||||
|
Country: []string{"CN"},
|
||||||
|
Province: []string{"GuangDong"},
|
||||||
|
Locality: []string{"Guangzhou"},
|
||||||
|
},
|
||||||
|
NotBefore: time.Now().Add(-10 * time.Minute),
|
||||||
|
NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
MaxPathLen: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自签名
|
||||||
|
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析证书
|
||||||
|
cert, err := x509.ParseCertificate(derBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, priv, nil
|
||||||
|
}
|
134
internal/proxy/conn_buffer.go
Normal file
134
internal/proxy/conn_buffer.go
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnBuffer 连接缓冲区
|
||||||
|
// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口
|
||||||
|
type ConnBuffer struct {
|
||||||
|
// 底层连接
|
||||||
|
conn net.Conn
|
||||||
|
// 缓冲读取器
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnBuffer 创建连接缓冲区
|
||||||
|
func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer {
|
||||||
|
if reader == nil {
|
||||||
|
reader = bufio.NewReader(conn)
|
||||||
|
}
|
||||||
|
return &ConnBuffer{
|
||||||
|
conn: conn,
|
||||||
|
reader: reader,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read 从连接读取数据
|
||||||
|
func (c *ConnBuffer) Read(b []byte) (int, error) {
|
||||||
|
return c.reader.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write 向连接写入数据
|
||||||
|
func (c *ConnBuffer) Write(b []byte) (int, error) {
|
||||||
|
return c.conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭连接
|
||||||
|
func (c *ConnBuffer) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr 获取本地地址
|
||||||
|
func (c *ConnBuffer) LocalAddr() net.Addr {
|
||||||
|
return c.conn.LocalAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr 获取远程地址
|
||||||
|
func (c *ConnBuffer) RemoteAddr() net.Addr {
|
||||||
|
return c.conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline 设置读写超时
|
||||||
|
func (c *ConnBuffer) SetDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline 设置读取超时
|
||||||
|
func (c *ConnBuffer) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline 设置写入超时
|
||||||
|
func (c *ConnBuffer) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.conn.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BufferReader 获取缓冲读取器
|
||||||
|
func (c *ConnBuffer) BufferReader() *bufio.Reader {
|
||||||
|
return c.reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek 查看缓冲区中的数据,但不消费
|
||||||
|
func (c *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||||
|
return c.reader.Peek(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadByte 读取一个字节
|
||||||
|
func (c *ConnBuffer) ReadByte() (byte, error) {
|
||||||
|
return c.reader.ReadByte()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnreadByte 将最后读取的字节放回缓冲区
|
||||||
|
func (c *ConnBuffer) UnreadByte() error {
|
||||||
|
return c.reader.UnreadByte()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadLine 读取一行数据
|
||||||
|
func (c *ConnBuffer) ReadLine() (string, error) {
|
||||||
|
line, isPrefix, err := c.reader.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果一行数据没有读取完整,继续读取
|
||||||
|
if isPrefix {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, line...)
|
||||||
|
for isPrefix && err == nil {
|
||||||
|
line, isPrefix, err = c.reader.ReadLine()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
buf = append(buf, line...)
|
||||||
|
}
|
||||||
|
return string(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(line), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadN 读取指定字节数的数据
|
||||||
|
func (c *ConnBuffer) ReadN(n int) ([]byte, error) {
|
||||||
|
buf := make([]byte, n)
|
||||||
|
_, err := io.ReadFull(c.reader, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush 刷新缓冲区
|
||||||
|
func (c *ConnBuffer) Flush() error {
|
||||||
|
// 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置连接缓冲区
|
||||||
|
func (c *ConnBuffer) Reset(conn net.Conn) {
|
||||||
|
c.conn = conn
|
||||||
|
c.reader = bufio.NewReader(conn)
|
||||||
|
}
|
132
internal/proxy/context.go
Normal file
132
internal/proxy/context.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context 代理上下文
|
||||||
|
// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据
|
||||||
|
type Context struct {
|
||||||
|
// 原始请求
|
||||||
|
Req *http.Request
|
||||||
|
// 请求开始时间
|
||||||
|
StartTime time.Time
|
||||||
|
// 上下文数据,用于在各个处理阶段传递数据
|
||||||
|
Data map[interface{}]interface{}
|
||||||
|
// 是否是隧道代理
|
||||||
|
TunnelProxy bool
|
||||||
|
// 请求ID
|
||||||
|
RequestID string
|
||||||
|
// 目标地址
|
||||||
|
TargetAddr string
|
||||||
|
// 上级代理地址
|
||||||
|
ParentProxyURL *url.URL
|
||||||
|
// 是否中断执行
|
||||||
|
abort bool
|
||||||
|
// 请求标签,用于标记请求类型
|
||||||
|
Tags []string
|
||||||
|
// 是否已中止
|
||||||
|
aborted bool
|
||||||
|
// 互斥锁
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsHTTPS 是否是HTTPS请求
|
||||||
|
func (c *Context) IsHTTPS() bool {
|
||||||
|
return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultPorts 默认端口映射
|
||||||
|
var defaultPorts = map[string]string{
|
||||||
|
"https": "443",
|
||||||
|
"http": "80",
|
||||||
|
"": "80",
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketURL 获取WebSocket URL
|
||||||
|
func (c *Context) WebSocketURL() *url.URL {
|
||||||
|
u := *c.Req.URL
|
||||||
|
if c.IsHTTPS() {
|
||||||
|
u.Scheme = "wss"
|
||||||
|
} else {
|
||||||
|
u.Scheme = "ws"
|
||||||
|
}
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr 获取请求地址
|
||||||
|
func (c *Context) Addr() string {
|
||||||
|
addr := c.Req.Host
|
||||||
|
|
||||||
|
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||||
|
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTag 添加请求标签
|
||||||
|
func (c *Context) AddTag(tag string) {
|
||||||
|
c.Tags = append(c.Tags, tag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasTag 检查是否包含指定标签
|
||||||
|
func (c *Context) HasTag(tag string) bool {
|
||||||
|
for _, t := range c.Tags {
|
||||||
|
if t == tag {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Abort 中断执行
|
||||||
|
func (c *Context) Abort() {
|
||||||
|
c.aborted = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAborted 是否已中断执行
|
||||||
|
func (c *Context) IsAborted() bool {
|
||||||
|
return c.aborted
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset 重置上下文
|
||||||
|
func (c *Context) Reset(req *http.Request) {
|
||||||
|
c.Req = req
|
||||||
|
c.StartTime = time.Now()
|
||||||
|
c.Data = make(map[interface{}]interface{})
|
||||||
|
c.abort = false
|
||||||
|
c.TunnelProxy = false
|
||||||
|
c.Tags = make([]string, 0)
|
||||||
|
c.RequestID = ""
|
||||||
|
c.TargetAddr = ""
|
||||||
|
c.ParentProxyURL = nil
|
||||||
|
c.aborted = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 设置数据
|
||||||
|
func (c *Context) Set(key, value interface{}) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.Data == nil {
|
||||||
|
c.Data = make(map[interface{}]interface{})
|
||||||
|
}
|
||||||
|
c.Data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 获取数据
|
||||||
|
func (c *Context) Get(key interface{}) (interface{}, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
if c.Data == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
val, ok := c.Data[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
123
internal/proxy/delegate.go
Normal file
123
internal/proxy/delegate.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Delegate 代理委托接口
|
||||||
|
// 定义了代理处理请求的各个阶段的回调方法
|
||||||
|
type Delegate interface {
|
||||||
|
// Connect 连接事件
|
||||||
|
Connect(ctx *Context, rw http.ResponseWriter)
|
||||||
|
|
||||||
|
// Auth 认证事件
|
||||||
|
Auth(ctx *Context, rw http.ResponseWriter)
|
||||||
|
|
||||||
|
// BeforeRequest 请求前事件
|
||||||
|
BeforeRequest(ctx *Context)
|
||||||
|
|
||||||
|
// BeforeResponse 响应前事件
|
||||||
|
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||||
|
|
||||||
|
// WebSocketSendMessage websocket发送消息拦截
|
||||||
|
// WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||||
|
|
||||||
|
// WebSocketReceiveMessage websocket接收消息拦截
|
||||||
|
// WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||||
|
|
||||||
|
// ParentProxy 获取上级代理
|
||||||
|
ParentProxy(req *http.Request) (*url.URL, error)
|
||||||
|
|
||||||
|
// ErrorLog 错误日志
|
||||||
|
ErrorLog(err error)
|
||||||
|
|
||||||
|
// Finish 完成事件
|
||||||
|
Finish(ctx *Context)
|
||||||
|
|
||||||
|
// 以下是反向代理相关的方法
|
||||||
|
|
||||||
|
// ResolveBackend 解析后端服务器
|
||||||
|
// 在反向代理模式下,根据请求确定应该转发到哪个后端服务器
|
||||||
|
ResolveBackend(req *http.Request) (string, error)
|
||||||
|
|
||||||
|
// ModifyRequest 修改请求
|
||||||
|
// 在反向代理模式下,可以修改发往后端服务器的请求
|
||||||
|
ModifyRequest(req *http.Request)
|
||||||
|
|
||||||
|
// ModifyResponse 修改响应
|
||||||
|
// 在反向代理模式下,可以修改来自后端服务器的响应
|
||||||
|
ModifyResponse(resp *http.Response) error
|
||||||
|
|
||||||
|
// HandleError 处理错误
|
||||||
|
// 在反向代理模式下,可以自定义错误处理逻辑
|
||||||
|
HandleError(rw http.ResponseWriter, req *http.Request, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultDelegate 默认代理委托
|
||||||
|
type DefaultDelegate struct{}
|
||||||
|
|
||||||
|
// Connect 连接事件
|
||||||
|
func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth 认证事件
|
||||||
|
func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRequest 请求前事件
|
||||||
|
func (d *DefaultDelegate) BeforeRequest(ctx *Context) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeResponse 响应前事件
|
||||||
|
func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParentProxy 获取上级代理
|
||||||
|
func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||||
|
// 默认实现不使用上级代理
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocketSendMessage websocket发送消息拦截
|
||||||
|
// func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||||
|
|
||||||
|
// WebSocketReceiveMessage websocket接收消息拦截
|
||||||
|
// func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||||
|
|
||||||
|
// ErrorLog 错误日志
|
||||||
|
func (d *DefaultDelegate) ErrorLog(err error) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finish 完成事件
|
||||||
|
func (d *DefaultDelegate) Finish(ctx *Context) {
|
||||||
|
// 默认实现不做任何处理
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveBackend 解析后端服务器
|
||||||
|
func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||||
|
// 默认实现返回请求中的主机
|
||||||
|
return req.Host, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyRequest 修改请求
|
||||||
|
func (d *DefaultDelegate) ModifyRequest(req *http.Request) {
|
||||||
|
// 默认实现不做任何修改
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyResponse 修改响应
|
||||||
|
func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error {
|
||||||
|
// 默认实现不做任何修改
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleError 处理错误
|
||||||
|
func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
// 默认实现返回502错误
|
||||||
|
http.Error(rw, err.Error(), http.StatusBadGateway)
|
||||||
|
}
|
262
internal/proxy/options.go
Normal file
262
internal/proxy/options.go
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/cache"
|
||||||
|
"github.com/goproxy/internal/config"
|
||||||
|
"github.com/goproxy/internal/healthcheck"
|
||||||
|
"github.com/goproxy/internal/loadbalance"
|
||||||
|
"github.com/goproxy/internal/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option 用于配置代理选项的函数类型
|
||||||
|
type Option func(*Options)
|
||||||
|
|
||||||
|
// WithConfig 设置代理配置
|
||||||
|
func WithConfig(cfg *config.Config) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.Config = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDisableKeepAlive 设置连接是否重用
|
||||||
|
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
// 在transport中设置DisableKeepAlives
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithClientTrace 设置HTTP客户端跟踪
|
||||||
|
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.ClientTrace = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDelegate 设置委托类
|
||||||
|
func WithDelegate(delegate Delegate) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.Delegate = delegate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTransport 使用自定义HTTP传输
|
||||||
|
func WithTransport(t *http.Transport) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
// 在New方法中处理transport
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDecryptHTTPS 启用中间人代理解密HTTPS
|
||||||
|
func WithDecryptHTTPS(c CertificateCache) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.DecryptHTTPS = true
|
||||||
|
opt.CertCache = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnableWebsocketIntercept 启用WebSocket拦截
|
||||||
|
func WithEnableWebsocketIntercept() Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
// WebSocket拦截在代理处理逻辑中实现
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHTTPCache 设置HTTP缓存
|
||||||
|
func WithHTTPCache(c cache.Cache) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.HTTPCache = c
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableCache = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLoadBalancer 设置负载均衡器
|
||||||
|
func WithLoadBalancer(lb loadbalance.LoadBalancer) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.LoadBalancer = lb
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableLoadBalancing = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHealthChecker 设置健康检查器
|
||||||
|
func WithHealthChecker(hc *healthcheck.HealthChecker) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.HealthChecker = hc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMetrics 设置监控指标
|
||||||
|
func WithMetrics(m metrics.Metrics) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.Metrics = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSCertAndKey 设置TLS证书和密钥
|
||||||
|
func WithTLSCertAndKey(certPath, keyPath string) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.TLSCert = certPath
|
||||||
|
opt.Config.TLSKey = keyPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCACertAndKey 设置CA证书和密钥(用于生成动态证书)
|
||||||
|
func WithCACertAndKey(caCertPath, caKeyPath string) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.CACert = caCertPath
|
||||||
|
opt.Config.CAKey = caKeyPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithConnectionPoolSize 设置连接池大小
|
||||||
|
func WithConnectionPoolSize(size int) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.ConnectionPoolSize = size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIdleTimeout 设置空闲超时时间
|
||||||
|
func WithIdleTimeout(timeout time.Duration) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.IdleTimeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestTimeout 设置请求超时时间
|
||||||
|
func WithRequestTimeout(timeout time.Duration) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.RequestTimeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReverseProxy 启用反向代理模式
|
||||||
|
func WithReverseProxy(enable bool) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.ReverseProxy = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnableRetry 启用请求重试
|
||||||
|
func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableRetry = true
|
||||||
|
opt.Config.MaxRetries = maxRetries
|
||||||
|
opt.Config.RetryBackoff = baseBackoff
|
||||||
|
opt.Config.MaxRetryBackoff = maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRateLimit 设置请求限流
|
||||||
|
func WithRateLimit(rps float64) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableRateLimit = true
|
||||||
|
opt.Config.RateLimit = rps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDNSCacheTTL 设置DNS缓存TTL
|
||||||
|
func WithDNSCacheTTL(ttl time.Duration) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.DNSCacheTTL = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithURLRewrite 启用URL重写
|
||||||
|
func WithURLRewrite(enable bool) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableURLRewrite = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnableCORS 启用CORS支持
|
||||||
|
func WithEnableCORS(enable bool) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.EnableCORS = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCertManager 设置证书管理器
|
||||||
|
// 这是一个内部函数,主要用于在New方法中设置CertManager
|
||||||
|
func WithCertManager(certManager *CertManager) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
opt.CertManager = certManager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnableECDSA 启用ECDSA证书生成(默认使用RSA)
|
||||||
|
func WithEnableECDSA(enable bool) Option {
|
||||||
|
return func(opt *Options) {
|
||||||
|
if opt.Config == nil {
|
||||||
|
opt.Config = config.DefaultConfig()
|
||||||
|
}
|
||||||
|
opt.Config.UseECDSA = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithOptions 使用选项函数创建代理
|
||||||
|
func NewWithOptions(options ...Option) *Proxy {
|
||||||
|
opts := &Options{
|
||||||
|
Config: config.DefaultConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 应用所有选项
|
||||||
|
for _, option := range options {
|
||||||
|
option(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return New(opts)
|
||||||
|
}
|
1130
internal/proxy/proxy.go
Normal file
1130
internal/proxy/proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
277
internal/proxy/reverse_proxy.go
Normal file
277
internal/proxy/reverse_proxy.go
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/goproxy/internal/rewriter"
|
||||||
|
"github.com/goproxy/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReverseProxy 反向代理
|
||||||
|
type ReverseProxy struct {
|
||||||
|
// 代理对象
|
||||||
|
proxy *Proxy
|
||||||
|
// 路由器
|
||||||
|
router *router.Router
|
||||||
|
// URL重写器
|
||||||
|
rewriter *rewriter.Rewriter
|
||||||
|
// HTTP传输对象
|
||||||
|
transport http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReverseProxy 创建反向代理
|
||||||
|
func (p *Proxy) NewReverseProxy() *ReverseProxy {
|
||||||
|
rp := &ReverseProxy{
|
||||||
|
proxy: p,
|
||||||
|
router: router.NewRouter(),
|
||||||
|
rewriter: rewriter.NewRewriter(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建自定义的传输对象
|
||||||
|
transport := &http.Transport{
|
||||||
|
Proxy: func(req *http.Request) (*url.URL, error) {
|
||||||
|
// 使用代理委托中的方法获取代理
|
||||||
|
return p.delegate.ParentProxy(req)
|
||||||
|
},
|
||||||
|
DialContext: p.dialContextWithCache(),
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
MaxIdleConns: p.config.ConnectionPoolSize,
|
||||||
|
IdleConnTimeout: p.config.IdleTimeout,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
MaxIdleConnsPerHost: p.config.ConnectionPoolSize,
|
||||||
|
DisableCompression: !p.config.EnableCompression,
|
||||||
|
}
|
||||||
|
|
||||||
|
rp.transport = transport
|
||||||
|
|
||||||
|
// 如果配置了规则文件,加载规则
|
||||||
|
if p.config.ReverseProxyRulesFile != "" {
|
||||||
|
// 省略加载规则文件的实现
|
||||||
|
}
|
||||||
|
|
||||||
|
return rp
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP 处理反向代理请求
|
||||||
|
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
// 获取请求上下文
|
||||||
|
ctx := ctxPool.Get().(*Context)
|
||||||
|
ctx.Reset(req)
|
||||||
|
defer ctxPool.Put(ctx)
|
||||||
|
|
||||||
|
// 调用连接事件
|
||||||
|
rp.proxy.delegate.Connect(ctx, rw)
|
||||||
|
|
||||||
|
// 认证检查
|
||||||
|
rp.proxy.delegate.Auth(ctx, rw)
|
||||||
|
if ctx.IsAborted() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 请求前处理
|
||||||
|
rp.proxy.delegate.BeforeRequest(ctx)
|
||||||
|
if ctx.IsAborted() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析后端地址
|
||||||
|
backend, err := rp.proxy.delegate.ResolveBackend(req)
|
||||||
|
if err != nil {
|
||||||
|
rp.proxy.delegate.HandleError(rw, req, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建请求代理对象
|
||||||
|
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: backend,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用自定义传输对象
|
||||||
|
proxy.Transport = rp.transport
|
||||||
|
|
||||||
|
// 设置自定义错误处理函数
|
||||||
|
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
rp.proxy.delegate.HandleError(rw, req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求修改函数
|
||||||
|
originalDirector := proxy.Director
|
||||||
|
proxy.Director = func(req *http.Request) {
|
||||||
|
// 调用原始Director函数
|
||||||
|
originalDirector(req)
|
||||||
|
|
||||||
|
// 处理URL重写
|
||||||
|
if rp.proxy.config.EnableURLRewrite {
|
||||||
|
rp.rewriter.Rewrite(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 修改请求头
|
||||||
|
if rp.proxy.config.RewriteHostHeader {
|
||||||
|
req.Host = backend
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加X-Forwarded-For头
|
||||||
|
if rp.proxy.config.AddXForwardedFor {
|
||||||
|
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
// 如果已经有X-Forwarded-For,添加到末尾
|
||||||
|
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||||||
|
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加X-Real-IP头
|
||||||
|
if rp.proxy.config.AddXRealIP {
|
||||||
|
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
req.Header.Set("X-Real-IP", clientIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置协议头
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "http")
|
||||||
|
if req.TLS != nil {
|
||||||
|
req.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用委托的ModifyRequest方法
|
||||||
|
rp.proxy.delegate.ModifyRequest(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置响应修改函数
|
||||||
|
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
// 处理响应URL重写
|
||||||
|
if rp.proxy.config.EnableURLRewrite && resp != nil {
|
||||||
|
rp.rewriter.RewriteResponse(resp, req.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加CORS头
|
||||||
|
if rp.proxy.config.EnableCORS && resp != nil {
|
||||||
|
resp.Header.Set("Access-Control-Allow-Origin", "*")
|
||||||
|
resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
|
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用委托的ModifyResponse方法
|
||||||
|
return rp.proxy.delegate.ModifyResponse(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新监控指标
|
||||||
|
if rp.proxy.metrics != nil {
|
||||||
|
rp.proxy.metrics.IncActiveConnections()
|
||||||
|
defer rp.proxy.metrics.DecActiveConnections()
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
defer func() {
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
rp.proxy.metrics.ObserveRequestDuration(duration.Seconds())
|
||||||
|
rp.proxy.metrics.IncRequestCount()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理WebSocket升级
|
||||||
|
if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) {
|
||||||
|
rp.handleWebSocketUpgrade(rw, req, backend)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理普通请求
|
||||||
|
proxy.ServeHTTP(rw, req)
|
||||||
|
|
||||||
|
// 完成事件
|
||||||
|
rp.proxy.delegate.Finish(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理WebSocket升级
|
||||||
|
func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) {
|
||||||
|
// 创建WebSocket代理
|
||||||
|
target := &url.URL{
|
||||||
|
Scheme: "ws",
|
||||||
|
Host: backend,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.TLS != nil {
|
||||||
|
target.Scheme = "wss"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建连接到后端的WebSocket连接
|
||||||
|
backendConn, err := rp.dialBackend(target.String(), req)
|
||||||
|
if err != nil {
|
||||||
|
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer backendConn.Close()
|
||||||
|
|
||||||
|
// 将请求转发给后端
|
||||||
|
err = req.Write(backendConn)
|
||||||
|
if err != nil {
|
||||||
|
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 升级客户端连接
|
||||||
|
clientConn, err := hijacker(rw)
|
||||||
|
if err != nil {
|
||||||
|
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 双向转发数据
|
||||||
|
rp.proxy.transfer(clientConn, backendConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 连接到后端
|
||||||
|
func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
backend := strings.TrimPrefix(url, "ws://")
|
||||||
|
backend = strings.TrimPrefix(backend, "wss://")
|
||||||
|
|
||||||
|
if strings.Contains(backend, "/") {
|
||||||
|
backend = backend[:strings.Index(backend, "/")]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据协议选择连接方式
|
||||||
|
if strings.HasPrefix(url, "wss://") {
|
||||||
|
// 使用 tls.Dialer 替代不存在的 tls.DialWithContext
|
||||||
|
dialer := &tls.Dialer{
|
||||||
|
Config: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return dialer.DialContext(ctx, "tcp", backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
var d net.Dialer
|
||||||
|
return d.DialContext(ctx, "tcp", backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加路由规则
|
||||||
|
func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) {
|
||||||
|
route := &router.Route{
|
||||||
|
Pattern: pattern,
|
||||||
|
Type: routeType,
|
||||||
|
Target: target,
|
||||||
|
}
|
||||||
|
rp.router.AddRoute(route)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加重写规则
|
||||||
|
func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error {
|
||||||
|
return rp.rewriter.AddRule(pattern, replacement, useRegex)
|
||||||
|
}
|
98
internal/rewriter/rewriter.go
Normal file
98
internal/rewriter/rewriter.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package rewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rewriter URL重写器
|
||||||
|
// 用于在反向代理中重写请求URL
|
||||||
|
type Rewriter struct {
|
||||||
|
// 重写规则列表
|
||||||
|
rules []*RewriteRule
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRule 重写规则
|
||||||
|
type RewriteRule struct {
|
||||||
|
// 匹配模式
|
||||||
|
Pattern string
|
||||||
|
// 替换模式
|
||||||
|
Replacement string
|
||||||
|
// 是否使用正则表达式
|
||||||
|
UseRegex bool
|
||||||
|
// 编译后的正则表达式
|
||||||
|
regex *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRewriter 创建URL重写器
|
||||||
|
func NewRewriter() *Rewriter {
|
||||||
|
return &Rewriter{
|
||||||
|
rules: make([]*RewriteRule, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRule 添加重写规则
|
||||||
|
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
|
||||||
|
rule := &RewriteRule{
|
||||||
|
Pattern: pattern,
|
||||||
|
Replacement: replacement,
|
||||||
|
UseRegex: useRegex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if useRegex {
|
||||||
|
regex, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
rule.regex = regex
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules = append(r.rules, rule)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrite 重写URL
|
||||||
|
func (r *Rewriter) Rewrite(req *http.Request) {
|
||||||
|
path := req.URL.Path
|
||||||
|
|
||||||
|
for _, rule := range r.rules {
|
||||||
|
if rule.UseRegex {
|
||||||
|
if rule.regex.MatchString(path) {
|
||||||
|
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if strings.HasPrefix(path, rule.Pattern) {
|
||||||
|
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteResponse 重写响应
|
||||||
|
// 主要用于处理响应中的Location头和内容中的URL
|
||||||
|
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
|
||||||
|
// 处理重定向头
|
||||||
|
location := resp.Header.Get("Location")
|
||||||
|
if location != "" {
|
||||||
|
// 将后端服务器的域名替换成代理服务器的域名
|
||||||
|
for _, rule := range r.rules {
|
||||||
|
if rule.UseRegex && rule.regex != nil {
|
||||||
|
if rule.regex.MatchString(location) {
|
||||||
|
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
|
||||||
|
resp.Header.Set("Location", newLocation)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadRulesFromFile 从文件加载重写规则
|
||||||
|
func (r *Rewriter) LoadRulesFromFile(filename string) error {
|
||||||
|
// 实现从配置文件加载规则的逻辑
|
||||||
|
// 这里省略实现细节
|
||||||
|
return nil
|
||||||
|
}
|
103
internal/router/router.go
Normal file
103
internal/router/router.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route 路由规则
|
||||||
|
type Route struct {
|
||||||
|
// 匹配模式(主机名、路径、正则表达式)
|
||||||
|
Pattern string
|
||||||
|
// 匹配类型
|
||||||
|
Type RouteType
|
||||||
|
// 目标地址
|
||||||
|
Target string
|
||||||
|
// 路径重写规则
|
||||||
|
RewritePattern string
|
||||||
|
// 请求头修改
|
||||||
|
HeaderModifier HeaderModifier
|
||||||
|
// 自定义匹配函数
|
||||||
|
MatchFunc func(req *http.Request) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteType 路由类型
|
||||||
|
type RouteType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// HostRoute 主机名路由
|
||||||
|
HostRoute RouteType = iota
|
||||||
|
// PathRoute 路径路由
|
||||||
|
PathRoute
|
||||||
|
// RegexRoute 正则表达式路由
|
||||||
|
RegexRoute
|
||||||
|
// CustomRoute 自定义路由
|
||||||
|
CustomRoute
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderModifier 头部修改接口
|
||||||
|
type HeaderModifier interface {
|
||||||
|
// ModifyRequest 修改请求头
|
||||||
|
ModifyRequest(req *http.Request)
|
||||||
|
// ModifyResponse 修改响应头
|
||||||
|
ModifyResponse(resp *http.Response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router 路由器
|
||||||
|
type Router struct {
|
||||||
|
routes []*Route
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter 创建路由器
|
||||||
|
func NewRouter() *Router {
|
||||||
|
return &Router{
|
||||||
|
routes: make([]*Route, 0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRoute 添加路由规则
|
||||||
|
func (r *Router) AddRoute(route *Route) {
|
||||||
|
r.routes = append(r.routes, route)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match 匹配请求
|
||||||
|
func (r *Router) Match(req *http.Request) (*Route, bool) {
|
||||||
|
for _, route := range r.routes {
|
||||||
|
switch route.Type {
|
||||||
|
case HostRoute:
|
||||||
|
if matchHost(req.Host, route.Pattern) {
|
||||||
|
return route, true
|
||||||
|
}
|
||||||
|
case PathRoute:
|
||||||
|
if matchPath(req.URL.Path, route.Pattern) {
|
||||||
|
return route, true
|
||||||
|
}
|
||||||
|
case RegexRoute:
|
||||||
|
if matchRegex(req.URL.String(), route.Pattern) {
|
||||||
|
return route, true
|
||||||
|
}
|
||||||
|
case CustomRoute:
|
||||||
|
if route.MatchFunc != nil && route.MatchFunc(req) {
|
||||||
|
return route, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配主机名
|
||||||
|
func matchHost(host, pattern string) bool {
|
||||||
|
return host == pattern || strings.HasSuffix(host, "."+pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配路径
|
||||||
|
func matchPath(path, pattern string) bool {
|
||||||
|
return strings.HasPrefix(path, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 匹配正则表达式
|
||||||
|
func matchRegex(url, pattern string) bool {
|
||||||
|
matched, _ := regexp.MatchString(pattern, url)
|
||||||
|
return matched
|
||||||
|
}
|
14
mitm-proxy.crt
Normal file
14
mitm-proxy.crt
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
-----BEGIN CERTIFICATE-----
|
||||||
|
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
|
||||||
|
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
|
||||||
|
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
|
||||||
|
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
|
||||||
|
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
|
||||||
|
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
|
||||||
|
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
|
||||||
|
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
|
||||||
|
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
|
||||||
|
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
|
||||||
|
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
|
||||||
|
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
|
||||||
|
-----END CERTIFICATE-----
|
753
proxy.go
Normal file
753
proxy.go
Normal file
@@ -0,0 +1,753 @@
|
|||||||
|
// Copyright 2018 ouqiang authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||||
|
// not use this file except in compliance with the License. You may obtain
|
||||||
|
// a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// License for the specific language governing permissions and limitations
|
||||||
|
// under the License.
|
||||||
|
|
||||||
|
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
|
||||||
|
package goproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/viki-org/dnscache"
|
||||||
|
|
||||||
|
"github.com/ouqiang/goproxy/cert"
|
||||||
|
"github.com/ouqiang/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 连接目标服务器超时时间
|
||||||
|
defaultTargetConnectTimeout = 5 * time.Second
|
||||||
|
// 目标服务器读写超时时间
|
||||||
|
defaultTargetReadWriteTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
// 隧道连接成功响应行
|
||||||
|
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
|
||||||
|
|
||||||
|
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
|
||||||
|
|
||||||
|
var (
|
||||||
|
bufPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, 32*1024)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(Context)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
headerPool = NewHeaderPool()
|
||||||
|
requestPool = newRequestPool()
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRequestPool() *RequestPool {
|
||||||
|
return &RequestPool{
|
||||||
|
pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new(http.Request)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RequestPool) Get() *http.Request {
|
||||||
|
req := p.pool.Get().(*http.Request)
|
||||||
|
|
||||||
|
req.Method = ""
|
||||||
|
req.URL = nil
|
||||||
|
req.Proto = ""
|
||||||
|
req.ProtoMajor = 0
|
||||||
|
req.ProtoMinor = 0
|
||||||
|
req.Header = nil
|
||||||
|
req.Body = nil
|
||||||
|
req.GetBody = nil
|
||||||
|
req.ContentLength = 0
|
||||||
|
req.TransferEncoding = nil
|
||||||
|
req.Close = false
|
||||||
|
req.Host = ""
|
||||||
|
req.Form = nil
|
||||||
|
req.PostForm = nil
|
||||||
|
req.MultipartForm = nil
|
||||||
|
req.Trailer = nil
|
||||||
|
req.RemoteAddr = ""
|
||||||
|
req.RequestURI = ""
|
||||||
|
req.TLS = nil
|
||||||
|
req.Cancel = nil
|
||||||
|
req.Response = nil
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *RequestPool) Put(req *http.Request) {
|
||||||
|
if req != nil {
|
||||||
|
p.pool.Put(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeaderPool struct {
|
||||||
|
pool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHeaderPool() *HeaderPool {
|
||||||
|
return &HeaderPool{
|
||||||
|
pool: sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return http.Header{}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *HeaderPool) Get() http.Header {
|
||||||
|
header := p.pool.Get().(http.Header)
|
||||||
|
for k := range header {
|
||||||
|
delete(header, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *HeaderPool) Put(header http.Header) {
|
||||||
|
if header != nil {
|
||||||
|
p.pool.Put(header)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成隧道建立请求行
|
||||||
|
func makeTunnelRequestLine(addr string) string {
|
||||||
|
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type options struct {
|
||||||
|
disableKeepAlive bool
|
||||||
|
delegate Delegate
|
||||||
|
|
||||||
|
decryptHTTPS bool
|
||||||
|
websocketIntercept bool
|
||||||
|
certCache cert.Cache
|
||||||
|
transport *http.Transport
|
||||||
|
clientTrace *httptrace.ClientTrace
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*options)
|
||||||
|
|
||||||
|
// WithDisableKeepAlive 连接是否重用
|
||||||
|
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.disableKeepAlive = disableKeepAlive
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.clientTrace = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDelegate 设置委托类
|
||||||
|
func WithDelegate(delegate Delegate) Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.delegate = delegate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTransport 自定义http transport
|
||||||
|
func WithTransport(t *http.Transport) Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.transport = t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
|
||||||
|
func WithDecryptHTTPS(c cert.Cache) Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.decryptHTTPS = true
|
||||||
|
opt.certCache = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnableWebsocketIntercept 拦截websocket
|
||||||
|
func WithEnableWebsocketIntercept() Option {
|
||||||
|
return func(opt *options) {
|
||||||
|
opt.websocketIntercept = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New 创建proxy实例
|
||||||
|
func New(opt ...Option) *Proxy {
|
||||||
|
opts := &options{}
|
||||||
|
for _, o := range opt {
|
||||||
|
o(opts)
|
||||||
|
}
|
||||||
|
if opts.delegate == nil {
|
||||||
|
opts.delegate = &DefaultDelegate{}
|
||||||
|
}
|
||||||
|
if opts.transport == nil {
|
||||||
|
opts.transport = &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
},
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
MaxConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 10 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 5 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Proxy{}
|
||||||
|
p.delegate = opts.delegate
|
||||||
|
p.websocketIntercept = opts.websocketIntercept
|
||||||
|
p.decryptHTTPS = opts.decryptHTTPS
|
||||||
|
if p.decryptHTTPS {
|
||||||
|
p.cert = cert.NewCertificate(opts.certCache, true)
|
||||||
|
}
|
||||||
|
p.transport = opts.transport
|
||||||
|
p.transport.DialContext = p.dialContext()
|
||||||
|
p.dnsCache = dnscache.New(5 * time.Minute)
|
||||||
|
p.transport.DisableKeepAlives = opts.disableKeepAlive
|
||||||
|
p.transport.Proxy = p.delegate.ParentProxy
|
||||||
|
p.clientTrace = opts.clientTrace
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy 实现了http.Handler接口
|
||||||
|
type Proxy struct {
|
||||||
|
delegate Delegate
|
||||||
|
clientConnNum int32
|
||||||
|
decryptHTTPS bool
|
||||||
|
websocketIntercept bool
|
||||||
|
cert *cert.Certificate
|
||||||
|
transport *http.Transport
|
||||||
|
clientTrace *httptrace.ClientTrace
|
||||||
|
dnsCache *dnscache.Resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ http.Handler = &Proxy{}
|
||||||
|
|
||||||
|
// ServeHTTP 实现了http.Handler接口
|
||||||
|
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if req.URL.Host == "" {
|
||||||
|
req.URL.Host = req.Host
|
||||||
|
}
|
||||||
|
atomic.AddInt32(&p.clientConnNum, 1)
|
||||||
|
ctx := ctxPool.Get().(*Context)
|
||||||
|
ctx.Reset(req)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
p.delegate.Finish(ctx)
|
||||||
|
ctxPool.Put(ctx)
|
||||||
|
atomic.AddInt32(&p.clientConnNum, -1)
|
||||||
|
}()
|
||||||
|
p.delegate.Connect(ctx, rw)
|
||||||
|
if ctx.abort {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.delegate.Auth(ctx, rw)
|
||||||
|
if ctx.abort {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case ctx.Req.Method == http.MethodConnect:
|
||||||
|
p.tunnelProxy(ctx, rw)
|
||||||
|
case websocket.IsWebSocketUpgrade(ctx.Req):
|
||||||
|
p.tunnelProxy(ctx, rw)
|
||||||
|
default:
|
||||||
|
p.httpProxy(ctx, rw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConnNum 获取客户端连接数
|
||||||
|
func (p *Proxy) ClientConnNum() int32 {
|
||||||
|
return atomic.LoadInt32(&p.clientConnNum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRequest 执行HTTP请求,并调用responseFunc处理response
|
||||||
|
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
|
||||||
|
if ctx.Data == nil {
|
||||||
|
ctx.Data = make(map[interface{}]interface{})
|
||||||
|
}
|
||||||
|
p.delegate.BeforeRequest(ctx)
|
||||||
|
if ctx.abort {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newReq := requestPool.Get()
|
||||||
|
*newReq = *ctx.Req
|
||||||
|
newHeader := headerPool.Get()
|
||||||
|
CloneHeader(newReq.Header, newHeader)
|
||||||
|
newReq.Header = newHeader
|
||||||
|
for _, item := range hopHeaders {
|
||||||
|
if newReq.Header.Get(item) != "" {
|
||||||
|
newReq.Header.Del(item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.clientTrace != nil {
|
||||||
|
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.transport.RoundTrip(newReq)
|
||||||
|
p.delegate.BeforeResponse(ctx, resp, err)
|
||||||
|
if ctx.abort {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
for _, h := range hopHeaders {
|
||||||
|
resp.Header.Del(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseFunc(resp, err)
|
||||||
|
headerPool.Put(newHeader)
|
||||||
|
requestPool.Put(newReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP代理
|
||||||
|
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
|
||||||
|
ctx.Req.URL.Scheme = "http"
|
||||||
|
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
CopyHeader(rw.Header(), resp.Header)
|
||||||
|
rw.WriteHeader(resp.StatusCode)
|
||||||
|
buf := bufPool.Get().([]byte)
|
||||||
|
_, _ = io.CopyBuffer(rw, resp.Body, buf)
|
||||||
|
bufPool.Put(buf)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPS代理
|
||||||
|
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
|
||||||
|
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||||
|
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
|
||||||
|
_, _ = tlsClientConn.Write(badGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = resp.Write(tlsClientConn)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 隧道代理
|
||||||
|
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
|
||||||
|
clientConn, err := hijacker(rw)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(err)
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||||
|
p.websocketProxy(ctx, clientConn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if parentProxyURL == nil {
|
||||||
|
_, err = clientConn.Write(tunnelEstablishedResponseLine)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isWebsocket := p.detectConnProtocol(clientConn)
|
||||||
|
if isWebsocket {
|
||||||
|
req, err := http.ReadRequest(clientConn.BufferReader())
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.RemoteAddr = ctx.Req.RemoteAddr
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
req.URL.Host = req.Host
|
||||||
|
ctx.Req = req
|
||||||
|
|
||||||
|
p.websocketProxy(ctx, clientConn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var tlsClientConn *tls.Conn
|
||||||
|
if p.decryptHTTPS {
|
||||||
|
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsClientConn = tls.Server(clientConn, tlsConfig)
|
||||||
|
defer func() {
|
||||||
|
_ = tlsClientConn.Close()
|
||||||
|
}()
|
||||||
|
if err := tlsClientConn.Handshake(); err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := bufio.NewReader(tlsClientConn)
|
||||||
|
tlsReq, err := http.ReadRequest(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
|
||||||
|
tlsReq.URL.Scheme = "https"
|
||||||
|
tlsReq.URL.Host = tlsReq.Host
|
||||||
|
ctx.Req = tlsReq
|
||||||
|
}
|
||||||
|
|
||||||
|
targetAddr := ctx.Req.URL.Host
|
||||||
|
if parentProxyURL != nil {
|
||||||
|
targetAddr = parentProxyURL.Host
|
||||||
|
}
|
||||||
|
if !strings.Contains(targetAddr, ":") {
|
||||||
|
targetAddr += ":443"
|
||||||
|
}
|
||||||
|
|
||||||
|
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = targetConn.Close()
|
||||||
|
}()
|
||||||
|
if parentProxyURL != nil {
|
||||||
|
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
|
||||||
|
_, _ = targetConn.Write([]byte(tunnelRequestLine))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.decryptHTTPS {
|
||||||
|
p.httpsProxy(ctx, tlsClientConn)
|
||||||
|
} else {
|
||||||
|
p.tunnelConnected(ctx, nil)
|
||||||
|
p.transfer(clientConn, targetConn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket代理
|
||||||
|
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
|
||||||
|
if !p.websocketIntercept {
|
||||||
|
remoteAddr := ctx.Addr()
|
||||||
|
var err error
|
||||||
|
var targetConn net.Conn
|
||||||
|
if ctx.IsHTTPS() {
|
||||||
|
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
|
||||||
|
} else {
|
||||||
|
targetConn, err = net.Dial("tcp", remoteAddr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = ctx.Req.Write(targetConn)
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.tunnelConnected(ctx, nil)
|
||||||
|
p.transfer(srcConn, targetConn)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
up := &websocket.Upgrader{
|
||||||
|
HandshakeTimeout: defaultTargetConnectTimeout,
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u := ctx.WebsocketUrl()
|
||||||
|
d := websocket.Dialer{
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
}
|
||||||
|
|
||||||
|
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
|
||||||
|
defer cancel()
|
||||||
|
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
|
||||||
|
if err != nil {
|
||||||
|
p.tunnelConnected(ctx, err)
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.tunnelConnected(ctx, nil)
|
||||||
|
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 探测连接协议
|
||||||
|
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
|
||||||
|
methodBytes, err := connBuf.Peek(3)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
method := string(methodBytes)
|
||||||
|
if method != http.MethodGet {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSocket双向转发
|
||||||
|
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
|
||||||
|
doneCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
if doneCtx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, msg, err := srcConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
|
||||||
|
err = targetConn.WriteMessage(msgType, msg)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if doneCtx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgType, msg, err := targetConn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
|
||||||
|
err = srcConn.WriteMessage(msgType, msg)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 双向转发
|
||||||
|
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
|
||||||
|
go func() {
|
||||||
|
buf := bufPool.Get().([]byte)
|
||||||
|
_, err := io.CopyBuffer(src, dst, buf)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
|
||||||
|
}
|
||||||
|
bufPool.Put(buf)
|
||||||
|
_ = src.Close()
|
||||||
|
_ = dst.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := bufPool.Get().([]byte)
|
||||||
|
_, err := io.CopyBuffer(dst, src, buf)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
|
||||||
|
}
|
||||||
|
bufPool.Put(buf)
|
||||||
|
_ = dst.Close()
|
||||||
|
_ = src.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
|
||||||
|
ctx.TunnelProxy = true
|
||||||
|
p.delegate.BeforeRequest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
p.delegate.BeforeResponse(ctx, nil, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := &http.Response{
|
||||||
|
Status: "200 OK",
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Proto: "1.1",
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: http.Header{},
|
||||||
|
Body: http.NoBody,
|
||||||
|
}
|
||||||
|
p.delegate.BeforeResponse(ctx, resp, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) dialContext() DialContext {
|
||||||
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: defaultTargetConnectTimeout,
|
||||||
|
}
|
||||||
|
separator := strings.LastIndex(addr, ":")
|
||||||
|
ips, err := p.dnsCache.Fetch(addr[:separator])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var ip string
|
||||||
|
for _, item := range ips {
|
||||||
|
ip = item.String()
|
||||||
|
if !strings.Contains(ip, ":") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
addr = ip + addr[separator:]
|
||||||
|
|
||||||
|
return dialer.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取底层连接
|
||||||
|
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
|
||||||
|
hijacker, ok := rw.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("http server不支持Hijacker")
|
||||||
|
}
|
||||||
|
conn, buf, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("hijacker错误: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewConnBuffer(conn, buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyHeader 浅拷贝Header
|
||||||
|
func CopyHeader(dst, src http.Header) {
|
||||||
|
for k, vv := range src {
|
||||||
|
for _, v := range vv {
|
||||||
|
dst.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneHeader 深拷贝Header
|
||||||
|
func CloneHeader(h http.Header, h2 http.Header) {
|
||||||
|
for k, vv := range h {
|
||||||
|
vv2 := make([]string, len(vv))
|
||||||
|
copy(vv2, vv)
|
||||||
|
h2[k] = vv2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var hopHeaders = []string{
|
||||||
|
"Proxy-Connection",
|
||||||
|
"Keep-Alive",
|
||||||
|
"Proxy-Authenticate",
|
||||||
|
"Proxy-Authorization",
|
||||||
|
"Te",
|
||||||
|
"Trailer",
|
||||||
|
"Transfer-Encoding",
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnBuffer struct {
|
||||||
|
net.Conn
|
||||||
|
buf *bufio.ReadWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
|
||||||
|
if buf == nil {
|
||||||
|
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||||
|
}
|
||||||
|
return &ConnBuffer{
|
||||||
|
Conn: conn,
|
||||||
|
buf: buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
|
||||||
|
return cb.buf.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
|
||||||
|
return cb.buf.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||||
|
return cb.buf.Peek(n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
|
||||||
|
n, err = cb.buf.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, cb.buf.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
return cb.Conn, cb.buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) WriteHeader(_ int) {}
|
||||||
|
|
||||||
|
func (cb *ConnBuffer) Header() http.Header { return nil }
|
Reference in New Issue
Block a user