This commit is contained in:
2025-03-13 15:56:33 +08:00
commit 21e0a73e5c
26 changed files with 6131 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
.idea
./delegate.go
./proxy.go

454
README.md Normal file
View File

@@ -0,0 +1,454 @@
# GoProxy
GoProxy是一个功能强大的Go语言HTTP代理库支持HTTP、HTTPS和WebSocket代理并提供了丰富的功能和扩展点。
## 功能特性
- 支持HTTP、HTTPS和WebSocket代理
- 支持正向代理和反向代理
- 支持HTTPS解密中间人模式
- 自定义CA证书和私钥
- 动态证书生成与缓存
- 通配符域名证书支持
- 支持RSA和ECDSA证书算法选择
- 支持上游代理链
- 支持负载均衡(轮询、随机、权重等)
- 支持健康检查
- 支持请求重试
- 支持HTTP缓存
- 支持请求限流
- 支持监控指标收集
- 支持自定义处理逻辑(委托模式)
- 支持DNS缓存
- 支持URL重写反向代理模式
## 安装
```bash
go get github.com/goproxy
```
## 快速开始
### 正向代理
```go
package main
import (
"log"
"net/http"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建代理
p := proxy.New(nil)
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
### 启用HTTPS解密
```go
package main
import (
"log"
"net/http"
"github.com/goproxy/internal/config"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建配置
cfg := config.DefaultConfig()
cfg.DecryptHTTPS = true
cfg.CACert = "ca.crt" // CA证书路径
cfg.CAKey = "ca.key" // CA私钥路径
cfg.UseECDSA = true // 使用ECDSA生成证书默认为false使用RSA
// 可选使用自定义TLS证书
// cfg.TLSCert = "server.crt"
// cfg.TLSKey = "server.key"
// 创建证书缓存
certCache := &proxy.MemCertCache{}
// 创建代理
p := proxy.New(&proxy.Options{
Config: cfg,
CertCache: certCache,
})
// 启动HTTP服务器
log.Println("HTTPS解密代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
> **注意**: 使用HTTPS解密功能时需要在客户端安装CA证书否则会出现证书警告。
### 反向代理
```go
package main
import (
"log"
"net/http"
"github.com/goproxy/internal/config"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建配置
cfg := config.DefaultConfig()
cfg.ReverseProxy = true
cfg.EnableURLRewrite = true
cfg.AddXForwardedFor = true
cfg.AddXRealIP = true
// 创建自定义委托
delegate := &ReverseProxyDelegate{
backend: "localhost:8081",
}
// 创建代理
p := proxy.New(&proxy.Options{
Config: cfg,
Delegate: delegate,
})
// 启动HTTP服务器
log.Println("反向代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
// ReverseProxyDelegate 反向代理委托
type ReverseProxyDelegate struct {
proxy.DefaultDelegate
backend string
}
// ResolveBackend 解析后端服务器
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
return d.backend, nil
}
```
### 自定义委托
```go
package main
import (
"log"
"net/http"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建自定义委托
delegate := &CustomDelegate{}
// 创建代理
p := proxy.New(&proxy.Options{
Delegate: delegate,
})
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
// CustomDelegate 自定义委托
type CustomDelegate struct {
proxy.DefaultDelegate
}
// BeforeRequest 请求前事件
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
}
// BeforeResponse 响应前事件
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
if err != nil {
log.Printf("响应错误: %v\n", err)
return
}
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
}
```
### 完整示例
- 正向代理示例: [cmd/example/main.go](cmd/example/main.go)
- 反向代理示例: [cmd/reverse_proxy_example/main.go](cmd/reverse_proxy_example/main.go)
### 使用函数式选项模式
```go
package main
import (
"log"
"net/http"
"time"
"github.com/goproxy/internal/metrics"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建监控指标
metricsCollector := metrics.NewSimpleMetrics()
// 创建证书缓存
certCache := &proxy.MemCertCache{}
// 使用函数式选项模式创建代理
p := proxy.NewProxy(
// 启用HTTPS解密
proxy.WithDecryptHTTPS(certCache),
proxy.WithCACertAndKey("ca.crt", "ca.key"),
// 设置监控指标
proxy.WithMetrics(metricsCollector),
// 设置请求超时和连接池
proxy.WithRequestTimeout(30 * time.Second),
proxy.WithConnectionPoolSize(100),
proxy.WithIdleTimeout(90 * time.Second),
// 启用DNS缓存
proxy.WithDNSCacheTTL(10 * time.Minute),
// 启用请求重试
proxy.WithEnableRetry(3, 1*time.Second, 10*time.Second),
// 启用CORS支持
proxy.WithEnableCORS(true),
)
// 启动HTTP服务器和监控服务器
go func() {
log.Println("监控服务器启动在 :8081")
http.Handle("/metrics", metricsCollector)
if err := http.ListenAndServe(":8081", nil); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
// 启动代理服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
```
## 架构设计
GoProxy采用模块化设计主要包含以下模块
- **代理核心Proxy**处理HTTP请求和响应实现代理功能
- **反向代理ReverseProxy**处理反向代理请求支持URL重写和请求修改
- **路由Router**:基于主机名、路径、正则表达式等规则路由请求到不同的后端
- **URL重写Rewriter**重写请求URL和响应中的URL
- **代理上下文Context**:保存请求上下文信息,用于在处理过程中传递数据
- **代理委托Delegate**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑
- **连接缓冲区ConnBuffer**:封装网络连接,提供缓冲读写功能
- **负载均衡LoadBalancer**:实现负载均衡算法,支持轮询、随机、权重等
- **健康检查HealthChecker**:检查上游服务器的健康状态,自动剔除不健康的服务器
- **缓存Cache**实现HTTP缓存减少重复请求
- **缓存适配器CacheAdapter**:统一不同缓存实现的接口,提高代码可读性和性能
- **证书生成CertGenerator**动态生成TLS证书支持HTTPS解密
- **限流RateLimit**:实现请求限流,防止过载
- **监控Metrics**:收集代理运行指标,用于监控和分析
- **重试Retry**:实现请求重试,提高请求成功率
## 配置选项
GoProxy提供了丰富的配置选项可以通过`Options`结构体进行配置:
```go
type Options struct {
// 配置
Config *config.Config
// 委托
Delegate Delegate
// 证书缓存
CertCache CertificateCache
// HTTP缓存
HTTPCache cache.Cache
// 负载均衡器
LoadBalancer loadbalance.LoadBalancer
// 健康检查器
HealthChecker *healthcheck.HealthChecker
// 监控指标
Metrics metrics.Metrics
// 客户端跟踪
ClientTrace *httptrace.ClientTrace
}
```
### 函数式选项模式
GoProxy 现在支持函数式选项模式Functional Options Pattern通过一系列的 `With` 方法提供更加灵活和可读性更高的配置方式。此模式的优势在于:
- 参数配置更加直观和清晰
- 可以灵活选择需要的配置项,不必记忆参数顺序
- 代码可读性更高,便于维护
- 可以逐步添加新的配置选项而不破坏兼容性
可以使用 `NewProxy` 函数和函数式选项创建代理:
```go
// 创建一个简单的代理
proxy := proxy.NewProxy()
// 创建一个功能丰富的代理
proxy := proxy.NewProxy(
proxy.WithConfig(config.DefaultConfig()),
proxy.WithHTTPCache(myCache),
proxy.WithDecryptHTTPS(myCertCache),
proxy.WithCACertAndKey("ca.crt", "ca.key"),
proxy.WithMetrics(myMetrics),
proxy.WithLoadBalancer(myLoadBalancer),
proxy.WithRequestTimeout(10 * time.Second),
proxy.WithEnableCORS(true)
)
```
### 可用的 With 方法
GoProxy 提供了以下 With 方法用于配置代理的各个方面:
#### 基础配置选项
- `WithConfig(cfg *config.Config)`: 设置代理配置
- `WithDisableKeepAlive(disableKeepAlive bool)`: 设置连接是否重用
- `WithTransport(t *http.Transport)`: 使用自定义HTTP传输
- `WithClientTrace(t *httptrace.ClientTrace)`: 设置HTTP客户端跟踪
#### 功能模块选项
- `WithDelegate(delegate Delegate)`: 设置委托类
- `WithHTTPCache(c cache.Cache)`: 设置HTTP缓存
- `WithLoadBalancer(lb loadbalance.LoadBalancer)`: 设置负载均衡器
- `WithHealthChecker(hc *healthcheck.HealthChecker)`: 设置健康检查器
- `WithMetrics(m metrics.Metrics)`: 设置监控指标
#### 功能开启选项
- `WithDecryptHTTPS(c CertificateCache)`: 启用中间人代理解密HTTPS
- `WithEnableECDSA(enable bool)`: 启用ECDSA证书生成默认使用RSA
- `WithEnableWebsocketIntercept()`: 启用WebSocket拦截
- `WithReverseProxy(enable bool)`: 启用反向代理模式
- `WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration)`: 启用请求重试
- `WithRateLimit(rps float64)`: 设置请求限流
- `WithURLRewrite(enable bool)`: 启用URL重写
- `WithEnableCORS(enable bool)`: 启用CORS支持
#### 证书相关选项
- `WithTLSCertAndKey(certPath, keyPath string)`: 设置TLS证书和密钥
- `WithCACertAndKey(caCertPath, caKeyPath string)`: 设置CA证书和密钥
#### 性能和超时相关选项
- `WithConnectionPoolSize(size int)`: 设置连接池大小
- `WithIdleTimeout(timeout time.Duration)`: 设置空闲超时时间
- `WithRequestTimeout(timeout time.Duration)`: 设置请求超时时间
- `WithDNSCacheTTL(ttl time.Duration)`: 设置DNS缓存TTL
### 配置HTTPS解密
要启用HTTPS解密功能需要在配置中设置以下选项
```go
config := &config.Config{
// 启用HTTPS解密
DecryptHTTPS: true,
// 方式一使用CA证书和私钥动态生成证书
CACert: "path/to/ca.crt", // CA证书路径
CAKey: "path/to/ca.key", // CA私钥路径
// 选择证书生成算法(可选)
UseECDSA: true, // 使用ECDSA生成证书默认为false使用RSA
// 方式二使用固定的TLS证书和私钥
// TLSCert: "path/to/server.crt",
// TLSKey: "path/to/server.key",
}
```
或者使用函数式选项模式:
```go
proxy := proxy.NewProxy(
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
// 或者使用静态TLS证书
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
)
```
同时,建议配置证书缓存以提高性能:
```go
certCache := &proxy.MemCertCache{}
```
## 扩展点
GoProxy提供了多个扩展点可以通过实现相应的接口进行扩展
- **Delegate**:代理委托接口,用于自定义代理处理逻辑
- **LoadBalancer**:负载均衡接口,用于实现自定义负载均衡算法
- **Cache**:缓存接口,用于实现自定义缓存策略
- **CertificateCache**:证书缓存接口,用于自定义证书存储方式
- **Metrics**:监控接口,用于实现自定义监控指标收集
## 反向代理特性
GoProxy的反向代理模式提供以下特性
- **URL重写**支持基于前缀和正则表达式的URL重写
- **路由规则**:支持基于主机名、路径、正则表达式等的路由规则
- **请求修改**:支持修改发往后端服务器的请求
- **响应修改**:支持修改来自后端服务器的响应
- **保留客户端信息**支持添加X-Forwarded-For和X-Real-IP头
- **CORS支持**支持自动添加CORS头
- **WebSocket支持**支持WebSocket协议的透明代理
- **负载均衡**:支持多种负载均衡算法
- **健康检查**:支持对后端服务器进行健康检查
- **监控指标**:支持收集反向代理的监控指标
## 贡献
欢迎贡献代码、报告问题或提出建议。请遵循以下步骤:
1. Fork 项目
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
3. 提交更改 (`git commit -m 'Add some amazing feature'`)
4. 推送到分支 (`git push origin feature/amazing-feature`)
5. 创建 Pull Request
## 许可证
本项目采用 MIT 许可证,详情请参阅 [LICENSE](LICENSE) 文件。

73
SUMMARY.md Normal file
View File

@@ -0,0 +1,73 @@
# GoProxy 项目总结
## 项目概述
GoProxy 是一个功能强大的 Go 语言 HTTP 代理库,支持 HTTP、HTTPS 和 WebSocket 代理,并提供了丰富的功能和扩展点。该项目采用模块化设计,各个模块之间职责明确,耦合度低,便于扩展和维护。
## 已完成的模块
1. **配置模块config**:提供代理配置选项,包括连接池大小、超时时间、是否启用缓存、负载均衡等。
2. **代理上下文context**:保存请求上下文信息,用于在代理处理过程中传递数据,包括原始请求、目标地址、上级代理等。
3. **代理委托delegate**:定义代理处理请求的各个阶段的回调方法,用于自定义处理逻辑,包括连接、认证、请求前、响应前等事件。
4. **连接缓冲区conn_buffer**:封装网络连接,提供缓冲读写功能,简化网络 IO 操作。
5. **负载均衡loadbalance**:实现负载均衡算法,支持轮询、随机、权重等,自动选择合适的上游服务器。
6. **健康检查healthcheck**:检查上游服务器的健康状态,自动剔除不健康的服务器,提高代理可靠性。
7. **缓存cache**:实现 HTTP 缓存,减少重复请求,提高代理性能。
8. **缓存适配器cache_adapter**:统一不同缓存实现的接口,使用适配器模式提高代码可读性和执行效率,支持多种缓存实现方式。
9. **证书生成cert_generator**动态生成TLS证书支持HTTPS解密中间人模式可基于CA证书和私钥创建域名证书并支持证书缓存。支持RSA和ECDSA两种算法用户可根据需要选择安全性和性能的平衡。
10. **限流ratelimit**:实现请求限流,防止过载,保护上游服务器。
11. **监控metrics**:收集代理运行指标,用于监控和分析,包括请求数、响应时间、错误数等。
12. **重试retry**:实现请求重试,提高请求成功率,处理临时性故障。
13. **代理核心proxy**:处理 HTTP 请求和响应,实现代理功能,包括 HTTP、HTTPS 和 WebSocket 代理。
## 项目特点
1. **模块化设计**:各个模块职责明确,耦合度低,便于扩展和维护。
2. **丰富的功能**:支持 HTTP、HTTPS 和 WebSocket 代理,并提供了负载均衡、健康检查、缓存、限流、监控等功能。
3. **灵活的扩展点**:提供了多个扩展点,可以通过实现相应的接口进行扩展,如代理委托、负载均衡、缓存、监控等。
4. **高性能**:采用 Go 语言的并发特性,实现高性能的代理服务。优化的缓存适配器和证书缓存机制进一步提升性能。
5. **可靠性**:通过健康检查、重试等机制,提高代理的可靠性。
6. **可观测性**:通过监控指标收集,提高代理的可观测性。
7. **安全性**支持HTTPS解密功能可用于调试、安全审计和内容过滤。提供RSA和ECDSA双算法选择满足不同的安全需求和性能场景。
8. **支持更多证书格式**增加对PEM、DER等更多证书格式的支持以及对更多密钥算法如Ed25519等的支持。
## 使用示例
我们提供了一个完整的示例程序,展示了如何使用 GoProxy 库创建一个功能完善的代理服务器包括负载均衡、健康检查、缓存、监控等功能。另外还提供了启用HTTPS解密功能的示例展示如何使用CA证书动态生成站点证书。
## 未来计划
1. **完善文档**:编写更详细的文档,包括 API 文档、使用示例、最佳实践等。
2. **增加测试**:增加单元测试和集成测试,提高代码质量和可靠性。
3. **性能优化**:进一步优化代理性能,减少资源消耗。
4. **增加更多功能**:如请求过滤、内容修改、安全检查等。
5. **提供更多扩展点**:如请求路由、请求转换等。
6. **支持更多证书格式**增加对PEM、DER等更多证书格式的支持。
## 总结
GoProxy 是一个功能强大、设计良好的 Go 语言 HTTP 代理库可以满足各种代理需求如开发调试、负载均衡、API 网关等。通过模块化设计和丰富的扩展点GoProxy 可以灵活地适应各种场景是构建代理服务的理想选择。最近的增强包括缓存适配器优化和完整的HTTPS解密功能实现使GoProxy更加完善和高效。

165
cmd/example/main.go Normal file
View File

@@ -0,0 +1,165 @@
package main
import (
"flag"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/goproxy/internal/cache"
"github.com/goproxy/internal/config"
"github.com/goproxy/internal/healthcheck"
"github.com/goproxy/internal/loadbalance"
"github.com/goproxy/internal/metrics"
"github.com/goproxy/internal/proxy"
)
var (
// 监听地址
addr = flag.String("addr", ":8080", "代理服务器监听地址")
// 上游代理服务器
upstream = flag.String("upstream", "", "上游代理服务器地址,多个地址用逗号分隔")
// 是否启用负载均衡
enableLoadBalance = flag.Bool("enable-lb", false, "是否启用负载均衡")
// 是否启用健康检查
enableHealthCheck = flag.Bool("enable-hc", false, "是否启用健康检查")
// 是否启用缓存
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
// 是否启用重试
enableRetry = flag.Bool("enable-retry", false, "是否启用重试")
// 是否启用监控
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
// 监控地址
metricsAddr = flag.String("metrics-addr", ":8081", "监控服务器监听地址")
)
// 解析目标地址
func parseTargets(targets string) []string {
if targets == "" {
return nil
}
return strings.Split(targets, ",")
}
func main() {
// 解析命令行参数
flag.Parse()
// 创建配置
cfg := config.DefaultConfig()
cfg.EnableLoadBalancing = *enableLoadBalance
cfg.EnableHealthCheck = *enableHealthCheck
cfg.EnableCache = *enableCache
cfg.EnableRetry = *enableRetry
// 创建选项
opts := &proxy.Options{
Config: cfg,
}
// 创建负载均衡器
if *enableLoadBalance && *upstream != "" {
lb := loadbalance.NewRoundRobinBalancer()
for _, target := range parseTargets(*upstream) {
lb.Add(target, 1)
}
opts.LoadBalancer = lb
}
// 创建健康检查器
if *enableHealthCheck && opts.LoadBalancer != nil {
hc := healthcheck.NewHealthChecker(&healthcheck.Config{
Interval: time.Second * 10,
Timeout: time.Second * 2,
MaxFails: 3,
MinSuccess: 2,
})
opts.HealthChecker = hc
}
// 创建缓存
if *enableCache {
c := cache.NewMemoryCache(time.Minute*5, time.Minute*5, 1000)
opts.HTTPCache = c
}
// 创建监控
if *enableMetrics {
m := metrics.NewSimpleMetrics()
opts.Metrics = m
// 启动监控服务器
go func() {
mux := http.NewServeMux()
handler := m.GetHandler()
mux.Handle("/metrics", handler)
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
}
// 创建自定义委托
delegate := &CustomDelegate{}
opts.Delegate = delegate
// 创建代理
p := proxy.New(opts)
// 创建HTTP服务器
server := &http.Server{
Addr: *addr,
Handler: p,
}
// 启动HTTP服务器
go func() {
log.Printf("代理服务器启动在 %s\n", *addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("代理服务器启动失败: %v", err)
}
}()
// 等待信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭代理服务器...")
server.Close()
log.Println("代理服务器已关闭")
}
// CustomDelegate 自定义委托
type CustomDelegate struct {
proxy.DefaultDelegate
}
// Connect 连接事件
func (d *CustomDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
log.Printf("收到连接: %s -> %s\n", ctx.Req.RemoteAddr, ctx.Req.URL.Host)
}
// BeforeRequest 请求前事件
func (d *CustomDelegate) BeforeRequest(ctx *proxy.Context) {
log.Printf("请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.String())
}
// BeforeResponse 响应前事件
func (d *CustomDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
if err != nil {
log.Printf("响应错误: %v\n", err)
return
}
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
}
// ErrorLog 错误日志
func (d *CustomDelegate) ErrorLog(err error) {
log.Printf("错误: %v\n", err)
}

View File

@@ -0,0 +1,185 @@
package main
import (
"flag"
"log"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/goproxy/internal/config"
"github.com/goproxy/internal/metrics"
"github.com/goproxy/internal/proxy"
)
var (
// 监听地址
addr = flag.String("addr", ":8080", "反向代理服务器监听地址")
// 后端服务器
backend = flag.String("backend", "localhost:8081", "后端服务器地址")
// 路由规则文件
routeFile = flag.String("route-file", "", "路由规则文件路径")
// 是否启用URL重写
enableRewrite = flag.Bool("enable-rewrite", false, "是否启用URL重写")
// 是否启用缓存
enableCache = flag.Bool("enable-cache", false, "是否启用缓存")
// 是否启用压缩
enableCompression = flag.Bool("enable-compression", false, "是否启用压缩")
// 是否启用监控
enableMetrics = flag.Bool("enable-metrics", false, "是否启用监控")
// 监控地址
metricsAddr = flag.String("metrics-addr", ":8082", "监控服务器监听地址")
// 是否添加X-Forwarded-For
addXForwardedFor = flag.Bool("add-x-forwarded-for", true, "是否添加X-Forwarded-For头")
// 是否添加X-Real-IP
addXRealIP = flag.Bool("add-x-real-ip", true, "是否添加X-Real-IP头")
// 是否启用CORS
enableCORS = flag.Bool("enable-cors", false, "是否启用CORS")
// 路径前缀
pathPrefix = flag.String("path-prefix", "", "路径前缀,将从请求路径中移除")
)
func main() {
// 解析命令行参数
flag.Parse()
// 创建配置
cfg := config.DefaultConfig()
cfg.ReverseProxy = true
cfg.EnableCache = *enableCache
cfg.EnableCompression = *enableCompression
cfg.EnableURLRewrite = *enableRewrite
cfg.AddXForwardedFor = *addXForwardedFor
cfg.AddXRealIP = *addXRealIP
cfg.EnableCORS = *enableCORS
cfg.ReverseProxyRulesFile = *routeFile
// 创建选项
opts := &proxy.Options{
Config: cfg,
}
// 创建监控
if *enableMetrics {
m := metrics.NewSimpleMetrics()
opts.Metrics = m
// 启动监控服务器
go func() {
mux := http.NewServeMux()
handler := m.GetHandler()
mux.Handle("/metrics", handler)
log.Printf("监控服务器启动在 %s\n", *metricsAddr)
if err := http.ListenAndServe(*metricsAddr, mux); err != nil {
log.Fatalf("监控服务器启动失败: %v", err)
}
}()
}
// 创建自定义委托
delegate := &ReverseProxyDelegate{
backend: *backend,
prefix: *pathPrefix,
}
opts.Delegate = delegate
// 创建代理
p := proxy.New(opts)
// 如果有路径前缀,添加重写规则
if *pathPrefix != "" {
reverseProxy := p.NewReverseProxy()
log.Printf("添加路径重写规则: 从请求路径移除前缀 %s\n", *pathPrefix)
reverseProxy.AddRewriteRule(*pathPrefix, "", false)
}
// 创建HTTP服务器
server := &http.Server{
Addr: *addr,
Handler: p,
}
// 启动HTTP服务器
go func() {
log.Printf("反向代理服务器启动在 %s后端服务器为 %s\n", *addr, *backend)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("代理服务器启动失败: %v", err)
}
}()
// 等待信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("正在关闭代理服务器...")
server.Close()
log.Println("代理服务器已关闭")
}
// ReverseProxyDelegate 反向代理委托
type ReverseProxyDelegate struct {
proxy.DefaultDelegate
backend string
prefix string
}
// ResolveBackend 解析后端服务器
func (d *ReverseProxyDelegate) ResolveBackend(req *http.Request) (string, error) {
// 这里可以实现基于请求路径、主机名等的路由逻辑
return d.backend, nil
}
// ModifyRequest 修改请求
func (d *ReverseProxyDelegate) ModifyRequest(req *http.Request) {
// 移除路径前缀
if d.prefix != "" && strings.HasPrefix(req.URL.Path, d.prefix) {
req.URL.Path = strings.TrimPrefix(req.URL.Path, d.prefix)
if req.URL.Path == "" {
req.URL.Path = "/"
}
}
// 添加自定义请求头
req.Header.Set("X-Proxy-Time", time.Now().Format(time.RFC3339))
}
// ModifyResponse 修改响应
func (d *ReverseProxyDelegate) ModifyResponse(resp *http.Response) error {
// 添加自定义响应头
resp.Header.Set("X-Proxied-By", "GoProxy")
return nil
}
// Connect 连接事件
func (d *ReverseProxyDelegate) Connect(ctx *proxy.Context, rw http.ResponseWriter) {
log.Printf("收到连接: %s -> %s %s\n", ctx.Req.RemoteAddr, ctx.Req.Method, ctx.Req.URL.Path)
}
// BeforeRequest 请求前事件
func (d *ReverseProxyDelegate) BeforeRequest(ctx *proxy.Context) {
log.Printf("处理请求: %s %s\n", ctx.Req.Method, ctx.Req.URL.Path)
}
// BeforeResponse 响应前事件
func (d *ReverseProxyDelegate) BeforeResponse(ctx *proxy.Context, resp *http.Response, err error) {
if err != nil {
log.Printf("响应错误: %v\n", err)
return
}
log.Printf("响应: %d %s\n", resp.StatusCode, resp.Status)
}
// ErrorLog 错误日志
func (d *ReverseProxyDelegate) ErrorLog(err error) {
log.Printf("错误: %v\n", err)
}
// HandleError 处理错误
func (d *ReverseProxyDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
log.Printf("处理错误: %v\n", err)
http.Error(rw, "代理服务器错误: "+err.Error(), http.StatusBadGateway)
}

132
delegate.go Normal file
View File

@@ -0,0 +1,132 @@
// Copyright 2018 ouqiang authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package goproxy
import (
"log"
"net/http"
"net/url"
"strings"
)
// Context 代理上下文
type Context struct {
Req *http.Request
Data map[interface{}]interface{}
TunnelProxy bool
abort bool
}
func (c *Context) IsHTTPS() bool {
return c.Req.URL.Scheme == "https"
}
var defaultPorts = map[string]string{
"https": "443",
"http": "80",
"": "80",
}
func (c *Context) WebsocketUrl() *url.URL {
u := new(url.URL)
*u = *c.Req.URL
if c.IsHTTPS() {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
return u
}
func (c *Context) Addr() string {
addr := c.Req.Host
if !strings.Contains(c.Req.URL.Host, ":") {
addr += ":" + defaultPorts[c.Req.URL.Scheme]
}
return addr
}
// Abort 中断执行
func (c *Context) Abort() {
c.abort = true
}
// IsAborted 是否已中断执行
func (c *Context) IsAborted() bool {
return c.abort
}
// Reset 重置
func (c *Context) Reset(req *http.Request) {
c.Req = req
c.Data = make(map[interface{}]interface{})
c.abort = false
c.TunnelProxy = false
}
type Delegate interface {
// Connect 收到客户端连接
Connect(ctx *Context, rw http.ResponseWriter)
// Auth 代理身份认证
Auth(ctx *Context, rw http.ResponseWriter)
// BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body
BeforeRequest(ctx *Context)
// BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code
BeforeResponse(ctx *Context, resp *http.Response, err error)
// WebSocketSendMessage websocket发送消息
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
// WebSockerReceiveMessage websocket接收 消息
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
// ParentProxy 上级代理
ParentProxy(*http.Request) (*url.URL, error)
// Finish 本次请求结束
Finish(ctx *Context)
// 记录错误信息
ErrorLog(err error)
}
var _ Delegate = &DefaultDelegate{}
// DefaultDelegate 默认Handler什么也不做
type DefaultDelegate struct {
Delegate
}
func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
func (h *DefaultDelegate) BeforeRequest(ctx *Context) {}
func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {}
func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
return http.ProxyFromEnvironment(req)
}
// WebSocketSendMessage websocket发送消息
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
// WebSockerReceiveMessage websocket接收 消息
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
func (h *DefaultDelegate) Finish(ctx *Context) {}
func (h *DefaultDelegate) ErrorLog(err error) {
log.Println(err)
}

10
go.mod Normal file
View File

@@ -0,0 +1,10 @@
module github.com/goproxy
go 1.24.0
require (
github.com/ouqiang/goproxy v1.3.2
github.com/ouqiang/websocket v1.6.2
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
golang.org/x/time v0.11.0
)

8
go.sum Normal file
View File

@@ -0,0 +1,8 @@
github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk=
github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww=
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8/go.mod h1:dniwbG03GafCjFohMDmz6Zc6oCuiqgH6tGNyXTkHzXE=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=

220
internal/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,220 @@
package cache
import (
"bytes"
"crypto/md5"
"encoding/hex"
"io"
"net/http"
"strings"
"sync"
"time"
)
// Cache 缓存接口
type Cache interface {
// Get 获取缓存
Get(key string) (*http.Response, bool)
// Set 设置缓存
Set(key string, resp *http.Response)
// Delete 删除缓存
Delete(key string)
// Clear 清空缓存
Clear()
}
// MemoryCache 内存缓存实现
type MemoryCache struct {
// 缓存内容
items sync.Map
// 过期时间
ttl time.Duration
// 清理间隔
cleanupInterval time.Duration
// 最大条目数
maxEntries int
// 当前条目数
size int32
// 互斥锁
mu sync.Mutex
}
// CacheItem 缓存项
type CacheItem struct {
response *http.Response
responseBody []byte
expiry time.Time
}
// NewMemoryCache 创建内存缓存
func NewMemoryCache(ttl, cleanupInterval time.Duration, maxEntries int) *MemoryCache {
cache := &MemoryCache{
ttl: ttl,
cleanupInterval: cleanupInterval,
maxEntries: maxEntries,
}
// 启动过期清理
if cleanupInterval > 0 {
go cache.startCleanup()
}
return cache
}
// Get 获取缓存
func (c *MemoryCache) Get(key string) (*http.Response, bool) {
value, ok := c.items.Load(key)
if !ok {
return nil, false
}
item := value.(*CacheItem)
if time.Now().After(item.expiry) {
c.Delete(key)
return nil, false
}
// 克隆响应,避免修改原始数据
resp := cloneResponse(item.response, item.responseBody)
return resp, true
}
// Set 设置缓存
func (c *MemoryCache) Set(key string, resp *http.Response) {
// 检查缓存是否已满
c.mu.Lock()
if c.maxEntries > 0 && c.size >= int32(c.maxEntries) {
c.mu.Unlock()
return
}
c.size++
c.mu.Unlock()
// 读取并保存响应体
var bodyBytes []byte
if resp.Body != nil {
bodyBytes, _ = io.ReadAll(resp.Body)
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
item := &CacheItem{
response: resp,
responseBody: bodyBytes,
expiry: time.Now().Add(c.ttl),
}
c.items.Store(key, item)
}
// Delete 删除缓存
func (c *MemoryCache) Delete(key string) {
c.items.Delete(key)
c.mu.Lock()
c.size--
if c.size < 0 {
c.size = 0
}
c.mu.Unlock()
}
// Clear 清空缓存
func (c *MemoryCache) Clear() {
c.items = sync.Map{}
c.mu.Lock()
c.size = 0
c.mu.Unlock()
}
// startCleanup 启动过期清理
func (c *MemoryCache) startCleanup() {
ticker := time.NewTicker(c.cleanupInterval)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
c.items.Range(func(key, value interface{}) bool {
item := value.(*CacheItem)
if now.After(item.expiry) {
c.Delete(key.(string))
}
return true
})
}
}
// GenerateCacheKey 生成缓存键
func GenerateCacheKey(req *http.Request) string {
// 忽略一些可变的头部
ignoredHeaders := map[string]bool{
"Connection": true,
"Keep-Alive": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"TE": true,
"Trailers": true,
"Transfer-Encoding": true,
"Upgrade": true,
}
// 提取缓存键组件
components := []string{
req.Method,
req.URL.String(),
}
// 添加选择性头部
for key, values := range req.Header {
if !ignoredHeaders[key] {
for _, value := range values {
components = append(components, key+":"+value)
}
}
}
// 连接并计算哈希
data := strings.Join(components, "|")
hash := md5.New()
hash.Write([]byte(data))
return hex.EncodeToString(hash.Sum(nil))
}
// cloneResponse 克隆HTTP响应
func cloneResponse(resp *http.Response, body []byte) *http.Response {
clone := *resp
clone.Body = io.NopCloser(bytes.NewBuffer(body))
clone.Header = make(http.Header)
for k, v := range resp.Header {
clone.Header[k] = v
}
return &clone
}
// ShouldCache 判断请求是否应该缓存
func ShouldCache(req *http.Request, resp *http.Response) bool {
// 只缓存GET请求
if req.Method != http.MethodGet {
return false
}
// 检查响应状态码
if resp.StatusCode != http.StatusOK &&
resp.StatusCode != http.StatusNotModified &&
resp.StatusCode != http.StatusMovedPermanently &&
resp.StatusCode != http.StatusPermanentRedirect {
return false
}
// 检查Cache-Control头
cacheControl := resp.Header.Get("Cache-Control")
if strings.Contains(cacheControl, "no-store") ||
strings.Contains(cacheControl, "no-cache") ||
strings.Contains(cacheControl, "private") {
return false
}
return true
}

135
internal/config/config.go Normal file
View File

@@ -0,0 +1,135 @@
package config
import (
"time"
)
// Config 代理配置
type Config struct {
// 监听地址
ListenAddr string
// 是否启用负载均衡
EnableLoadBalancing bool
// 负载均衡后端列表
Backends []string
// 是否启用限流
EnableRateLimit bool
// 每秒请求速率限制
RateLimit float64
// 并发请求峰值限制
MaxBurst int
// 最大连接数
MaxConnections int
// 是否启用连接池
EnableConnectionPool bool
// 连接池大小
ConnectionPoolSize int
// 连接空闲超时时间
IdleTimeout time.Duration
// 请求超时时间
RequestTimeout time.Duration
// 是否启用响应缓存
EnableCache bool
// 缓存过期时间
CacheTTL time.Duration
// 是否启用HTTPS解密
DecryptHTTPS bool
// TLS证书文件路径
TLSCert string
// TLS密钥文件路径
TLSKey string
// CA证书文件路径(用于生成动态证书)
CACert string
// CA密钥文件路径(用于生成动态证书)
CAKey string
// 是否启用健康检查
EnableHealthCheck bool
// 健康检查间隔时间
HealthCheckInterval time.Duration
// 健康检查超时时间
HealthCheckTimeout time.Duration
// 是否启用重试机制
EnableRetry bool
// 最大重试次数
MaxRetries int
// 重试间隔基数
RetryBackoff time.Duration
// 最大重试间隔
MaxRetryBackoff time.Duration
// 是否启用监控指标
EnableMetrics bool
// 是否启用请求追踪
EnableTracing bool
// 是否拦截WebSocket
WebSocketIntercept bool
// DNS缓存过期时间
DNSCacheTTL time.Duration
// 是否作为反向代理
ReverseProxy bool
// 反向代理规则文件路径
ReverseProxyRulesFile string
// 是否启用URL重写
EnableURLRewrite bool
// 是否保留客户端IP
PreserveClientIP bool
// 是否启用压缩
EnableCompression bool
// 是否自动添加CORS头
EnableCORS bool
// 重写Host头
RewriteHostHeader bool
// 是否添加X-Forwarded-For头
AddXForwardedFor bool
// 是否添加X-Real-IP头
AddXRealIP bool
// 是否支持Websocket升级
SupportWebSocketUpgrade bool
// 是否使用ECDSA生成证书默认使用RSA
UseECDSA bool
}
// DefaultConfig 默认配置
func DefaultConfig() *Config {
return &Config{
ListenAddr: ":8080",
EnableLoadBalancing: false,
Backends: []string{},
EnableRateLimit: false,
RateLimit: 100,
MaxBurst: 50,
MaxConnections: 1000,
EnableConnectionPool: true,
ConnectionPoolSize: 100,
IdleTimeout: 60 * time.Second,
RequestTimeout: 30 * time.Second,
EnableCache: false,
CacheTTL: 5 * time.Minute,
DecryptHTTPS: false,
TLSCert: "",
TLSKey: "",
CACert: "",
CAKey: "",
EnableHealthCheck: false,
HealthCheckInterval: 10 * time.Second,
HealthCheckTimeout: 5 * time.Second,
EnableRetry: true,
MaxRetries: 3,
RetryBackoff: 100 * time.Millisecond,
MaxRetryBackoff: 2 * time.Second,
EnableMetrics: false,
EnableTracing: false,
WebSocketIntercept: false,
DNSCacheTTL: 5 * time.Minute,
ReverseProxy: false,
ReverseProxyRulesFile: "",
EnableURLRewrite: false,
PreserveClientIP: true,
EnableCompression: false,
EnableCORS: false,
RewriteHostHeader: false,
AddXForwardedFor: true,
AddXRealIP: true,
SupportWebSocketUpgrade: true,
UseECDSA: false,
}
}

View File

@@ -0,0 +1,248 @@
package healthcheck
import (
"context"
"net"
"net/http"
"net/url"
"sync"
"time"
)
// HealthChecker 健康检查器
type HealthChecker struct {
// 配置
config *Config
// 健康检查状态
statusMap sync.Map
// 状态变更回调
statusChangeCallback func(string, bool)
// 是否运行中
running bool
// 上下文
ctx context.Context
// 取消函数
cancel context.CancelFunc
// 互斥锁
mu sync.Mutex
}
// Config 健康检查配置
type Config struct {
// 检查间隔
Interval time.Duration
// 检查超时
Timeout time.Duration
// 检查路径
Path string
// 检查方法
Method string
// 检查状态码
SuccessStatus int
// 最大失败次数
MaxFails int
// 最小成功次数
MinSuccess int
}
// NewHealthChecker 创建健康检查器
func NewHealthChecker(config *Config) *HealthChecker {
if config.Path == "" {
config.Path = "/"
}
if config.Method == "" {
config.Method = http.MethodGet
}
if config.SuccessStatus == 0 {
config.SuccessStatus = http.StatusOK
}
if config.MaxFails == 0 {
config.MaxFails = 3
}
if config.MinSuccess == 0 {
config.MinSuccess = 2
}
ctx, cancel := context.WithCancel(context.Background())
return &HealthChecker{
config: config,
ctx: ctx,
cancel: cancel,
running: false,
}
}
// Start 启动健康检查
func (hc *HealthChecker) Start() {
hc.mu.Lock()
defer hc.mu.Unlock()
if hc.running {
return
}
hc.running = true
go hc.run()
}
// Stop 停止健康检查
func (hc *HealthChecker) Stop() {
hc.mu.Lock()
defer hc.mu.Unlock()
if !hc.running {
return
}
hc.cancel()
hc.running = false
}
// AddTarget 添加监控目标
func (hc *HealthChecker) AddTarget(target string) error {
u, err := url.Parse(target)
if err != nil {
return err
}
// 初始化为健康状态
hc.statusMap.Store(u.String(), &backendStatus{
URL: u,
Healthy: true,
FailCount: 0,
SuccessCount: 0,
})
return nil
}
// RemoveTarget 移除监控目标
func (hc *HealthChecker) RemoveTarget(target string) error {
u, err := url.Parse(target)
if err != nil {
return err
}
hc.statusMap.Delete(u.String())
return nil
}
// IsHealthy 检查目标是否健康
func (hc *HealthChecker) IsHealthy(target string) bool {
u, err := url.Parse(target)
if err != nil {
return false
}
value, ok := hc.statusMap.Load(u.String())
if !ok {
return false
}
status := value.(*backendStatus)
return status.Healthy
}
// SetStatusChangeCallback 设置状态变更回调
func (hc *HealthChecker) SetStatusChangeCallback(callback func(string, bool)) {
hc.statusChangeCallback = callback
}
// backendStatus 后端健康状态
type backendStatus struct {
URL *url.URL
Healthy bool
FailCount int
SuccessCount int
}
// run 运行健康检查
func (hc *HealthChecker) run() {
ticker := time.NewTicker(hc.config.Interval)
defer ticker.Stop()
for {
select {
case <-hc.ctx.Done():
return
case <-ticker.C:
hc.checkAll()
}
}
}
// checkAll 检查所有后端
func (hc *HealthChecker) checkAll() {
hc.statusMap.Range(func(key, value interface{}) bool {
go hc.check(key.(string), value.(*backendStatus))
return true
})
}
// check 检查单个后端
func (hc *HealthChecker) check(key string, status *backendStatus) {
// 创建检查请求
u := *status.URL
u.Path = hc.config.Path
req, err := http.NewRequest(hc.config.Method, u.String(), nil)
if err != nil {
hc.updateStatus(key, status, false)
return
}
// 设置超时的客户端
client := &http.Client{
Timeout: hc.config.Timeout,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: hc.config.Timeout,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: hc.config.Timeout,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
},
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
hc.updateStatus(key, status, false)
return
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode == hc.config.SuccessStatus {
hc.updateStatus(key, status, true)
} else {
hc.updateStatus(key, status, false)
}
}
// updateStatus 更新后端状态
func (hc *HealthChecker) updateStatus(key string, status *backendStatus, success bool) {
if success {
status.SuccessCount++
status.FailCount = 0
if !status.Healthy && status.SuccessCount >= hc.config.MinSuccess {
status.Healthy = true
if hc.statusChangeCallback != nil {
hc.statusChangeCallback(key, true)
}
}
} else {
status.FailCount++
status.SuccessCount = 0
if status.Healthy && status.FailCount >= hc.config.MaxFails {
status.Healthy = false
if hc.statusChangeCallback != nil {
hc.statusChangeCallback(key, false)
}
}
}
hc.statusMap.Store(key, status)
}

View File

@@ -0,0 +1,330 @@
package loadbalance
import (
"math/rand"
"net/url"
"sync"
"sync/atomic"
"time"
)
// Strategy 负载均衡策略
type Strategy int
const (
// StrategyRoundRobin 轮询策略
StrategyRoundRobin Strategy = iota
// StrategyRandom 随机策略
StrategyRandom
// StrategyWeightedRoundRobin 加权轮询策略
StrategyWeightedRoundRobin
// StrategyIPHash IP哈希策略
StrategyIPHash
)
// LoadBalancer 负载均衡器接口
type LoadBalancer interface {
// Next 获取下一个后端
Next(key string) (*url.URL, error)
// Add 添加后端
Add(backend string, weight int) error
// Remove 删除后端
Remove(backend string) error
// MarkDown 标记后端为不可用
MarkDown(backend string) error
// MarkUp 标记后端为可用
MarkUp(backend string) error
// Reset 重置负载均衡器
Reset() error
}
// Backend 后端服务器
type Backend struct {
// URL 后端URL
URL *url.URL
// Weight 权重
Weight int
// Down 是否不可用
Down bool
}
// RoundRobinBalancer 轮询负载均衡器
type RoundRobinBalancer struct {
backends []*Backend
current int32
mutex sync.RWMutex
}
// NewRoundRobinBalancer 创建轮询负载均衡器
func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{
backends: make([]*Backend, 0),
current: 0,
}
}
// Next 获取下一个后端
func (lb *RoundRobinBalancer) Next(key string) (*url.URL, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
if len(lb.backends) == 0 {
return nil, nil
}
// 计算可用后端数量
var availableCount int
for _, backend := range lb.backends {
if !backend.Down {
availableCount++
}
}
if availableCount == 0 {
return nil, nil
}
// 循环直到找到可用后端
for i := 0; i < len(lb.backends); i++ {
idx := atomic.AddInt32(&lb.current, 1) % int32(len(lb.backends))
backend := lb.backends[idx]
if !backend.Down {
return backend.URL, nil
}
}
return nil, nil
}
// Add 添加后端
func (lb *RoundRobinBalancer) Add(backend string, weight int) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
return nil
}
}
lb.backends = append(lb.backends, &Backend{
URL: url,
Weight: weight,
Down: false,
})
return nil
}
// Remove 删除后端
func (lb *RoundRobinBalancer) Remove(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for i, b := range lb.backends {
if b.URL.String() == url.String() {
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
return nil
}
}
return nil
}
// MarkDown 标记后端为不可用
func (lb *RoundRobinBalancer) MarkDown(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
b.Down = true
return nil
}
}
return nil
}
// MarkUp 标记后端为可用
func (lb *RoundRobinBalancer) MarkUp(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
b.Down = false
return nil
}
}
return nil
}
// Reset 重置负载均衡器
func (lb *RoundRobinBalancer) Reset() error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
lb.backends = make([]*Backend, 0)
atomic.StoreInt32(&lb.current, 0)
return nil
}
// RandomBalancer 随机负载均衡器
type RandomBalancer struct {
backends []*Backend
mutex sync.RWMutex
rand *rand.Rand
}
// NewRandomBalancer 创建随机负载均衡器
func NewRandomBalancer() *RandomBalancer {
source := rand.NewSource(time.Now().UnixNano())
random := rand.New(source)
return &RandomBalancer{
backends: make([]*Backend, 0),
rand: random,
}
}
// Next 获取下一个后端
func (lb *RandomBalancer) Next(key string) (*url.URL, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
if len(lb.backends) == 0 {
return nil, nil
}
// 计算可用后端数量
var availableBackends []*Backend
for _, backend := range lb.backends {
if !backend.Down {
availableBackends = append(availableBackends, backend)
}
}
if len(availableBackends) == 0 {
return nil, nil
}
idx := lb.rand.Intn(len(availableBackends))
return availableBackends[idx].URL, nil
}
// Add 添加后端
func (lb *RandomBalancer) Add(backend string, weight int) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
return nil
}
}
lb.backends = append(lb.backends, &Backend{
URL: url,
Weight: weight,
Down: false,
})
return nil
}
// Remove 删除后端
func (lb *RandomBalancer) Remove(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for i, b := range lb.backends {
if b.URL.String() == url.String() {
lb.backends = append(lb.backends[:i], lb.backends[i+1:]...)
return nil
}
}
return nil
}
// MarkDown 标记后端为不可用
func (lb *RandomBalancer) MarkDown(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
b.Down = true
return nil
}
}
return nil
}
// MarkUp 标记后端为可用
func (lb *RandomBalancer) MarkUp(backend string) error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
url, err := url.Parse(backend)
if err != nil {
return err
}
for _, b := range lb.backends {
if b.URL.String() == url.String() {
b.Down = false
return nil
}
}
return nil
}
// Reset 重置负载均衡器
func (lb *RandomBalancer) Reset() error {
lb.mutex.Lock()
defer lb.mutex.Unlock()
lb.backends = make([]*Backend, 0)
return nil
}

250
internal/metrics/metrics.go Normal file
View File

@@ -0,0 +1,250 @@
package metrics
import (
"fmt"
"net/http"
"sync"
"sync/atomic"
"time"
)
// Metrics 监控指标接口
type Metrics interface {
// 增加请求计数
IncRequestCount()
// 增加错误计数
IncErrorCount(err error)
// 观察请求持续时间
ObserveRequestDuration(seconds float64)
// 增加活跃连接数
IncActiveConnections()
// 减少活跃连接数
DecActiveConnections()
// 设置后端健康状态
SetBackendHealth(backend string, healthy bool)
// 设置后端响应时间
SetBackendResponseTime(backend string, duration time.Duration)
// 观察请求字节数
ObserveRequestBytes(bytes int64)
// 观察响应字节数
ObserveResponseBytes(bytes int64)
// 添加传输字节数
AddBytesTransferred(direction string, bytes int64)
// 增加缓存命中计数
IncCacheHit()
// 获取指标处理器
GetHandler() http.Handler
}
// SimpleMetrics 简单指标实现
type SimpleMetrics struct {
// 请求计数
requestCount int64
// 错误计数
errorCount int64
// 活跃连接数
activeConnections int64
// 累计响应时间
totalResponseTime int64
// 传输字节数
bytesTransferred map[string]int64
// 后端健康状态
backendHealth map[string]bool
// 后端响应时间
backendResponseTime map[string]time.Duration
// 缓存命中计数
cacheHits int64
// 互斥锁
mu sync.Mutex
}
// NewSimpleMetrics 创建简单指标
func NewSimpleMetrics() *SimpleMetrics {
return &SimpleMetrics{
bytesTransferred: make(map[string]int64),
backendHealth: make(map[string]bool),
backendResponseTime: make(map[string]time.Duration),
}
}
// IncRequestCount 增加请求计数
func (m *SimpleMetrics) IncRequestCount() {
atomic.AddInt64(&m.requestCount, 1)
}
// IncErrorCount 增加错误计数
func (m *SimpleMetrics) IncErrorCount(err error) {
atomic.AddInt64(&m.errorCount, 1)
}
// ObserveRequestDuration 观察请求持续时间
func (m *SimpleMetrics) ObserveRequestDuration(seconds float64) {
nsec := int64(seconds * float64(time.Second))
atomic.AddInt64(&m.totalResponseTime, nsec)
}
// IncActiveConnections 增加活跃连接数
func (m *SimpleMetrics) IncActiveConnections() {
atomic.AddInt64(&m.activeConnections, 1)
}
// DecActiveConnections 减少活跃连接数
func (m *SimpleMetrics) DecActiveConnections() {
atomic.AddInt64(&m.activeConnections, -1)
}
// SetBackendHealth 设置后端健康状态
func (m *SimpleMetrics) SetBackendHealth(backend string, healthy bool) {
m.backendHealth[backend] = healthy
}
// SetBackendResponseTime 设置后端响应时间
func (m *SimpleMetrics) SetBackendResponseTime(backend string, duration time.Duration) {
m.backendResponseTime[backend] = duration
}
// ObserveRequestBytes 观察请求字节数
func (m *SimpleMetrics) ObserveRequestBytes(bytes int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.bytesTransferred["request"] += bytes
}
// ObserveResponseBytes 观察响应字节数
func (m *SimpleMetrics) ObserveResponseBytes(bytes int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.bytesTransferred["response"] += bytes
}
// AddBytesTransferred 添加传输字节数
func (m *SimpleMetrics) AddBytesTransferred(direction string, bytes int64) {
m.mu.Lock()
defer m.mu.Unlock()
m.bytesTransferred[direction] += bytes
}
// IncCacheHit 增加缓存命中计数
func (m *SimpleMetrics) IncCacheHit() {
atomic.AddInt64(&m.cacheHits, 1)
}
// GetHandler 获取指标处理器
func (m *SimpleMetrics) GetHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
// 输出基本指标
w.Write([]byte("# HELP proxy_requests_total 代理请求总数\n"))
w.Write([]byte("# TYPE proxy_requests_total counter\n"))
w.Write([]byte(fmt.Sprintf("proxy_requests_total %d\n", m.requestCount)))
w.Write([]byte("# HELP proxy_errors_total 代理错误总数\n"))
w.Write([]byte("# TYPE proxy_errors_total counter\n"))
w.Write([]byte(fmt.Sprintf("proxy_errors_total %d\n", m.errorCount)))
w.Write([]byte("# HELP proxy_active_connections 当前活跃连接数\n"))
w.Write([]byte("# TYPE proxy_active_connections gauge\n"))
w.Write([]byte(fmt.Sprintf("proxy_active_connections %d\n", m.activeConnections)))
// 输出缓存命中数据
w.Write([]byte("# HELP proxy_cache_hits_total 缓存命中总数\n"))
w.Write([]byte("# TYPE proxy_cache_hits_total counter\n"))
w.Write([]byte(fmt.Sprintf("proxy_cache_hits_total %d\n", m.cacheHits)))
// 输出传输字节数
for direction, bytes := range m.bytesTransferred {
w.Write([]byte(fmt.Sprintf("# HELP proxy_bytes_transferred_%s 代理传输字节数(%s)\n", direction, direction)))
w.Write([]byte(fmt.Sprintf("# TYPE proxy_bytes_transferred_%s counter\n", direction)))
w.Write([]byte(fmt.Sprintf("proxy_bytes_transferred_%s %d\n", direction, bytes)))
}
// 输出后端健康状态
for backend, healthy := range m.backendHealth {
healthValue := 0
if healthy {
healthValue = 1
}
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_health 后端健康状态\n")))
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_health gauge\n")))
w.Write([]byte(fmt.Sprintf("proxy_backend_health{backend=\"%s\"} %d\n", backend, healthValue)))
}
// 输出后端响应时间
for backend, duration := range m.backendResponseTime {
w.Write([]byte(fmt.Sprintf("# HELP proxy_backend_response_time 后端响应时间\n")))
w.Write([]byte(fmt.Sprintf("# TYPE proxy_backend_response_time gauge\n")))
w.Write([]byte(fmt.Sprintf("proxy_backend_response_time{backend=\"%s\"} %f\n", backend, float64(duration)/float64(time.Second))))
}
// 平均响应时间
if m.requestCount > 0 {
avgTime := float64(m.totalResponseTime) / float64(m.requestCount) / float64(time.Second)
w.Write([]byte("# HELP proxy_average_response_time 平均响应时间\n"))
w.Write([]byte("# TYPE proxy_average_response_time gauge\n"))
w.Write([]byte(fmt.Sprintf("proxy_average_response_time %f\n", avgTime)))
}
})
}
// PrometheusMetrics Prometheus指标实现
type PrometheusMetrics struct {
// 可以通过引入prometheus客户端库实现更完整的指标收集
// 此处省略具体实现
}
// MetricsMiddleware 指标中间件
type MetricsMiddleware struct {
metrics Metrics
}
// NewMetricsMiddleware 创建指标中间件
func NewMetricsMiddleware(metrics Metrics) *MetricsMiddleware {
return &MetricsMiddleware{
metrics: metrics,
}
}
// Middleware 中间件处理函数
func (m *MetricsMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装响应写入器,用于捕获状态码
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
// 继续处理请求
next.ServeHTTP(rw, r)
// 记录请求指标
duration := time.Since(start)
m.metrics.ObserveRequestDuration(duration.Seconds())
})
}
// responseWriter 包装的响应写入器
type responseWriter struct {
http.ResponseWriter
statusCode int
written int64
}
// WriteHeader 写入状态码
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
// Write 写入数据
func (rw *responseWriter) Write(b []byte) (int, error) {
n, err := rw.ResponseWriter.Write(b)
rw.written += int64(n)
return n, err
}
// Flush 刷新数据
func (rw *responseWriter) Flush() {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View File

@@ -0,0 +1,184 @@
package middleware
import (
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiter 限流器接口
type RateLimiter interface {
// Allow 检查请求是否允许通过
Allow(key string) bool
}
// SimpleRateLimiter 简单限流器
type SimpleRateLimiter struct {
limiter *rate.Limiter
}
// NewSimpleRateLimiter 创建简单限流器
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
return &SimpleRateLimiter{
limiter: rate.NewLimiter(rate.Limit(r), b),
}
}
// Allow 检查请求是否允许通过
func (rl *SimpleRateLimiter) Allow(key string) bool {
return rl.limiter.Allow()
}
// IPRateLimiter 按IP限流
type IPRateLimiter struct {
ips map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
cleanupInterval time.Duration
lastSeen map[string]time.Time
}
// NewIPRateLimiter 创建IP限流器
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
limiter := &IPRateLimiter{
ips: make(map[string]*rate.Limiter),
rate: rate.Limit(r),
burst: b,
cleanupInterval: cleanup,
lastSeen: make(map[string]time.Time),
}
// 启动过期清理
if cleanup > 0 {
go limiter.startCleanup()
}
return limiter
}
// startCleanup 启动过期清理
func (rl *IPRateLimiter) startCleanup() {
ticker := time.NewTicker(rl.cleanupInterval)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
// cleanup 清理过期限流器
func (rl *IPRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for ip, lastSeen := range rl.lastSeen {
if now.Sub(lastSeen) > rl.cleanupInterval {
delete(rl.ips, ip)
delete(rl.lastSeen, ip)
}
}
}
// AddIP 添加IP限流器
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter := rate.NewLimiter(rl.rate, rl.burst)
rl.ips[ip] = limiter
rl.lastSeen[ip] = time.Now()
return limiter
}
// GetLimiter 获取IP限流器
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.ips[ip]
rl.mu.RUnlock()
if !exists {
return rl.AddIP(ip)
}
// 更新最后访问时间
rl.mu.Lock()
rl.lastSeen[ip] = time.Now()
rl.mu.Unlock()
return limiter
}
// Allow 检查请求是否允许通过
func (rl *IPRateLimiter) Allow(ip string) bool {
limiter := rl.GetLimiter(ip)
return limiter.Allow()
}
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
limiter RateLimiter
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
return &RateLimitMiddleware{
limiter: limiter,
}
}
// Middleware 中间件处理函数
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取客户端IP
ip := getClientIP(r)
// 检查是否允许通过
if !m.limiter.Allow(ip) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
// 继续处理请求
next.ServeHTTP(w, r)
})
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 检查 X-Forwarded-For 头
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
// 取第一个IP
for i := 0; i < len(ip) && i < 15; i++ {
if ip[i] == ',' {
ip = ip[:i]
break
}
}
return ip
}
// 检查 X-Real-IP 头
ip = r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
// 从 RemoteAddr 获取
if r.RemoteAddr != "" {
// 去掉端口部分
for i := 0; i < len(r.RemoteAddr); i++ {
if r.RemoteAddr[i] == ':' {
return r.RemoteAddr[:i]
}
}
return r.RemoteAddr
}
return "unknown"
}

View File

@@ -0,0 +1,156 @@
package middleware
import (
"bytes"
"io"
"math"
"net"
"net/http"
"time"
)
// RetryPolicy 重试策略
type RetryPolicy struct {
// 最大重试次数
MaxRetries int
// 基础退避时间
BaseBackoff time.Duration
// 最大退避时间
MaxBackoff time.Duration
// 重试判断函数
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
}
// DefaultRetryPolicy 默认重试策略
func DefaultRetryPolicy() *RetryPolicy {
return &RetryPolicy{
MaxRetries: 3,
BaseBackoff: 100 * time.Millisecond,
MaxBackoff: 2 * time.Second,
ShouldRetry: defaultShouldRetry,
}
}
// defaultShouldRetry 默认重试判断
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
// 不重试非幂等请求
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
return false
}
// 检查错误
if err != nil {
// 重试网络错误
if netErr, ok := err.(net.Error); ok {
return netErr.Temporary() || netErr.Timeout()
}
return false
}
// 检查响应状态码
if resp != nil {
// 重试服务器错误
return resp.StatusCode >= 500 && resp.StatusCode < 600
}
return false
}
// RetryRoundTripper 重试HTTP传输
type RetryRoundTripper struct {
// 下一级传输
Next http.RoundTripper
// 重试策略
Policy *RetryPolicy
}
// NewRetryRoundTripper 创建重试HTTP传输
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
if next == nil {
next = http.DefaultTransport
}
if policy == nil {
policy = DefaultRetryPolicy()
}
return &RetryRoundTripper{
Next: next,
Policy: policy,
}
}
// RoundTrip 执行HTTP请求
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// 需要保留原始请求体,以便重试
var reqBodyBytes []byte
if req.Body != nil {
var err error
reqBodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
}
var resp *http.Response
var err error
// 尝试请求直到成功或达到最大重试次数
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
// 复制请求体
if len(reqBodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
}
// 发送请求
resp, err = rt.Next.RoundTrip(req)
// 检查是否需要重试
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
// 如果需要重试,先关闭当前响应
if resp != nil {
resp.Body.Close()
}
// 计算退避时间
backoff := rt.calculateBackoff(attempt)
time.Sleep(backoff)
continue
}
// 不需要重试,返回响应
return resp, err
}
// 所有重试都失败
return resp, err
}
// calculateBackoff 计算退避时间
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
// 指数退避: baseBackoff * 2^attempt
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
if backoff > rt.Policy.MaxBackoff {
backoff = rt.Policy.MaxBackoff
}
return backoff
}
// RetryMiddleware 重试中间件
type RetryMiddleware struct {
policy *RetryPolicy
}
// NewRetryMiddleware 创建重试中间件
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
if policy == nil {
policy = DefaultRetryPolicy()
}
return &RetryMiddleware{
policy: policy,
}
}
// Middleware 中间件处理函数
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
return NewRetryRoundTripper(next, m.policy)
}

View File

@@ -0,0 +1,552 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"hash/fnv"
"math/big"
"net"
"os"
"strings"
"time"
)
// 内置默认CA证书和私钥
var (
// 默认根证书
defaultRootCAPem = []byte(`-----BEGIN CERTIFICATE-----
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
-----END CERTIFICATE-----
`)
// 默认根私钥
defaultRootKeyPem = []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIAXeEHO0FtFqQhTvsn/DT4g3rEos97+1Nibp9RfKOKhroAoGCCqGSM49
AwEHoUQDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQXmRgsFV5KHHmxOrVJBFC/
nDetmGowkARShWtBsX1Irm4w6i6Qk2QliA==
-----END EC PRIVATE KEY-----
`)
)
// 加载和初始化默认根证书
var (
defaultRootCA *x509.Certificate
defaultRootKey *ecdsa.PrivateKey
)
func init() {
var err error
block, _ := pem.Decode(defaultRootCAPem)
if block == nil {
panic("解析默认根证书PEM块失败")
}
defaultRootCA, err = x509.ParseCertificate(block.Bytes)
if err != nil {
panic(fmt.Errorf("加载默认根证书失败: %s", err))
}
block, _ = pem.Decode(defaultRootKeyPem)
if block == nil {
panic("解析默认根私钥PEM块失败")
}
defaultRootKey, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
panic(fmt.Errorf("加载默认根私钥失败: %s", err))
}
}
// CertManager 证书管理器
type CertManager struct {
// 证书缓存
cache CertificateCache
// 默认私钥,可用于多个证书共享
defaultPrivateKey interface{} // 改为interface{}以支持不同类型的私钥
// 默认使用ECDSA P-256曲线
curve elliptic.Curve
// 证书有效期(年)
validityYears int
// 是否使用ECDSA否则使用RSA
useECDSA bool
}
// NewCertManager 创建证书管理器
func NewCertManager(cache CertificateCache, options ...CertManagerOption) *CertManager {
manager := &CertManager{
cache: cache,
curve: elliptic.P256(), // 默认使用P-256曲线
validityYears: 1, // 默认证书有效期1年
useECDSA: true, // 默认使用ECDSA
}
// 应用选项
for _, option := range options {
option(manager)
}
return manager
}
// CertManagerOption 证书管理器选项
type CertManagerOption func(*CertManager)
// WithUseECDSA 设置是否使用ECDSA否则使用RSA
func WithUseECDSA(useECDSA bool) CertManagerOption {
return func(m *CertManager) {
m.useECDSA = useECDSA
}
}
// WithDefaultPrivateKey 设置是否使用默认私钥
func WithDefaultPrivateKey(enable bool) CertManagerOption {
return func(m *CertManager) {
if enable {
if m.useECDSA {
// 生成ECDSA私钥
priv, err := ecdsa.GenerateKey(m.curve, rand.Reader)
if err != nil {
panic(fmt.Errorf("生成默认ECDSA私钥失败: %s", err))
}
m.defaultPrivateKey = priv
} else {
// 生成RSA私钥
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
panic(fmt.Errorf("生成默认RSA私钥失败: %s", err))
}
m.defaultPrivateKey = priv
}
}
}
}
// WithCurve 设置椭圆曲线
func WithCurve(curve elliptic.Curve) CertManagerOption {
return func(m *CertManager) {
m.curve = curve
}
}
// WithValidityYears 设置证书有效期(年)
func WithValidityYears(years int) CertManagerOption {
return func(m *CertManager) {
if years > 0 {
m.validityYears = years
}
}
}
// GenerateTLSConfig 为指定主机生成TLS配置
func (m *CertManager) GenerateTLSConfig(host string) (*tls.Config, error) {
// 处理可能的端口
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// 检查证书缓存
if m.cache != nil {
// 检查主域名和子域名
fields := strings.Split(host, ".")
domains := []string{host}
// 添加父域名
if len(fields) > 2 {
domains = append(domains, strings.Join(fields[1:], "."))
}
// 查找缓存
for _, domain := range domains {
if cert := m.cache.Get(domain); cert != nil {
return &tls.Config{
Certificates: []tls.Certificate{*cert},
}, nil
}
}
}
// 生成新证书
cert, err := m.GenerateCertificate(host, defaultRootCA, defaultRootKey)
if err != nil {
return nil, err
}
// 缓存证书
if m.cache != nil {
// 缓存主机名
m.cache.Set(host, cert)
// 如果是IP地址不进行其他处理
if net.ParseIP(host) != nil {
return &tls.Config{
Certificates: []tls.Certificate{*cert},
}, nil
}
// 缓存域名和子域名证书
fields := strings.Split(host, ".")
if len(fields) >= 2 {
// 缓存主域名
domain := strings.Join(fields[1:], ".")
m.cache.Set(domain, cert)
}
}
return &tls.Config{
Certificates: []tls.Certificate{*cert},
}, nil
}
// GenerateCertificate 生成证书
func (m *CertManager) GenerateCertificate(host string, rootCA *x509.Certificate, rootKey *ecdsa.PrivateKey) (*tls.Certificate, error) {
// 准备私钥
var priv interface{}
var pubKey interface{}
var err error
// 使用默认私钥或生成新私钥
if m.defaultPrivateKey != nil {
priv = m.defaultPrivateKey
} else if m.useECDSA {
// 生成ECDSA私钥
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
if err != nil {
return nil, fmt.Errorf("生成ECDSA私钥失败: %s", err)
}
priv = ecdsaKey
pubKey = &ecdsaKey.PublicKey
} else {
// 生成RSA私钥
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("生成RSA私钥失败: %s", err)
}
priv = rsaKey
pubKey = &rsaKey.PublicKey
}
// 获取公钥
if pubKey == nil {
switch k := priv.(type) {
case *ecdsa.PrivateKey:
pubKey = &k.PublicKey
case *rsa.PrivateKey:
pubKey = &k.PublicKey
default:
return nil, fmt.Errorf("不支持的私钥类型")
}
}
// 创建证书模板
template := m.createCertificateTemplate(host)
// 签名证书
derBytes, err := x509.CreateCertificate(rand.Reader, template, rootCA, pubKey, rootKey)
if err != nil {
return nil, fmt.Errorf("创建证书失败: %s", err)
}
// 编码为PEM格式
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: derBytes,
})
// 编码私钥
var keyPEM []byte
switch k := priv.(type) {
case *ecdsa.PrivateKey:
privBytes, err := x509.MarshalECPrivateKey(k)
if err != nil {
return nil, fmt.Errorf("序列化ECDSA私钥失败: %s", err)
}
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: privBytes,
})
case *rsa.PrivateKey:
privBytes := x509.MarshalPKCS1PrivateKey(k)
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: privBytes,
})
default:
return nil, fmt.Errorf("不支持的私钥类型")
}
// 创建TLS证书
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, fmt.Errorf("创建TLS证书对失败: %s", err)
}
return &tlsCert, nil
}
// createCertificateTemplate 创建证书模板
func (m *CertManager) createCertificateTemplate(host string) *x509.Certificate {
// 使用基于主机名的哈希值作为序列号
fv := fnv.New64a()
fv.Write([]byte(host))
serialNumber := big.NewInt(0).SetUint64(fv.Sum64())
// 准备模板
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: host,
Organization: []string{"GoProxy Dynamic CA"},
Country: []string{"CN"},
Province: []string{"GuangDong"},
Locality: []string{"Guangzhou"},
},
NotBefore: time.Now().Add(-10 * time.Minute), // 提前10分钟生效容忍时间偏差
NotAfter: time.Now().AddDate(m.validityYears, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: false,
}
// 处理IP地址和域名
ipAddr := net.ParseIP(host)
if ipAddr != nil {
template.IPAddresses = []net.IP{ipAddr}
} else {
// 移除可能的端口部分
if strings.Contains(host, ":") {
host = strings.Split(host, ":")[0]
}
// 将主机名添加到DNS名称列表
template.DNSNames = []string{host}
// 添加通配符域名支持
fields := strings.Split(host, ".")
fieldNum := len(fields)
// 为每一级子域名添加通配符
for i := 0; i <= (fieldNum - 2); i++ {
wildcardDomain := "*." + strings.Join(fields[i:], ".")
// 避免重复
if wildcardDomain != host {
template.DNSNames = append(template.DNSNames, wildcardDomain)
}
}
}
return template
}
// LoadCAFromFiles 从文件加载CA证书和私钥
func LoadCAFromFiles(certFile, keyFile string) (*x509.Certificate, *ecdsa.PrivateKey, error) {
// 读取CA证书
caCertPEM, err := os.ReadFile(certFile)
if err != nil {
return nil, nil, fmt.Errorf("读取CA证书文件失败: %s", err)
}
block, _ := pem.Decode(caCertPEM)
if block == nil {
return nil, nil, fmt.Errorf("解析CA证书PEM块失败")
}
caCert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("解析CA证书失败: %s", err)
}
// 读取CA私钥
caKeyPEM, err := os.ReadFile(keyFile)
if err != nil {
return nil, nil, fmt.Errorf("读取CA私钥文件失败: %s", err)
}
block, _ = pem.Decode(caKeyPEM)
if block == nil {
return nil, nil, fmt.Errorf("解析CA私钥PEM块失败")
}
var caKey *ecdsa.PrivateKey
// 尝试不同的私钥格式
switch block.Type {
case "EC PRIVATE KEY":
caKey, err = x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("解析EC私钥失败: %s", err)
}
case "PRIVATE KEY":
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("解析PKCS8私钥失败: %s", err)
}
var ok bool
caKey, ok = key.(*ecdsa.PrivateKey)
if !ok {
return nil, nil, fmt.Errorf("私钥不是ECDSA类型")
}
default:
return nil, nil, fmt.Errorf("不支持的私钥类型: %s", block.Type)
}
return caCert, caKey, nil
}
// GetDefaultRootCA 获取默认根证书和私钥
func GetDefaultRootCA() (*x509.Certificate, *ecdsa.PrivateKey) {
return defaultRootCA, defaultRootKey
}
// GenerateRootCA 生成新的根证书和私钥
func GenerateRootCA(validYears int) (*x509.Certificate, *ecdsa.PrivateKey, error) {
// 生成私钥
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("生成根证书私钥失败: %s", err)
}
// 随机生成序列号
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, nil, fmt.Errorf("生成序列号失败: %s", err)
}
// 创建根证书模板
notBefore := time.Now().Add(-10 * time.Minute)
notAfter := notBefore.AddDate(validYears, 0, 0)
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: "GoProxy Root CA",
Organization: []string{"GoProxy"},
Country: []string{"CN"},
Province: []string{"GuangDong"},
Locality: []string{"Guangzhou"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 2,
}
// 自签名
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
}
// 解析生成的证书
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
}
return cert, priv, nil
}
// SaveCertificateToFile 将证书和私钥保存到文件
func SaveCertificateToFile(cert *x509.Certificate, key *ecdsa.PrivateKey, certFile, keyFile string) error {
// 保存证书
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
if err := os.WriteFile(certFile, certPEM, 0644); err != nil {
return fmt.Errorf("保存证书到文件失败: %s", err)
}
// 保存私钥
keyBytes, err := x509.MarshalECPrivateKey(key)
if err != nil {
return fmt.Errorf("序列化私钥失败: %s", err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyBytes,
})
if err := os.WriteFile(keyFile, keyPEM, 0600); err != nil {
return fmt.Errorf("保存私钥到文件失败: %s", err)
}
return nil
}
// GenerateRootCA 生成根证书
func (m *CertManager) GenerateRootCA() (*x509.Certificate, interface{}, error) {
// 生成私钥
var priv interface{}
var pubKey interface{}
var err error
if m.useECDSA {
// 生成ECDSA私钥
ecdsaKey, err := ecdsa.GenerateKey(m.curve, rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("生成ECDSA根私钥失败: %s", err)
}
priv = ecdsaKey
pubKey = &ecdsaKey.PublicKey
} else {
// 生成RSA私钥
rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, fmt.Errorf("生成RSA根私钥失败: %s", err)
}
priv = rsaKey
pubKey = &rsaKey.PublicKey
}
// 创建根证书模板
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "GoProxy Root CA",
Organization: []string{"GoProxy Authority"},
Country: []string{"CN"},
Province: []string{"GuangDong"},
Locality: []string{"Guangzhou"},
},
NotBefore: time.Now().Add(-10 * time.Minute),
NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 1,
}
// 自签名
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, pubKey, priv)
if err != nil {
return nil, nil, fmt.Errorf("创建根证书失败: %s", err)
}
// 解析证书
cert, err := x509.ParseCertificate(derBytes)
if err != nil {
return nil, nil, fmt.Errorf("解析生成的根证书失败: %s", err)
}
return cert, priv, nil
}

View File

@@ -0,0 +1,134 @@
package proxy
import (
"bufio"
"io"
"net"
"time"
)
// ConnBuffer 连接缓冲区
// 封装了底层网络连接和缓冲读取器,提供了更方便的读写接口
type ConnBuffer struct {
// 底层连接
conn net.Conn
// 缓冲读取器
reader *bufio.Reader
}
// NewConnBuffer 创建连接缓冲区
func NewConnBuffer(conn net.Conn, reader *bufio.Reader) *ConnBuffer {
if reader == nil {
reader = bufio.NewReader(conn)
}
return &ConnBuffer{
conn: conn,
reader: reader,
}
}
// Read 从连接读取数据
func (c *ConnBuffer) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// Write 向连接写入数据
func (c *ConnBuffer) Write(b []byte) (int, error) {
return c.conn.Write(b)
}
// Close 关闭连接
func (c *ConnBuffer) Close() error {
return c.conn.Close()
}
// LocalAddr 获取本地地址
func (c *ConnBuffer) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
// RemoteAddr 获取远程地址
func (c *ConnBuffer) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline 设置读写超时
func (c *ConnBuffer) SetDeadline(t time.Time) error {
return c.conn.SetDeadline(t)
}
// SetReadDeadline 设置读取超时
func (c *ConnBuffer) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
// SetWriteDeadline 设置写入超时
func (c *ConnBuffer) SetWriteDeadline(t time.Time) error {
return c.conn.SetWriteDeadline(t)
}
// BufferReader 获取缓冲读取器
func (c *ConnBuffer) BufferReader() *bufio.Reader {
return c.reader
}
// Peek 查看缓冲区中的数据,但不消费
func (c *ConnBuffer) Peek(n int) ([]byte, error) {
return c.reader.Peek(n)
}
// ReadByte 读取一个字节
func (c *ConnBuffer) ReadByte() (byte, error) {
return c.reader.ReadByte()
}
// UnreadByte 将最后读取的字节放回缓冲区
func (c *ConnBuffer) UnreadByte() error {
return c.reader.UnreadByte()
}
// ReadLine 读取一行数据
func (c *ConnBuffer) ReadLine() (string, error) {
line, isPrefix, err := c.reader.ReadLine()
if err != nil {
return "", err
}
// 如果一行数据没有读取完整,继续读取
if isPrefix {
var buf []byte
buf = append(buf, line...)
for isPrefix && err == nil {
line, isPrefix, err = c.reader.ReadLine()
if err != nil {
return "", err
}
buf = append(buf, line...)
}
return string(buf), nil
}
return string(line), nil
}
// ReadN 读取指定字节数的数据
func (c *ConnBuffer) ReadN(n int) ([]byte, error) {
buf := make([]byte, n)
_, err := io.ReadFull(c.reader, buf)
if err != nil {
return nil, err
}
return buf, nil
}
// Flush 刷新缓冲区
func (c *ConnBuffer) Flush() error {
// 由于我们只有读取缓冲区,没有写入缓冲区,所以这里不需要实际操作
return nil
}
// Reset 重置连接缓冲区
func (c *ConnBuffer) Reset(conn net.Conn) {
c.conn = conn
c.reader = bufio.NewReader(conn)
}

132
internal/proxy/context.go Normal file
View File

@@ -0,0 +1,132 @@
package proxy
import (
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// Context 代理上下文
// 包含了代理请求的上下文信息,用于在代理处理过程中传递数据
type Context struct {
// 原始请求
Req *http.Request
// 请求开始时间
StartTime time.Time
// 上下文数据,用于在各个处理阶段传递数据
Data map[interface{}]interface{}
// 是否是隧道代理
TunnelProxy bool
// 请求ID
RequestID string
// 目标地址
TargetAddr string
// 上级代理地址
ParentProxyURL *url.URL
// 是否中断执行
abort bool
// 请求标签,用于标记请求类型
Tags []string
// 是否已中止
aborted bool
// 互斥锁
mu sync.RWMutex
}
// IsHTTPS 是否是HTTPS请求
func (c *Context) IsHTTPS() bool {
return c.Req.URL.Scheme == "https" || c.Req.Method == http.MethodConnect
}
// defaultPorts 默认端口映射
var defaultPorts = map[string]string{
"https": "443",
"http": "80",
"": "80",
}
// WebSocketURL 获取WebSocket URL
func (c *Context) WebSocketURL() *url.URL {
u := *c.Req.URL
if c.IsHTTPS() {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
return &u
}
// Addr 获取请求地址
func (c *Context) Addr() string {
addr := c.Req.Host
if !strings.Contains(c.Req.URL.Host, ":") {
addr += ":" + defaultPorts[c.Req.URL.Scheme]
}
return addr
}
// AddTag 添加请求标签
func (c *Context) AddTag(tag string) {
c.Tags = append(c.Tags, tag)
}
// HasTag 检查是否包含指定标签
func (c *Context) HasTag(tag string) bool {
for _, t := range c.Tags {
if t == tag {
return true
}
}
return false
}
// Abort 中断执行
func (c *Context) Abort() {
c.aborted = true
}
// IsAborted 是否已中断执行
func (c *Context) IsAborted() bool {
return c.aborted
}
// Reset 重置上下文
func (c *Context) Reset(req *http.Request) {
c.Req = req
c.StartTime = time.Now()
c.Data = make(map[interface{}]interface{})
c.abort = false
c.TunnelProxy = false
c.Tags = make([]string, 0)
c.RequestID = ""
c.TargetAddr = ""
c.ParentProxyURL = nil
c.aborted = false
}
// Set 设置数据
func (c *Context) Set(key, value interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
if c.Data == nil {
c.Data = make(map[interface{}]interface{})
}
c.Data[key] = value
}
// Get 获取数据
func (c *Context) Get(key interface{}) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.Data == nil {
return nil, false
}
val, ok := c.Data[key]
return val, ok
}

123
internal/proxy/delegate.go Normal file
View File

@@ -0,0 +1,123 @@
package proxy
import (
"net/http"
"net/url"
)
// Delegate 代理委托接口
// 定义了代理处理请求的各个阶段的回调方法
type Delegate interface {
// Connect 连接事件
Connect(ctx *Context, rw http.ResponseWriter)
// Auth 认证事件
Auth(ctx *Context, rw http.ResponseWriter)
// BeforeRequest 请求前事件
BeforeRequest(ctx *Context)
// BeforeResponse 响应前事件
BeforeResponse(ctx *Context, resp *http.Response, err error)
// WebSocketSendMessage websocket发送消息拦截
// WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
// WebSocketReceiveMessage websocket接收消息拦截
// WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
// ParentProxy 获取上级代理
ParentProxy(req *http.Request) (*url.URL, error)
// ErrorLog 错误日志
ErrorLog(err error)
// Finish 完成事件
Finish(ctx *Context)
// 以下是反向代理相关的方法
// ResolveBackend 解析后端服务器
// 在反向代理模式下,根据请求确定应该转发到哪个后端服务器
ResolveBackend(req *http.Request) (string, error)
// ModifyRequest 修改请求
// 在反向代理模式下,可以修改发往后端服务器的请求
ModifyRequest(req *http.Request)
// ModifyResponse 修改响应
// 在反向代理模式下,可以修改来自后端服务器的响应
ModifyResponse(resp *http.Response) error
// HandleError 处理错误
// 在反向代理模式下,可以自定义错误处理逻辑
HandleError(rw http.ResponseWriter, req *http.Request, err error)
}
// DefaultDelegate 默认代理委托
type DefaultDelegate struct{}
// Connect 连接事件
func (d *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {
// 默认实现不做任何处理
}
// Auth 认证事件
func (d *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {
// 默认实现不做任何处理
}
// BeforeRequest 请求前事件
func (d *DefaultDelegate) BeforeRequest(ctx *Context) {
// 默认实现不做任何处理
}
// BeforeResponse 响应前事件
func (d *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {
// 默认实现不做任何处理
}
// ParentProxy 获取上级代理
func (d *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
// 默认实现不使用上级代理
return nil, nil
}
// WebSocketSendMessage websocket发送消息拦截
// func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
// WebSocketReceiveMessage websocket接收消息拦截
// func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
// ErrorLog 错误日志
func (d *DefaultDelegate) ErrorLog(err error) {
// 默认实现不做任何处理
}
// Finish 完成事件
func (d *DefaultDelegate) Finish(ctx *Context) {
// 默认实现不做任何处理
}
// ResolveBackend 解析后端服务器
func (d *DefaultDelegate) ResolveBackend(req *http.Request) (string, error) {
// 默认实现返回请求中的主机
return req.Host, nil
}
// ModifyRequest 修改请求
func (d *DefaultDelegate) ModifyRequest(req *http.Request) {
// 默认实现不做任何修改
}
// ModifyResponse 修改响应
func (d *DefaultDelegate) ModifyResponse(resp *http.Response) error {
// 默认实现不做任何修改
return nil
}
// HandleError 处理错误
func (d *DefaultDelegate) HandleError(rw http.ResponseWriter, req *http.Request, err error) {
// 默认实现返回502错误
http.Error(rw, err.Error(), http.StatusBadGateway)
}

262
internal/proxy/options.go Normal file
View File

@@ -0,0 +1,262 @@
package proxy
import (
"net/http"
"net/http/httptrace"
"time"
"github.com/goproxy/internal/cache"
"github.com/goproxy/internal/config"
"github.com/goproxy/internal/healthcheck"
"github.com/goproxy/internal/loadbalance"
"github.com/goproxy/internal/metrics"
)
// Option 用于配置代理选项的函数类型
type Option func(*Options)
// WithConfig 设置代理配置
func WithConfig(cfg *config.Config) Option {
return func(opt *Options) {
opt.Config = cfg
}
}
// WithDisableKeepAlive 设置连接是否重用
func WithDisableKeepAlive(disableKeepAlive bool) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
// 在transport中设置DisableKeepAlives
}
}
// WithClientTrace 设置HTTP客户端跟踪
func WithClientTrace(t *httptrace.ClientTrace) Option {
return func(opt *Options) {
opt.ClientTrace = t
}
}
// WithDelegate 设置委托类
func WithDelegate(delegate Delegate) Option {
return func(opt *Options) {
opt.Delegate = delegate
}
}
// WithTransport 使用自定义HTTP传输
func WithTransport(t *http.Transport) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
// 在New方法中处理transport
}
}
// WithDecryptHTTPS 启用中间人代理解密HTTPS
func WithDecryptHTTPS(c CertificateCache) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.DecryptHTTPS = true
opt.CertCache = c
}
}
// WithEnableWebsocketIntercept 启用WebSocket拦截
func WithEnableWebsocketIntercept() Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
// WebSocket拦截在代理处理逻辑中实现
}
}
// WithHTTPCache 设置HTTP缓存
func WithHTTPCache(c cache.Cache) Option {
return func(opt *Options) {
opt.HTTPCache = c
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableCache = true
}
}
// WithLoadBalancer 设置负载均衡器
func WithLoadBalancer(lb loadbalance.LoadBalancer) Option {
return func(opt *Options) {
opt.LoadBalancer = lb
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableLoadBalancing = true
}
}
// WithHealthChecker 设置健康检查器
func WithHealthChecker(hc *healthcheck.HealthChecker) Option {
return func(opt *Options) {
opt.HealthChecker = hc
}
}
// WithMetrics 设置监控指标
func WithMetrics(m metrics.Metrics) Option {
return func(opt *Options) {
opt.Metrics = m
}
}
// WithTLSCertAndKey 设置TLS证书和密钥
func WithTLSCertAndKey(certPath, keyPath string) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.TLSCert = certPath
opt.Config.TLSKey = keyPath
}
}
// WithCACertAndKey 设置CA证书和密钥用于生成动态证书
func WithCACertAndKey(caCertPath, caKeyPath string) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.CACert = caCertPath
opt.Config.CAKey = caKeyPath
}
}
// WithConnectionPoolSize 设置连接池大小
func WithConnectionPoolSize(size int) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.ConnectionPoolSize = size
}
}
// WithIdleTimeout 设置空闲超时时间
func WithIdleTimeout(timeout time.Duration) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.IdleTimeout = timeout
}
}
// WithRequestTimeout 设置请求超时时间
func WithRequestTimeout(timeout time.Duration) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.RequestTimeout = timeout
}
}
// WithReverseProxy 启用反向代理模式
func WithReverseProxy(enable bool) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.ReverseProxy = enable
}
}
// WithEnableRetry 启用请求重试
func WithEnableRetry(maxRetries int, baseBackoff, maxBackoff time.Duration) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableRetry = true
opt.Config.MaxRetries = maxRetries
opt.Config.RetryBackoff = baseBackoff
opt.Config.MaxRetryBackoff = maxBackoff
}
}
// WithRateLimit 设置请求限流
func WithRateLimit(rps float64) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableRateLimit = true
opt.Config.RateLimit = rps
}
}
// WithDNSCacheTTL 设置DNS缓存TTL
func WithDNSCacheTTL(ttl time.Duration) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.DNSCacheTTL = ttl
}
}
// WithURLRewrite 启用URL重写
func WithURLRewrite(enable bool) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableURLRewrite = enable
}
}
// WithEnableCORS 启用CORS支持
func WithEnableCORS(enable bool) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.EnableCORS = enable
}
}
// WithCertManager 设置证书管理器
// 这是一个内部函数主要用于在New方法中设置CertManager
func WithCertManager(certManager *CertManager) Option {
return func(opt *Options) {
opt.CertManager = certManager
}
}
// WithEnableECDSA 启用ECDSA证书生成默认使用RSA
func WithEnableECDSA(enable bool) Option {
return func(opt *Options) {
if opt.Config == nil {
opt.Config = config.DefaultConfig()
}
opt.Config.UseECDSA = enable
}
}
// NewWithOptions 使用选项函数创建代理
func NewWithOptions(options ...Option) *Proxy {
opts := &Options{
Config: config.DefaultConfig(),
}
// 应用所有选项
for _, option := range options {
option(opts)
}
return New(opts)
}

1130
internal/proxy/proxy.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,277 @@
package proxy
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/goproxy/internal/rewriter"
"github.com/goproxy/internal/router"
)
// ReverseProxy 反向代理
type ReverseProxy struct {
// 代理对象
proxy *Proxy
// 路由器
router *router.Router
// URL重写器
rewriter *rewriter.Rewriter
// HTTP传输对象
transport http.RoundTripper
}
// NewReverseProxy 创建反向代理
func (p *Proxy) NewReverseProxy() *ReverseProxy {
rp := &ReverseProxy{
proxy: p,
router: router.NewRouter(),
rewriter: rewriter.NewRewriter(),
}
// 创建自定义的传输对象
transport := &http.Transport{
Proxy: func(req *http.Request) (*url.URL, error) {
// 使用代理委托中的方法获取代理
return p.delegate.ParentProxy(req)
},
DialContext: p.dialContextWithCache(),
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: p.config.ConnectionPoolSize,
IdleConnTimeout: p.config.IdleTimeout,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: p.config.ConnectionPoolSize,
DisableCompression: !p.config.EnableCompression,
}
rp.transport = transport
// 如果配置了规则文件,加载规则
if p.config.ReverseProxyRulesFile != "" {
// 省略加载规则文件的实现
}
return rp
}
// ServeHTTP 处理反向代理请求
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// 获取请求上下文
ctx := ctxPool.Get().(*Context)
ctx.Reset(req)
defer ctxPool.Put(ctx)
// 调用连接事件
rp.proxy.delegate.Connect(ctx, rw)
// 认证检查
rp.proxy.delegate.Auth(ctx, rw)
if ctx.IsAborted() {
return
}
// 请求前处理
rp.proxy.delegate.BeforeRequest(ctx)
if ctx.IsAborted() {
return
}
// 解析后端地址
backend, err := rp.proxy.delegate.ResolveBackend(req)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, err)
return
}
// 创建请求代理对象
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
Scheme: "http",
Host: backend,
})
// 使用自定义传输对象
proxy.Transport = rp.transport
// 设置自定义错误处理函数
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
rp.proxy.delegate.HandleError(rw, req, err)
}
// 设置请求修改函数
originalDirector := proxy.Director
proxy.Director = func(req *http.Request) {
// 调用原始Director函数
originalDirector(req)
// 处理URL重写
if rp.proxy.config.EnableURLRewrite {
rp.rewriter.Rewrite(req)
}
// 修改请求头
if rp.proxy.config.RewriteHostHeader {
req.Host = backend
}
// 添加X-Forwarded-For头
if rp.proxy.config.AddXForwardedFor {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
// 如果已经有X-Forwarded-For添加到末尾
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}
}
// 添加X-Real-IP头
if rp.proxy.config.AddXRealIP {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err == nil {
req.Header.Set("X-Real-IP", clientIP)
}
}
// 设置协议头
req.Header.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
}
// 调用委托的ModifyRequest方法
rp.proxy.delegate.ModifyRequest(req)
}
// 设置响应修改函数
proxy.ModifyResponse = func(resp *http.Response) error {
// 处理响应URL重写
if rp.proxy.config.EnableURLRewrite && resp != nil {
rp.rewriter.RewriteResponse(resp, req.Host)
}
// 添加CORS头
if rp.proxy.config.EnableCORS && resp != nil {
resp.Header.Set("Access-Control-Allow-Origin", "*")
resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
// 调用委托的ModifyResponse方法
return rp.proxy.delegate.ModifyResponse(resp)
}
// 更新监控指标
if rp.proxy.metrics != nil {
rp.proxy.metrics.IncActiveConnections()
defer rp.proxy.metrics.DecActiveConnections()
startTime := time.Now()
defer func() {
duration := time.Since(startTime)
rp.proxy.metrics.ObserveRequestDuration(duration.Seconds())
rp.proxy.metrics.IncRequestCount()
}()
}
// 处理WebSocket升级
if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) {
rp.handleWebSocketUpgrade(rw, req, backend)
return
}
// 处理普通请求
proxy.ServeHTTP(rw, req)
// 完成事件
rp.proxy.delegate.Finish(ctx)
}
// 处理WebSocket升级
func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) {
// 创建WebSocket代理
target := &url.URL{
Scheme: "ws",
Host: backend,
}
if req.TLS != nil {
target.Scheme = "wss"
}
// 创建连接到后端的WebSocket连接
backendConn, err := rp.dialBackend(target.String(), req)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err))
return
}
defer backendConn.Close()
// 将请求转发给后端
err = req.Write(backendConn)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err))
return
}
// 升级客户端连接
clientConn, err := hijacker(rw)
if err != nil {
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err))
return
}
// 双向转发数据
rp.proxy.transfer(clientConn, backendConn)
}
// 连接到后端
func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) {
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
defer cancel()
backend := strings.TrimPrefix(url, "ws://")
backend = strings.TrimPrefix(backend, "wss://")
if strings.Contains(backend, "/") {
backend = backend[:strings.Index(backend, "/")]
}
// 根据协议选择连接方式
if strings.HasPrefix(url, "wss://") {
// 使用 tls.Dialer 替代不存在的 tls.DialWithContext
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, "tcp", backend)
}
var d net.Dialer
return d.DialContext(ctx, "tcp", backend)
}
// 添加路由规则
func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) {
route := &router.Route{
Pattern: pattern,
Type: routeType,
Target: target,
}
rp.router.AddRoute(route)
}
// 添加重写规则
func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error {
return rp.rewriter.AddRule(pattern, replacement, useRegex)
}

View File

@@ -0,0 +1,98 @@
package rewriter
import (
"net/http"
"regexp"
"strings"
)
// Rewriter URL重写器
// 用于在反向代理中重写请求URL
type Rewriter struct {
// 重写规则列表
rules []*RewriteRule
}
// RewriteRule 重写规则
type RewriteRule struct {
// 匹配模式
Pattern string
// 替换模式
Replacement string
// 是否使用正则表达式
UseRegex bool
// 编译后的正则表达式
regex *regexp.Regexp
}
// NewRewriter 创建URL重写器
func NewRewriter() *Rewriter {
return &Rewriter{
rules: make([]*RewriteRule, 0),
}
}
// AddRule 添加重写规则
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
rule := &RewriteRule{
Pattern: pattern,
Replacement: replacement,
UseRegex: useRegex,
}
if useRegex {
regex, err := regexp.Compile(pattern)
if err != nil {
return err
}
rule.regex = regex
}
r.rules = append(r.rules, rule)
return nil
}
// Rewrite 重写URL
func (r *Rewriter) Rewrite(req *http.Request) {
path := req.URL.Path
for _, rule := range r.rules {
if rule.UseRegex {
if rule.regex.MatchString(path) {
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
break
}
} else {
if strings.HasPrefix(path, rule.Pattern) {
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
break
}
}
}
}
// RewriteResponse 重写响应
// 主要用于处理响应中的Location头和内容中的URL
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
// 处理重定向头
location := resp.Header.Get("Location")
if location != "" {
// 将后端服务器的域名替换成代理服务器的域名
for _, rule := range r.rules {
if rule.UseRegex && rule.regex != nil {
if rule.regex.MatchString(location) {
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
resp.Header.Set("Location", newLocation)
break
}
}
}
}
}
// LoadRulesFromFile 从文件加载重写规则
func (r *Rewriter) LoadRulesFromFile(filename string) error {
// 实现从配置文件加载规则的逻辑
// 这里省略实现细节
return nil
}

103
internal/router/router.go Normal file
View File

@@ -0,0 +1,103 @@
package router
import (
"net/http"
"regexp"
"strings"
)
// Route 路由规则
type Route struct {
// 匹配模式(主机名、路径、正则表达式)
Pattern string
// 匹配类型
Type RouteType
// 目标地址
Target string
// 路径重写规则
RewritePattern string
// 请求头修改
HeaderModifier HeaderModifier
// 自定义匹配函数
MatchFunc func(req *http.Request) bool
}
// RouteType 路由类型
type RouteType int
const (
// HostRoute 主机名路由
HostRoute RouteType = iota
// PathRoute 路径路由
PathRoute
// RegexRoute 正则表达式路由
RegexRoute
// CustomRoute 自定义路由
CustomRoute
)
// HeaderModifier 头部修改接口
type HeaderModifier interface {
// ModifyRequest 修改请求头
ModifyRequest(req *http.Request)
// ModifyResponse 修改响应头
ModifyResponse(resp *http.Response)
}
// Router 路由器
type Router struct {
routes []*Route
}
// NewRouter 创建路由器
func NewRouter() *Router {
return &Router{
routes: make([]*Route, 0),
}
}
// AddRoute 添加路由规则
func (r *Router) AddRoute(route *Route) {
r.routes = append(r.routes, route)
}
// Match 匹配请求
func (r *Router) Match(req *http.Request) (*Route, bool) {
for _, route := range r.routes {
switch route.Type {
case HostRoute:
if matchHost(req.Host, route.Pattern) {
return route, true
}
case PathRoute:
if matchPath(req.URL.Path, route.Pattern) {
return route, true
}
case RegexRoute:
if matchRegex(req.URL.String(), route.Pattern) {
return route, true
}
case CustomRoute:
if route.MatchFunc != nil && route.MatchFunc(req) {
return route, true
}
}
}
return nil, false
}
// 匹配主机名
func matchHost(host, pattern string) bool {
return host == pattern || strings.HasSuffix(host, "."+pattern)
}
// 匹配路径
func matchPath(path, pattern string) bool {
return strings.HasPrefix(path, pattern)
}
// 匹配正则表达式
func matchRegex(url, pattern string) bool {
matched, _ := regexp.MatchString(pattern, url)
return matched
}

14
mitm-proxy.crt Normal file
View File

@@ -0,0 +1,14 @@
-----BEGIN CERTIFICATE-----
MIICJzCCAcygAwIBAgIITWWCIQf8/VIwCgYIKoZIzj0EAwIwUzEOMAwGA1UEBhMF
Q2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0GA1UEBxMGWGlhbWVuMRAwDgYDVQQK
EwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMB4XDTIyMDMyNTA1NDgwMFoXDTQyMDQy
NTA1NDgwMFowUzEOMAwGA1UEBhMFQ2hpbmExDzANBgNVBAgTBkZ1SmlhbjEPMA0G
A1UEBxMGWGlhbWVuMRAwDgYDVQQKEwdHb3Byb3h5MQ0wCwYDVQQDEwRNYXJzMFkw
EwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEf0mhVJmuTmxnLimKshdEE4+PYdxvBfQX
mRgsFV5KHHmxOrVJBFC/nDetmGowkARShWtBsX1Irm4w6i6Qk2QliKOBiTCBhjAO
BgNVHQ8BAf8EBAMCAQYwHQYDVR0lBBYwFAYIKwYBBQUHAwIGCCsGAQUFBwMBMBIG
A1UdEwEB/wQIMAYBAf8CAQIwHQYDVR0OBBYEFBI5TkWYcvUIWsBAdffs833FnBrI
MCIGA1UdEQQbMBmBF3FpbmdxaWFubHVkYW9AZ21haWwuY29tMAoGCCqGSM49BAMC
A0kAMEYCIQCk1DhW7AmIW/n/QLftQq8BHZKLevWYJ813zdrNr5kXlwIhAIVvqglY
9BkYWg4NEe/mVO4C5Vtu4FnzNU9I+rFpXVSO
-----END CERTIFICATE-----

753
proxy.go Normal file
View File

@@ -0,0 +1,753 @@
// Copyright 2018 ouqiang authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
package goproxy
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/viki-org/dnscache"
"github.com/ouqiang/goproxy/cert"
"github.com/ouqiang/websocket"
)
const (
// 连接目标服务器超时时间
defaultTargetConnectTimeout = 5 * time.Second
// 目标服务器读写超时时间
defaultTargetReadWriteTimeout = 10 * time.Second
)
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// 隧道连接成功响应行
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
var (
bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
ctxPool = sync.Pool{
New: func() interface{} {
return new(Context)
},
}
headerPool = NewHeaderPool()
requestPool = newRequestPool()
)
type RequestPool struct {
pool sync.Pool
}
func newRequestPool() *RequestPool {
return &RequestPool{
pool: sync.Pool{
New: func() interface{} {
return new(http.Request)
},
},
}
}
func (p *RequestPool) Get() *http.Request {
req := p.pool.Get().(*http.Request)
req.Method = ""
req.URL = nil
req.Proto = ""
req.ProtoMajor = 0
req.ProtoMinor = 0
req.Header = nil
req.Body = nil
req.GetBody = nil
req.ContentLength = 0
req.TransferEncoding = nil
req.Close = false
req.Host = ""
req.Form = nil
req.PostForm = nil
req.MultipartForm = nil
req.Trailer = nil
req.RemoteAddr = ""
req.RequestURI = ""
req.TLS = nil
req.Cancel = nil
req.Response = nil
return req
}
func (p *RequestPool) Put(req *http.Request) {
if req != nil {
p.pool.Put(req)
}
}
type HeaderPool struct {
pool sync.Pool
}
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
return http.Header{}
},
},
}
}
func (p *HeaderPool) Get() http.Header {
header := p.pool.Get().(http.Header)
for k := range header {
delete(header, k)
}
return header
}
func (p *HeaderPool) Put(header http.Header) {
if header != nil {
p.pool.Put(header)
}
}
// 生成隧道建立请求行
func makeTunnelRequestLine(addr string) string {
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
}
type options struct {
disableKeepAlive bool
delegate Delegate
decryptHTTPS bool
websocketIntercept bool
certCache cert.Cache
transport *http.Transport
clientTrace *httptrace.ClientTrace
}
type Option func(*options)
// WithDisableKeepAlive 连接是否重用
func WithDisableKeepAlive(disableKeepAlive bool) Option {
return func(opt *options) {
opt.disableKeepAlive = disableKeepAlive
}
}
func WithClientTrace(t *httptrace.ClientTrace) Option {
return func(opt *options) {
opt.clientTrace = t
}
}
// WithDelegate 设置委托类
func WithDelegate(delegate Delegate) Option {
return func(opt *options) {
opt.delegate = delegate
}
}
// WithTransport 自定义http transport
func WithTransport(t *http.Transport) Option {
return func(opt *options) {
opt.transport = t
}
}
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
func WithDecryptHTTPS(c cert.Cache) Option {
return func(opt *options) {
opt.decryptHTTPS = true
opt.certCache = c
}
}
// WithEnableWebsocketIntercept 拦截websocket
func WithEnableWebsocketIntercept() Option {
return func(opt *options) {
opt.websocketIntercept = true
}
}
// New 创建proxy实例
func New(opt ...Option) *Proxy {
opts := &options{}
for _, o := range opt {
o(opts)
}
if opts.delegate == nil {
opts.delegate = &DefaultDelegate{}
}
if opts.transport == nil {
opts.transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: 100,
MaxConnsPerHost: 10,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
p := &Proxy{}
p.delegate = opts.delegate
p.websocketIntercept = opts.websocketIntercept
p.decryptHTTPS = opts.decryptHTTPS
if p.decryptHTTPS {
p.cert = cert.NewCertificate(opts.certCache, true)
}
p.transport = opts.transport
p.transport.DialContext = p.dialContext()
p.dnsCache = dnscache.New(5 * time.Minute)
p.transport.DisableKeepAlives = opts.disableKeepAlive
p.transport.Proxy = p.delegate.ParentProxy
p.clientTrace = opts.clientTrace
return p
}
// Proxy 实现了http.Handler接口
type Proxy struct {
delegate Delegate
clientConnNum int32
decryptHTTPS bool
websocketIntercept bool
cert *cert.Certificate
transport *http.Transport
clientTrace *httptrace.ClientTrace
dnsCache *dnscache.Resolver
}
var _ http.Handler = &Proxy{}
// ServeHTTP 实现了http.Handler接口
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Host == "" {
req.URL.Host = req.Host
}
atomic.AddInt32(&p.clientConnNum, 1)
ctx := ctxPool.Get().(*Context)
ctx.Reset(req)
defer func() {
p.delegate.Finish(ctx)
ctxPool.Put(ctx)
atomic.AddInt32(&p.clientConnNum, -1)
}()
p.delegate.Connect(ctx, rw)
if ctx.abort {
return
}
p.delegate.Auth(ctx, rw)
if ctx.abort {
return
}
switch {
case ctx.Req.Method == http.MethodConnect:
p.tunnelProxy(ctx, rw)
case websocket.IsWebSocketUpgrade(ctx.Req):
p.tunnelProxy(ctx, rw)
default:
p.httpProxy(ctx, rw)
}
}
// ClientConnNum 获取客户端连接数
func (p *Proxy) ClientConnNum() int32 {
return atomic.LoadInt32(&p.clientConnNum)
}
// DoRequest 执行HTTP请求并调用responseFunc处理response
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
if ctx.Data == nil {
ctx.Data = make(map[interface{}]interface{})
}
p.delegate.BeforeRequest(ctx)
if ctx.abort {
return
}
newReq := requestPool.Get()
*newReq = *ctx.Req
newHeader := headerPool.Get()
CloneHeader(newReq.Header, newHeader)
newReq.Header = newHeader
for _, item := range hopHeaders {
if newReq.Header.Get(item) != "" {
newReq.Header.Del(item)
}
}
if p.clientTrace != nil {
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
}
resp, err := p.transport.RoundTrip(newReq)
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.abort {
return
}
if err == nil {
for _, h := range hopHeaders {
resp.Header.Del(h)
}
}
responseFunc(resp, err)
headerPool.Put(newHeader)
requestPool.Put(newReq)
}
// HTTP代理
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
ctx.Req.URL.Scheme = "http"
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = resp.Body.Close()
}()
CopyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
buf := bufPool.Get().([]byte)
_, _ = io.CopyBuffer(rw, resp.Body, buf)
bufPool.Put(buf)
})
}
// HTTPS代理
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
return
}
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
_, _ = tlsClientConn.Write(badGateway)
return
}
err = resp.Write(tlsClientConn)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
}
_ = resp.Body.Close()
})
}
// 隧道代理
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
clientConn, err := hijacker(rw)
if err != nil {
p.delegate.ErrorLog(err)
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = clientConn.Close()
}()
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, clientConn)
return
}
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
if parentProxyURL == nil {
_, err = clientConn.Write(tunnelEstablishedResponseLine)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
return
}
}
isWebsocket := p.detectConnProtocol(clientConn)
if isWebsocket {
req, err := http.ReadRequest(clientConn.BufferReader())
if err != nil {
if err != io.EOF {
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
req.RemoteAddr = ctx.Req.RemoteAddr
req.URL.Scheme = "http"
req.URL.Host = req.Host
ctx.Req = req
p.websocketProxy(ctx, clientConn)
return
}
var tlsClientConn *tls.Conn
if p.decryptHTTPS {
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
return
}
tlsClientConn = tls.Server(clientConn, tlsConfig)
defer func() {
_ = tlsClientConn.Close()
}()
if err := tlsClientConn.Handshake(); err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
return
}
buf := bufio.NewReader(tlsClientConn)
tlsReq, err := http.ReadRequest(buf)
if err != nil {
if err != io.EOF {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
tlsReq.URL.Scheme = "https"
tlsReq.URL.Host = tlsReq.Host
ctx.Req = tlsReq
}
targetAddr := ctx.Req.URL.Host
if parentProxyURL != nil {
targetAddr = parentProxyURL.Host
}
if !strings.Contains(targetAddr, ":") {
targetAddr += ":443"
}
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
return
}
defer func() {
_ = targetConn.Close()
}()
if parentProxyURL != nil {
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
_, _ = targetConn.Write([]byte(tunnelRequestLine))
}
if p.decryptHTTPS {
p.httpsProxy(ctx, tlsClientConn)
} else {
p.tunnelConnected(ctx, nil)
p.transfer(clientConn, targetConn)
}
}
// WebSocket代理
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
if !p.websocketIntercept {
remoteAddr := ctx.Addr()
var err error
var targetConn net.Conn
if ctx.IsHTTPS() {
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
} else {
targetConn, err = net.Dial("tcp", remoteAddr)
}
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
err = ctx.Req.Write(targetConn)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transfer(srcConn, targetConn)
return
}
up := &websocket.Upgrader{
HandshakeTimeout: defaultTargetConnectTimeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
u := ctx.WebsocketUrl()
d := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
}
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
defer cancel()
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
}
// 探测连接协议
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
methodBytes, err := connBuf.Peek(3)
if err != nil {
return false
}
method := string(methodBytes)
if method != http.MethodGet {
return false
}
return true
}
// webSocket双向转发
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
doneCtx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := srcConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
err = targetConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
}
}()
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
err = srcConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
}
}
// 双向转发
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
go func() {
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = src.Close()
_ = dst.Close()
}()
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = dst.Close()
_ = src.Close()
}
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
ctx.TunnelProxy = true
p.delegate.BeforeRequest(ctx)
if err != nil {
p.delegate.BeforeResponse(ctx, nil, err)
return
}
resp := &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: http.NoBody,
}
p.delegate.BeforeResponse(ctx, resp, nil)
}
func (p *Proxy) dialContext() DialContext {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: defaultTargetConnectTimeout,
}
separator := strings.LastIndex(addr, ":")
ips, err := p.dnsCache.Fetch(addr[:separator])
if err != nil {
return nil, err
}
var ip string
for _, item := range ips {
ip = item.String()
if !strings.Contains(ip, ":") {
break
}
}
addr = ip + addr[separator:]
return dialer.DialContext(ctx, network, addr)
}
}
// 获取底层连接
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
return nil, fmt.Errorf("http server不支持Hijacker")
}
conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijacker错误: %s", err)
}
return NewConnBuffer(conn, buf), nil
}
// CopyHeader 浅拷贝Header
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// CloneHeader 深拷贝Header
func CloneHeader(h http.Header, h2 http.Header) {
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
}
var hopHeaders = []string{
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
}
type ConnBuffer struct {
net.Conn
buf *bufio.ReadWriter
}
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
if buf == nil {
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
}
return &ConnBuffer{
Conn: conn,
buf: buf,
}
}
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
return cb.buf.Reader
}
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
return cb.buf.Read(b)
}
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
return cb.buf.Peek(n)
}
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
n, err = cb.buf.Write(p)
if err != nil {
return 0, err
}
return n, cb.buf.Flush()
}
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return cb.Conn, cb.buf, nil
}
func (cb *ConnBuffer) WriteHeader(_ int) {}
func (cb *ConnBuffer) Header() http.Header { return nil }