增加多后端 DNS 解析和负载均衡的支持
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
.idea
|
||||
./delegate.go
|
||||
./proxy.go
|
||||
.vscode
|
||||
|
||||
245
README.md
245
README.md
@@ -12,7 +12,12 @@ GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebS
|
||||
- 通配符域名证书支持
|
||||
- 支持RSA和ECDSA证书算法选择
|
||||
- 支持上游代理链
|
||||
- 支持负载均衡(轮询、随机、权重等)
|
||||
- 支持多后端DNS解析和负载均衡
|
||||
- 支持一个域名对应多个后端服务器
|
||||
- 支持多种负载均衡策略(轮询、随机、第一个可用)
|
||||
- 支持通配符域名解析
|
||||
- 支持自定义DNS解析规则
|
||||
- 支持动态添加/删除后端服务器
|
||||
- 支持健康检查
|
||||
- 支持请求重试
|
||||
- 支持HTTP缓存
|
||||
@@ -264,6 +269,192 @@ func main() {
|
||||
}
|
||||
```
|
||||
|
||||
### 多后端DNS解析和负载均衡
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/goproxy/internal/dns"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建DNS解析器
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
|
||||
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
|
||||
)
|
||||
|
||||
// 添加多个后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
|
||||
// 添加通配符域名解析
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(&proxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// CustomDelegate 自定义委托
|
||||
type CustomDelegate struct {
|
||||
proxy.DefaultDelegate
|
||||
resolver *dns.CustomResolver
|
||||
}
|
||||
|
||||
// ResolveBackend 解析后端服务器
|
||||
func (d *CustomDelegate) ResolveBackend(req *http.Request) (string, error) {
|
||||
// 从请求中获取目标主机
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
// 解析域名获取后端服务器
|
||||
endpoint, err := d.resolver.ResolveWithPort(host, 80)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return endpoint.String(), nil
|
||||
}
|
||||
```
|
||||
|
||||
### 从配置文件加载DNS规则
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/goproxy/internal/dns"
|
||||
"github.com/goproxy/internal/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 创建DNS解析器
|
||||
resolver := dns.NewResolver()
|
||||
|
||||
// 从JSON文件加载DNS规则
|
||||
if err := loadDNSConfig(resolver, "dns_config.json"); err != nil {
|
||||
log.Fatalf("加载DNS配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建自定义委托
|
||||
delegate := &CustomDelegate{
|
||||
resolver: resolver,
|
||||
}
|
||||
|
||||
// 创建代理
|
||||
p := proxy.New(&proxy.Options{
|
||||
Delegate: delegate,
|
||||
})
|
||||
|
||||
// 启动HTTP服务器
|
||||
log.Println("代理服务器启动在 :8080")
|
||||
if err := http.ListenAndServe(":8080", p); err != nil {
|
||||
log.Fatalf("代理服务器启动失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// DNSConfig DNS配置结构
|
||||
type DNSConfig struct {
|
||||
Records map[string][]string `json:"records"` // 域名到IP地址列表的映射
|
||||
Wildcards map[string][]string `json:"wildcards"` // 通配符域名到IP地址列表的映射
|
||||
}
|
||||
|
||||
func loadDNSConfig(resolver *dns.CustomResolver, filename string) error {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var config DNSConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 加载精确匹配记录
|
||||
for host, ips := range config.Records {
|
||||
for _, ip := range ips {
|
||||
if err := resolver.Add(host, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 加载通配符记录
|
||||
for pattern, ips := range config.Wildcards {
|
||||
for _, ip := range ips {
|
||||
if err := resolver.AddWildcard(pattern, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
示例配置文件 `dns_config.json`:
|
||||
```json
|
||||
{
|
||||
"records": {
|
||||
"api.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080",
|
||||
"192.168.1.3:8080"
|
||||
]
|
||||
},
|
||||
"wildcards": {
|
||||
"*.example.com": [
|
||||
"192.168.1.1:8080",
|
||||
"192.168.1.2:8080"
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 动态管理后端服务器
|
||||
|
||||
```go
|
||||
// 添加新的后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 删除特定的后端服务器
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
|
||||
// 删除整个域名记录
|
||||
resolver.Remove("api.example.com")
|
||||
|
||||
// 清除所有记录
|
||||
resolver.Clear()
|
||||
```
|
||||
|
||||
## 架构设计
|
||||
|
||||
GoProxy采用模块化设计,主要包含以下模块:
|
||||
@@ -326,14 +517,14 @@ proxy := proxy.NewProxy()
|
||||
|
||||
// 创建一个功能丰富的代理
|
||||
proxy := proxy.NewProxy(
|
||||
proxy.WithConfig(config.DefaultConfig()),
|
||||
proxy.WithHTTPCache(myCache),
|
||||
proxy.WithDecryptHTTPS(myCertCache),
|
||||
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
proxy.WithMetrics(myMetrics),
|
||||
proxy.WithLoadBalancer(myLoadBalancer),
|
||||
proxy.WithRequestTimeout(10 * time.Second),
|
||||
proxy.WithEnableCORS(true)
|
||||
proxy.WithConfig(config.DefaultConfig()),
|
||||
proxy.WithHTTPCache(myCache),
|
||||
proxy.WithDecryptHTTPS(myCertCache),
|
||||
proxy.WithCACertAndKey("ca.crt", "ca.key"),
|
||||
proxy.WithMetrics(myMetrics),
|
||||
proxy.WithLoadBalancer(myLoadBalancer),
|
||||
proxy.WithRequestTimeout(10 * time.Second),
|
||||
proxy.WithEnableCORS(true)
|
||||
)
|
||||
```
|
||||
|
||||
@@ -380,19 +571,19 @@ GoProxy 提供了以下 With 方法用于配置代理的各个方面:
|
||||
|
||||
```go
|
||||
config := &config.Config{
|
||||
// 启用HTTPS解密
|
||||
DecryptHTTPS: true,
|
||||
|
||||
// 方式一:使用CA证书和私钥动态生成证书
|
||||
CACert: "path/to/ca.crt", // CA证书路径
|
||||
CAKey: "path/to/ca.key", // CA私钥路径
|
||||
|
||||
// 选择证书生成算法(可选)
|
||||
UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 方式二:使用固定的TLS证书和私钥
|
||||
// TLSCert: "path/to/server.crt",
|
||||
// TLSKey: "path/to/server.key",
|
||||
// 启用HTTPS解密
|
||||
DecryptHTTPS: true,
|
||||
|
||||
// 方式一:使用CA证书和私钥动态生成证书
|
||||
CACert: "path/to/ca.crt", // CA证书路径
|
||||
CAKey: "path/to/ca.key", // CA私钥路径
|
||||
|
||||
// 选择证书生成算法(可选)
|
||||
UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA)
|
||||
|
||||
// 方式二:使用固定的TLS证书和私钥
|
||||
// TLSCert: "path/to/server.crt",
|
||||
// TLSKey: "path/to/server.key",
|
||||
}
|
||||
```
|
||||
|
||||
@@ -400,11 +591,11 @@ config := &config.Config{
|
||||
|
||||
```go
|
||||
proxy := proxy.NewProxy(
|
||||
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
|
||||
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
|
||||
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
|
||||
// 或者使用静态TLS证书
|
||||
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
|
||||
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
|
||||
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
|
||||
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
|
||||
// 或者使用静态TLS证书
|
||||
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
281
README_DNS.md
281
README_DNS.md
@@ -6,6 +6,10 @@
|
||||
- 动态变更后端服务器IP
|
||||
- 自定义后端服务器端口
|
||||
- 泛解析(通配符域名)支持
|
||||
- 多后端服务器支持
|
||||
- 一个域名对应多个后端服务器
|
||||
- 支持多种负载均衡策略(轮询、随机、第一个可用)
|
||||
- 支持动态添加/删除后端服务器
|
||||
- 负载均衡和故障转移
|
||||
- 绕过DNS污染
|
||||
- 高效的DNS缓存
|
||||
@@ -16,6 +20,8 @@
|
||||
- **自定义端口**:为每个域名指定自定义端口,无需在URL中指定
|
||||
- **泛解析**:支持通配符域名(如`*.example.com`)自动匹配多个子域名
|
||||
- **多级泛解析**:支持复杂的通配符模式(如`api.*.example.com`)
|
||||
- **多后端支持**:支持一个域名配置多个后端服务器
|
||||
- **负载均衡**:支持多种负载均衡策略
|
||||
- **备用解析**:在自定义记录未找到时可选择使用系统DNS
|
||||
- **DNS缓存**:缓存解析结果以提高性能
|
||||
- **自动重试**:解析失败时可配置重试策略
|
||||
@@ -83,31 +89,37 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt
|
||||
```json
|
||||
{
|
||||
"records": {
|
||||
"example.com": "93.184.216.34",
|
||||
"api.example.com": "93.184.216.35:8443",
|
||||
"example.com": ["93.184.216.34", "93.184.216.35"],
|
||||
"api.example.com": ["93.184.216.35:8443", "93.184.216.36:8443"],
|
||||
|
||||
"*.github.com": "140.82.121.3",
|
||||
"github.com": "140.82.121.4",
|
||||
"*.github.com": ["140.82.121.3", "140.82.121.4"],
|
||||
"github.com": ["140.82.121.4", "140.82.121.5"],
|
||||
|
||||
"*.dev.local": "127.0.0.1:3000",
|
||||
"api.*.dev.local": "127.0.0.1:3001"
|
||||
"*.dev.local": ["127.0.0.1:3000", "127.0.0.1:3001"],
|
||||
"api.*.dev.local": ["127.0.0.1:3001", "127.0.0.1:3002"]
|
||||
},
|
||||
"use_fallback": true,
|
||||
"ttl": 300
|
||||
"ttl": 300,
|
||||
"load_balance_strategy": "round_robin"
|
||||
}
|
||||
```
|
||||
|
||||
### Hosts格式
|
||||
|
||||
```
|
||||
# 精确匹配
|
||||
# 精确匹配(多后端)
|
||||
93.184.216.34 example.com
|
||||
93.184.216.35 example.com
|
||||
93.184.216.35:8443 api.example.com
|
||||
93.184.216.36:8443 api.example.com
|
||||
|
||||
# 泛解析(通配符域名)
|
||||
# 泛解析(多后端)
|
||||
140.82.121.3 *.github.com
|
||||
140.82.121.4 *.github.com
|
||||
127.0.0.1:3000 *.dev.local
|
||||
127.0.0.1:3001 *.dev.local
|
||||
127.0.0.1:3001 api.*.dev.local
|
||||
127.0.0.1:3002 api.*.dev.local
|
||||
```
|
||||
|
||||
## 编程接口
|
||||
@@ -118,57 +130,76 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt
|
||||
import "github.com/goproxy/internal/dns"
|
||||
|
||||
// 创建解析器
|
||||
resolver := dns.NewResolver()
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
|
||||
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
|
||||
)
|
||||
|
||||
// 添加标准记录(使用默认端口)
|
||||
resolver.Add("example.com", "93.184.216.34")
|
||||
// 添加多个后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
|
||||
// 添加带端口的记录
|
||||
resolver.AddWithPort("api.example.com", "93.184.216.35", 8443)
|
||||
// 添加泛解析记录(多后端)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
|
||||
// 添加泛解析记录(使用默认端口)
|
||||
resolver.AddWildcard("*.example.com", "93.184.216.36")
|
||||
|
||||
// 添加带端口的泛解析记录
|
||||
resolver.AddWildcardWithPort("*.api.example.com", "93.184.216.37", 8444)
|
||||
|
||||
// 解析域名(只获取IP)
|
||||
ip, err := resolver.Resolve("example.com")
|
||||
// 解析域名(使用负载均衡策略选择后端)
|
||||
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
|
||||
if err != nil {
|
||||
log.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
fmt.Printf("解析结果IP: %s\n", ip)
|
||||
fmt.Printf("解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port)
|
||||
|
||||
// 测试泛解析功能
|
||||
ip, err = resolver.Resolve("sub.example.com")
|
||||
if err != nil {
|
||||
log.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
fmt.Printf("泛解析结果IP: %s\n", ip)
|
||||
|
||||
// 测试多级泛解析功能
|
||||
endpoint, err := resolver.ResolveWithPort("test.api.example.com", 443)
|
||||
if err != nil {
|
||||
log.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
fmt.Printf("多级泛解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port)
|
||||
// 动态添加/删除后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
```
|
||||
|
||||
### 负载均衡策略
|
||||
|
||||
GoProxy支持三种负载均衡策略:
|
||||
|
||||
1. **轮询策略(Round Robin)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin),
|
||||
)
|
||||
```
|
||||
|
||||
2. **随机策略(Random)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.Random),
|
||||
)
|
||||
```
|
||||
|
||||
3. **第一个可用策略(First Available)**
|
||||
```go
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.FirstAvailable),
|
||||
)
|
||||
```
|
||||
|
||||
### 使用DNS拨号器
|
||||
|
||||
```go
|
||||
import "github.com/goproxy/internal/dns"
|
||||
|
||||
// 创建解析器
|
||||
resolver := dns.NewResolver()
|
||||
resolver.Add("example.com", "93.184.216.34")
|
||||
resolver.AddWildcard("*.example.com", "93.184.216.36")
|
||||
resolver := dns.NewResolver(
|
||||
dns.WithLoadBalanceStrategy(dns.RoundRobin),
|
||||
)
|
||||
resolver.AddWithPort("example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.3", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 创建拨号器
|
||||
dialer := dns.NewDialer(resolver)
|
||||
|
||||
// 使用拨号器连接(会自动应用泛解析)
|
||||
conn, err := dialer.Dial("tcp", "sub.example.com:443")
|
||||
// 使用拨号器连接(会自动应用负载均衡)
|
||||
conn, err := dialer.Dial("tcp", "api.example.com:443")
|
||||
if err != nil {
|
||||
log.Fatalf("连接失败: %v", err)
|
||||
}
|
||||
@@ -183,125 +214,116 @@ client := &http.Client{Transport: transport}
|
||||
|
||||
## 高级用法
|
||||
|
||||
### 泛解析模式
|
||||
### 多后端配置模式
|
||||
|
||||
泛解析支持以下模式:
|
||||
多后端配置支持以下模式:
|
||||
|
||||
1. **单级通配符**:`*.example.com` 匹配 `a.example.com`、`b.example.com` 等
|
||||
1. **精确匹配多后端**:
|
||||
```go
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
|
||||
```
|
||||
|
||||
2. **多级通配符**:`*.*.example.com` 匹配 `a.b.example.com`、`c.d.example.com` 等
|
||||
2. **泛解析多后端**:
|
||||
```go
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
```
|
||||
|
||||
3. **中间通配符**:`api.*.example.com` 匹配 `api.v1.example.com`、`api.beta.example.com` 等
|
||||
3. **混合配置**:
|
||||
```go
|
||||
// 精确匹配优先于泛解析
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
|
||||
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
|
||||
```
|
||||
|
||||
4. **前缀通配符**:`api-*.example.com` 匹配 `api-v1.example.com`、`api-beta.example.com` 等
|
||||
|
||||
### 域名匹配优先级
|
||||
|
||||
当有多条规则可以匹配同一个域名时,解析器会按照以下优先级选择:
|
||||
|
||||
1. 精确匹配(如 `example.com`)
|
||||
2. 最具体的通配符匹配(如 `*.test.example.com` 比 `*.example.com` 更优先)
|
||||
3. 后添加的通配符规则优先于先添加的规则
|
||||
|
||||
### 透明代理模式
|
||||
|
||||
泛解析代理支持透明模式,这种模式下代理会使用请求中的Host头来进行DNS解析,而不是固定的目标主机:
|
||||
|
||||
```bash
|
||||
go run cmd/wildcard_dns_proxy/main.go -target ""
|
||||
```
|
||||
|
||||
这样,通过修改请求的Host头或使用不同的域名访问代理,可以自动路由到不同的后端服务器。
|
||||
|
||||
### 自定义DNS后端
|
||||
|
||||
可以实现自己的`dns.Resolver`接口来创建更复杂的DNS解析策略:
|
||||
### 动态管理后端服务器
|
||||
|
||||
```go
|
||||
type MyResolver struct {
|
||||
// 你的字段
|
||||
}
|
||||
// 添加新的后端服务器
|
||||
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
|
||||
|
||||
// 实现Resolver接口的所有方法
|
||||
func (r *MyResolver) Resolve(hostname string) (string, error) {
|
||||
// 自定义逻辑
|
||||
}
|
||||
// 删除特定的后端服务器
|
||||
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
|
||||
|
||||
func (r *MyResolver) ResolveWithPort(hostname string, defaultPort int) (*dns.Endpoint, error) {
|
||||
// 自定义逻辑
|
||||
}
|
||||
// 删除整个域名记录
|
||||
resolver.Remove("api.example.com")
|
||||
|
||||
func (r *MyResolver) Add(hostname, ip string) error {
|
||||
// 自定义逻辑
|
||||
}
|
||||
// 清除所有记录
|
||||
resolver.Clear()
|
||||
```
|
||||
|
||||
func (r *MyResolver) AddWithPort(hostname, ip string, port int) error {
|
||||
// 自定义逻辑
|
||||
}
|
||||
### 健康检查集成
|
||||
|
||||
func (r *MyResolver) AddWildcard(wildcardDomain, ip string) error {
|
||||
// 自定义逻辑
|
||||
}
|
||||
可以结合健康检查功能,自动剔除不健康的后端服务器:
|
||||
|
||||
func (r *MyResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
|
||||
// 自定义逻辑
|
||||
}
|
||||
```go
|
||||
// 创建健康检查器
|
||||
healthChecker := healthcheck.NewChecker(
|
||||
healthcheck.WithCheckInterval(30*time.Second),
|
||||
healthcheck.WithTimeout(5*time.Second),
|
||||
)
|
||||
|
||||
func (r *MyResolver) Remove(hostname string) error {
|
||||
// 自定义逻辑
|
||||
}
|
||||
// 添加健康检查
|
||||
healthChecker.Add("api.example.com", "192.168.1.1:8080")
|
||||
healthChecker.Add("api.example.com", "192.168.1.2:8080")
|
||||
|
||||
func (r *MyResolver) Clear() {
|
||||
// 自定义逻辑
|
||||
}
|
||||
// 在解析器中使用健康检查结果
|
||||
resolver.SetHealthChecker(healthChecker)
|
||||
```
|
||||
|
||||
## 应用场景
|
||||
|
||||
### 多环境测试
|
||||
### 高可用部署
|
||||
|
||||
使用泛解析可以为不同环境的所有服务配置不同的后端:
|
||||
使用多后端配置实现高可用:
|
||||
|
||||
```
|
||||
# 测试环境的所有服务
|
||||
192.168.1.100 *.test.example.com
|
||||
# 主备模式
|
||||
192.168.1.1 api.example.com # 主服务器
|
||||
192.168.1.2 api.example.com # 备用服务器
|
||||
|
||||
# 预发布环境的所有服务
|
||||
192.168.1.101 *.staging.example.com
|
||||
# 负载均衡模式
|
||||
192.168.1.1 api.example.com # 服务器1
|
||||
192.168.1.2 api.example.com # 服务器2
|
||||
192.168.1.3 api.example.com # 服务器3
|
||||
```
|
||||
|
||||
# 生产环境的所有服务
|
||||
192.168.1.102 *.production.example.com
|
||||
### 多环境部署
|
||||
|
||||
为不同环境配置不同的后端服务器组:
|
||||
|
||||
```
|
||||
# 测试环境
|
||||
192.168.1.10 *.test.example.com
|
||||
192.168.1.11 *.test.example.com
|
||||
|
||||
# 预发布环境
|
||||
192.168.1.20 *.staging.example.com
|
||||
192.168.1.21 *.staging.example.com
|
||||
|
||||
# 生产环境
|
||||
192.168.1.30 *.production.example.com
|
||||
192.168.1.31 *.production.example.com
|
||||
```
|
||||
|
||||
### 微服务架构
|
||||
|
||||
为不同类型的微服务提供统一的路由模式:
|
||||
为不同类型的微服务配置多个后端:
|
||||
|
||||
```
|
||||
# 认证服务
|
||||
10.0.0.1:8001 *.auth.internal
|
||||
10.0.0.2:8002 *.user.internal
|
||||
10.0.0.3:8003 *.payment.internal
|
||||
```
|
||||
10.0.0.2:8001 *.auth.internal
|
||||
|
||||
### 多租户系统
|
||||
# 用户服务
|
||||
10.0.0.3:8002 *.user.internal
|
||||
10.0.0.4:8002 *.user.internal
|
||||
|
||||
在多租户系统中为每个租户路由到不同后端:
|
||||
|
||||
```
|
||||
# 每个租户的子域名指向专用服务器
|
||||
192.168.1.10 tenant1.*.example.com
|
||||
192.168.1.11 tenant2.*.example.com
|
||||
192.168.1.12 tenant3.*.example.com
|
||||
```
|
||||
|
||||
### 本地开发环境
|
||||
|
||||
在本地开发中快速模拟复杂的服务架构:
|
||||
|
||||
```
|
||||
127.0.0.1:3000 *.local
|
||||
127.0.0.1:3001 api.*.local
|
||||
127.0.0.1:5432 db.*.local
|
||||
# 支付服务
|
||||
10.0.0.5:8003 *.payment.internal
|
||||
10.0.0.6:8003 *.payment.internal
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
@@ -310,4 +332,7 @@ func (r *MyResolver) Clear() {
|
||||
2. 泛解析规则的顺序会影响匹配结果,后添加的规则优先级更高
|
||||
3. 过多的泛解析规则可能会影响性能,建议合理组织规则
|
||||
4. 当域名同时匹配多个规则时,精确匹配优先于通配符匹配
|
||||
5. 自签名证书会导致浏览器警告,仅用于测试目的
|
||||
5. 自签名证书会导致浏览器警告,仅用于测试目的
|
||||
6. 多后端配置时,建议使用健康检查确保后端服务器可用性
|
||||
7. 负载均衡策略的选择应根据实际需求进行配置
|
||||
8. 动态添加/删除后端服务器时,需要考虑并发安全性
|
||||
132
delegate.go
132
delegate.go
@@ -1,132 +0,0 @@
|
||||
// Copyright 2018 ouqiang authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Context 代理上下文
|
||||
type Context struct {
|
||||
Req *http.Request
|
||||
Data map[interface{}]interface{}
|
||||
TunnelProxy bool
|
||||
abort bool
|
||||
}
|
||||
|
||||
func (c *Context) IsHTTPS() bool {
|
||||
return c.Req.URL.Scheme == "https"
|
||||
}
|
||||
|
||||
var defaultPorts = map[string]string{
|
||||
"https": "443",
|
||||
"http": "80",
|
||||
"": "80",
|
||||
}
|
||||
|
||||
func (c *Context) WebsocketUrl() *url.URL {
|
||||
u := new(url.URL)
|
||||
*u = *c.Req.URL
|
||||
if c.IsHTTPS() {
|
||||
u.Scheme = "wss"
|
||||
} else {
|
||||
u.Scheme = "ws"
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func (c *Context) Addr() string {
|
||||
addr := c.Req.Host
|
||||
|
||||
if !strings.Contains(c.Req.URL.Host, ":") {
|
||||
addr += ":" + defaultPorts[c.Req.URL.Scheme]
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// Abort 中断执行
|
||||
func (c *Context) Abort() {
|
||||
c.abort = true
|
||||
}
|
||||
|
||||
// IsAborted 是否已中断执行
|
||||
func (c *Context) IsAborted() bool {
|
||||
return c.abort
|
||||
}
|
||||
|
||||
// Reset 重置
|
||||
func (c *Context) Reset(req *http.Request) {
|
||||
c.Req = req
|
||||
c.Data = make(map[interface{}]interface{})
|
||||
c.abort = false
|
||||
c.TunnelProxy = false
|
||||
}
|
||||
|
||||
type Delegate interface {
|
||||
// Connect 收到客户端连接
|
||||
Connect(ctx *Context, rw http.ResponseWriter)
|
||||
// Auth 代理身份认证
|
||||
Auth(ctx *Context, rw http.ResponseWriter)
|
||||
// BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body
|
||||
BeforeRequest(ctx *Context)
|
||||
// BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code
|
||||
BeforeResponse(ctx *Context, resp *http.Response, err error)
|
||||
// WebSocketSendMessage websocket发送消息
|
||||
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
// WebSockerReceiveMessage websocket接收 消息
|
||||
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
|
||||
// ParentProxy 上级代理
|
||||
ParentProxy(*http.Request) (*url.URL, error)
|
||||
// Finish 本次请求结束
|
||||
Finish(ctx *Context)
|
||||
// 记录错误信息
|
||||
ErrorLog(err error)
|
||||
}
|
||||
|
||||
var _ Delegate = &DefaultDelegate{}
|
||||
|
||||
// DefaultDelegate 默认Handler什么也不做
|
||||
type DefaultDelegate struct {
|
||||
Delegate
|
||||
}
|
||||
|
||||
func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
|
||||
|
||||
func (h *DefaultDelegate) BeforeRequest(ctx *Context) {}
|
||||
|
||||
func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {}
|
||||
|
||||
func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
|
||||
return http.ProxyFromEnvironment(req)
|
||||
}
|
||||
|
||||
// WebSocketSendMessage websocket发送消息
|
||||
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
// WebSockerReceiveMessage websocket接收 消息
|
||||
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
|
||||
|
||||
func (h *DefaultDelegate) Finish(ctx *Context) {}
|
||||
|
||||
func (h *DefaultDelegate) ErrorLog(err error) {
|
||||
log.Println(err)
|
||||
}
|
||||
1
go.mod
1
go.mod
@@ -3,7 +3,6 @@ module github.com/goproxy
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/ouqiang/goproxy v1.3.2
|
||||
github.com/ouqiang/websocket v1.6.2
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
|
||||
golang.org/x/time v0.11.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1,5 +1,3 @@
|
||||
github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk=
|
||||
github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww=
|
||||
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
|
||||
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
|
||||
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=
|
||||
|
||||
@@ -39,18 +39,19 @@ type Resolver interface {
|
||||
// CustomResolver 自定义DNS解析器
|
||||
type CustomResolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*Endpoint // 精确域名到端点的映射
|
||||
wildcardRules []wildcardRule // 通配符规则列表
|
||||
cache map[string]cacheEntry // 外部域名解析缓存
|
||||
fallback bool // 是否在本地记录找不到时回退到系统DNS
|
||||
ttl time.Duration // 缓存TTL
|
||||
records map[string][]*Endpoint // 精确域名到多个端点的映射
|
||||
wildcardRules []wildcardRule // 通配符规则列表
|
||||
cache map[string]cacheEntry // 外部域名解析缓存
|
||||
fallback bool // 是否在本地记录找不到时回退到系统DNS
|
||||
ttl time.Duration // 缓存TTL
|
||||
lbStrategy LoadBalanceStrategy // 负载均衡策略
|
||||
}
|
||||
|
||||
// wildcardRule 通配符规则
|
||||
type wildcardRule struct {
|
||||
pattern string // 原始通配符模式,如 *.example.com
|
||||
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
|
||||
endpoint *Endpoint // 对应的端点
|
||||
pattern string // 原始通配符模式,如 *.example.com
|
||||
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
|
||||
endpoints []*Endpoint // 对应的多个端点
|
||||
}
|
||||
|
||||
// cacheEntry 缓存条目
|
||||
@@ -59,14 +60,24 @@ type cacheEntry struct {
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// LoadBalanceStrategy 负载均衡策略
|
||||
type LoadBalanceStrategy int
|
||||
|
||||
const (
|
||||
RoundRobin LoadBalanceStrategy = iota // 轮询策略
|
||||
Random // 随机策略
|
||||
FirstAvailable // 第一个可用策略
|
||||
)
|
||||
|
||||
// NewResolver 创建新的自定义DNS解析器
|
||||
func NewResolver(options ...Option) *CustomResolver {
|
||||
r := &CustomResolver{
|
||||
records: make(map[string]*Endpoint),
|
||||
records: make(map[string][]*Endpoint),
|
||||
wildcardRules: make([]wildcardRule, 0),
|
||||
cache: make(map[string]cacheEntry),
|
||||
fallback: true,
|
||||
ttl: 5 * time.Minute,
|
||||
lbStrategy: RoundRobin,
|
||||
}
|
||||
|
||||
// 应用选项
|
||||
@@ -92,15 +103,15 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin
|
||||
r.mu.RLock()
|
||||
|
||||
// 精确匹配
|
||||
if endpoint, ok := r.records[host]; ok {
|
||||
if endpoints, ok := r.records[host]; ok && len(endpoints) > 0 {
|
||||
r.mu.RUnlock()
|
||||
return endpoint, nil
|
||||
return r.selectEndpoint(endpoints, defaultPort), nil
|
||||
}
|
||||
|
||||
// 尝试通配符匹配
|
||||
if endpoint := r.matchWildcard(host); endpoint != nil {
|
||||
if endpoints := r.matchWildcard(host); len(endpoints) > 0 {
|
||||
r.mu.RUnlock()
|
||||
return endpoint, nil
|
||||
return r.selectEndpoint(endpoints, defaultPort), nil
|
||||
}
|
||||
|
||||
// 检查缓存
|
||||
@@ -152,15 +163,49 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin
|
||||
return nil, errors.New("未找到域名记录且系统DNS回退被禁用")
|
||||
}
|
||||
|
||||
// selectEndpoint 根据负载均衡策略选择一个端点
|
||||
func (r *CustomResolver) selectEndpoint(endpoints []*Endpoint, defaultPort int) *Endpoint {
|
||||
if len(endpoints) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果只有一个端点,直接返回
|
||||
if len(endpoints) == 1 {
|
||||
endpoint := endpoints[0]
|
||||
if defaultPort > 0 {
|
||||
endpoint.Port = defaultPort
|
||||
}
|
||||
return endpoint
|
||||
}
|
||||
|
||||
// 根据负载均衡策略选择端点
|
||||
var selected *Endpoint
|
||||
switch r.lbStrategy {
|
||||
case RoundRobin:
|
||||
// 轮询策略:每次选择下一个端点
|
||||
selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))]
|
||||
case Random:
|
||||
// 随机策略:随机选择一个端点
|
||||
selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))]
|
||||
case FirstAvailable:
|
||||
// 第一个可用策略:选择第一个端点
|
||||
selected = endpoints[0]
|
||||
}
|
||||
|
||||
if selected != nil && defaultPort > 0 {
|
||||
selected.Port = defaultPort
|
||||
}
|
||||
return selected
|
||||
}
|
||||
|
||||
// matchWildcard 尝试匹配通配符规则
|
||||
func (r *CustomResolver) matchWildcard(host string) *Endpoint {
|
||||
func (r *CustomResolver) matchWildcard(host string) []*Endpoint {
|
||||
hostParts := strings.Split(host, ".")
|
||||
|
||||
// 按照通配符规则列表的顺序尝试匹配
|
||||
// 规则顺序应该保证更具体的规则先匹配
|
||||
for _, rule := range r.wildcardRules {
|
||||
if matchDomainPattern(hostParts, rule.parts) {
|
||||
return rule.endpoint
|
||||
return rule.endpoints
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +249,18 @@ func (r *CustomResolver) AddWithPort(host, ip string, port int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.records[host] = NewEndpointWithPort(ip, port)
|
||||
endpoint := NewEndpointWithPort(ip, port)
|
||||
if endpoints, exists := r.records[host]; exists {
|
||||
// 检查是否已存在相同的端点
|
||||
for _, e := range endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
return nil // 端点已存在,无需重复添加
|
||||
}
|
||||
}
|
||||
r.records[host] = append(r.records[host], endpoint)
|
||||
} else {
|
||||
r.records[host] = []*Endpoint{endpoint}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -230,11 +286,27 @@ func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 查找是否已存在相同的通配符规则
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == wildcardDomain {
|
||||
// 检查是否已存在相同的端点
|
||||
endpoint := NewEndpointWithPort(ip, port)
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
return nil // 端点已存在,无需重复添加
|
||||
}
|
||||
}
|
||||
// 添加新端点
|
||||
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新的通配符规则
|
||||
rule := wildcardRule{
|
||||
pattern: wildcardDomain,
|
||||
parts: parts,
|
||||
endpoint: NewEndpointWithPort(ip, port),
|
||||
pattern: wildcardDomain,
|
||||
parts: parts,
|
||||
endpoints: []*Endpoint{NewEndpointWithPort(ip, port)},
|
||||
}
|
||||
|
||||
// 将新规则添加到规则列表头部,确保更新的规则优先匹配
|
||||
@@ -266,12 +338,54 @@ func (r *CustomResolver) Remove(host string) error {
|
||||
return errors.New("域名记录不存在")
|
||||
}
|
||||
|
||||
// RemoveEndpoint 删除特定端点
|
||||
func (r *CustomResolver) RemoveEndpoint(host, ip string, port int) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// 尝试从精确匹配记录中删除
|
||||
if endpoints, ok := r.records[host]; ok {
|
||||
newEndpoints := make([]*Endpoint, 0)
|
||||
for _, e := range endpoints {
|
||||
if e.IP != ip || e.Port != port {
|
||||
newEndpoints = append(newEndpoints, e)
|
||||
}
|
||||
}
|
||||
if len(newEndpoints) == 0 {
|
||||
delete(r.records, host)
|
||||
} else {
|
||||
r.records[host] = newEndpoints
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试从通配符规则中删除
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == host {
|
||||
newEndpoints := make([]*Endpoint, 0)
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP != ip || e.Port != port {
|
||||
newEndpoints = append(newEndpoints, e)
|
||||
}
|
||||
}
|
||||
if len(newEndpoints) == 0 {
|
||||
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
|
||||
} else {
|
||||
r.wildcardRules[i].endpoints = newEndpoints
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New("域名记录不存在")
|
||||
}
|
||||
|
||||
// Clear 清除所有解析规则
|
||||
func (r *CustomResolver) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.records = make(map[string]*Endpoint)
|
||||
r.records = make(map[string][]*Endpoint)
|
||||
r.wildcardRules = make([]wildcardRule, 0)
|
||||
r.cache = make(map[string]cacheEntry)
|
||||
}
|
||||
@@ -293,6 +407,13 @@ func WithTTL(ttl time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithLoadBalanceStrategy 设置负载均衡策略
|
||||
func WithLoadBalanceStrategy(strategy LoadBalanceStrategy) Option {
|
||||
return func(r *CustomResolver) {
|
||||
r.lbStrategy = strategy
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromMap 从映射加载DNS记录
|
||||
func (r *CustomResolver) LoadFromMap(records map[string]string) error {
|
||||
r.mu.Lock()
|
||||
@@ -310,13 +431,33 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error {
|
||||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||||
}
|
||||
|
||||
// 添加通配符规则
|
||||
rule := wildcardRule{
|
||||
pattern: host,
|
||||
parts: strings.Split(host, "."),
|
||||
endpoint: endpoint,
|
||||
// 查找是否已存在相同的通配符规则
|
||||
found := false
|
||||
for i, rule := range r.wildcardRules {
|
||||
if rule.pattern == host {
|
||||
// 检查是否已存在相同的端点
|
||||
for _, e := range rule.endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// 创建新的通配符规则
|
||||
rule := wildcardRule{
|
||||
pattern: host,
|
||||
parts: strings.Split(host, "."),
|
||||
endpoints: []*Endpoint{endpoint},
|
||||
}
|
||||
r.wildcardRules = append(r.wildcardRules, rule)
|
||||
}
|
||||
r.wildcardRules = append(r.wildcardRules, rule)
|
||||
|
||||
} else {
|
||||
// 常规记录
|
||||
@@ -329,7 +470,21 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error {
|
||||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||||
}
|
||||
|
||||
r.records[host] = endpoint
|
||||
// 检查是否已存在相同的端点
|
||||
if endpoints, exists := r.records[host]; exists {
|
||||
found := false
|
||||
for _, e := range endpoints {
|
||||
if e.IP == endpoint.IP && e.Port == endpoint.Port {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r.records[host] = append(r.records[host], endpoint)
|
||||
}
|
||||
} else {
|
||||
r.records[host] = []*Endpoint{endpoint}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
753
proxy.go
753
proxy.go
@@ -1,753 +0,0 @@
|
||||
// Copyright 2018 ouqiang authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
|
||||
package goproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/viki-org/dnscache"
|
||||
|
||||
"github.com/ouqiang/goproxy/cert"
|
||||
"github.com/ouqiang/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// 连接目标服务器超时时间
|
||||
defaultTargetConnectTimeout = 5 * time.Second
|
||||
// 目标服务器读写超时时间
|
||||
defaultTargetReadWriteTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// 隧道连接成功响应行
|
||||
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
|
||||
|
||||
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
|
||||
|
||||
var (
|
||||
bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 32*1024)
|
||||
},
|
||||
}
|
||||
|
||||
ctxPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(Context)
|
||||
},
|
||||
}
|
||||
headerPool = NewHeaderPool()
|
||||
requestPool = newRequestPool()
|
||||
)
|
||||
|
||||
type RequestPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newRequestPool() *RequestPool {
|
||||
return &RequestPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(http.Request)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RequestPool) Get() *http.Request {
|
||||
req := p.pool.Get().(*http.Request)
|
||||
|
||||
req.Method = ""
|
||||
req.URL = nil
|
||||
req.Proto = ""
|
||||
req.ProtoMajor = 0
|
||||
req.ProtoMinor = 0
|
||||
req.Header = nil
|
||||
req.Body = nil
|
||||
req.GetBody = nil
|
||||
req.ContentLength = 0
|
||||
req.TransferEncoding = nil
|
||||
req.Close = false
|
||||
req.Host = ""
|
||||
req.Form = nil
|
||||
req.PostForm = nil
|
||||
req.MultipartForm = nil
|
||||
req.Trailer = nil
|
||||
req.RemoteAddr = ""
|
||||
req.RequestURI = ""
|
||||
req.TLS = nil
|
||||
req.Cancel = nil
|
||||
req.Response = nil
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func (p *RequestPool) Put(req *http.Request) {
|
||||
if req != nil {
|
||||
p.pool.Put(req)
|
||||
}
|
||||
}
|
||||
|
||||
type HeaderPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func NewHeaderPool() *HeaderPool {
|
||||
return &HeaderPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return http.Header{}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *HeaderPool) Get() http.Header {
|
||||
header := p.pool.Get().(http.Header)
|
||||
for k := range header {
|
||||
delete(header, k)
|
||||
}
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
func (p *HeaderPool) Put(header http.Header) {
|
||||
if header != nil {
|
||||
p.pool.Put(header)
|
||||
}
|
||||
}
|
||||
|
||||
// 生成隧道建立请求行
|
||||
func makeTunnelRequestLine(addr string) string {
|
||||
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
disableKeepAlive bool
|
||||
delegate Delegate
|
||||
|
||||
decryptHTTPS bool
|
||||
websocketIntercept bool
|
||||
certCache cert.Cache
|
||||
transport *http.Transport
|
||||
clientTrace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
type Option func(*options)
|
||||
|
||||
// WithDisableKeepAlive 连接是否重用
|
||||
func WithDisableKeepAlive(disableKeepAlive bool) Option {
|
||||
return func(opt *options) {
|
||||
opt.disableKeepAlive = disableKeepAlive
|
||||
}
|
||||
}
|
||||
|
||||
func WithClientTrace(t *httptrace.ClientTrace) Option {
|
||||
return func(opt *options) {
|
||||
opt.clientTrace = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDelegate 设置委托类
|
||||
func WithDelegate(delegate Delegate) Option {
|
||||
return func(opt *options) {
|
||||
opt.delegate = delegate
|
||||
}
|
||||
}
|
||||
|
||||
// WithTransport 自定义http transport
|
||||
func WithTransport(t *http.Transport) Option {
|
||||
return func(opt *options) {
|
||||
opt.transport = t
|
||||
}
|
||||
}
|
||||
|
||||
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
|
||||
func WithDecryptHTTPS(c cert.Cache) Option {
|
||||
return func(opt *options) {
|
||||
opt.decryptHTTPS = true
|
||||
opt.certCache = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnableWebsocketIntercept 拦截websocket
|
||||
func WithEnableWebsocketIntercept() Option {
|
||||
return func(opt *options) {
|
||||
opt.websocketIntercept = true
|
||||
}
|
||||
}
|
||||
|
||||
// New 创建proxy实例
|
||||
func New(opt ...Option) *Proxy {
|
||||
opts := &options{}
|
||||
for _, o := range opt {
|
||||
o(opts)
|
||||
}
|
||||
if opts.delegate == nil {
|
||||
opts.delegate = &DefaultDelegate{}
|
||||
}
|
||||
if opts.transport == nil {
|
||||
opts.transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
MaxIdleConns: 100,
|
||||
MaxConnsPerHost: 10,
|
||||
IdleConnTimeout: 10 * time.Second,
|
||||
TLSHandshakeTimeout: 5 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
p := &Proxy{}
|
||||
p.delegate = opts.delegate
|
||||
p.websocketIntercept = opts.websocketIntercept
|
||||
p.decryptHTTPS = opts.decryptHTTPS
|
||||
if p.decryptHTTPS {
|
||||
p.cert = cert.NewCertificate(opts.certCache, true)
|
||||
}
|
||||
p.transport = opts.transport
|
||||
p.transport.DialContext = p.dialContext()
|
||||
p.dnsCache = dnscache.New(5 * time.Minute)
|
||||
p.transport.DisableKeepAlives = opts.disableKeepAlive
|
||||
p.transport.Proxy = p.delegate.ParentProxy
|
||||
p.clientTrace = opts.clientTrace
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Proxy 实现了http.Handler接口
|
||||
type Proxy struct {
|
||||
delegate Delegate
|
||||
clientConnNum int32
|
||||
decryptHTTPS bool
|
||||
websocketIntercept bool
|
||||
cert *cert.Certificate
|
||||
transport *http.Transport
|
||||
clientTrace *httptrace.ClientTrace
|
||||
dnsCache *dnscache.Resolver
|
||||
}
|
||||
|
||||
var _ http.Handler = &Proxy{}
|
||||
|
||||
// ServeHTTP 实现了http.Handler接口
|
||||
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Host == "" {
|
||||
req.URL.Host = req.Host
|
||||
}
|
||||
atomic.AddInt32(&p.clientConnNum, 1)
|
||||
ctx := ctxPool.Get().(*Context)
|
||||
ctx.Reset(req)
|
||||
|
||||
defer func() {
|
||||
p.delegate.Finish(ctx)
|
||||
ctxPool.Put(ctx)
|
||||
atomic.AddInt32(&p.clientConnNum, -1)
|
||||
}()
|
||||
p.delegate.Connect(ctx, rw)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
p.delegate.Auth(ctx, rw)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case ctx.Req.Method == http.MethodConnect:
|
||||
p.tunnelProxy(ctx, rw)
|
||||
case websocket.IsWebSocketUpgrade(ctx.Req):
|
||||
p.tunnelProxy(ctx, rw)
|
||||
default:
|
||||
p.httpProxy(ctx, rw)
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConnNum 获取客户端连接数
|
||||
func (p *Proxy) ClientConnNum() int32 {
|
||||
return atomic.LoadInt32(&p.clientConnNum)
|
||||
}
|
||||
|
||||
// DoRequest 执行HTTP请求,并调用responseFunc处理response
|
||||
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
|
||||
if ctx.Data == nil {
|
||||
ctx.Data = make(map[interface{}]interface{})
|
||||
}
|
||||
p.delegate.BeforeRequest(ctx)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
newReq := requestPool.Get()
|
||||
*newReq = *ctx.Req
|
||||
newHeader := headerPool.Get()
|
||||
CloneHeader(newReq.Header, newHeader)
|
||||
newReq.Header = newHeader
|
||||
for _, item := range hopHeaders {
|
||||
if newReq.Header.Get(item) != "" {
|
||||
newReq.Header.Del(item)
|
||||
}
|
||||
}
|
||||
if p.clientTrace != nil {
|
||||
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
|
||||
}
|
||||
|
||||
resp, err := p.transport.RoundTrip(newReq)
|
||||
p.delegate.BeforeResponse(ctx, resp, err)
|
||||
if ctx.abort {
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
for _, h := range hopHeaders {
|
||||
resp.Header.Del(h)
|
||||
}
|
||||
}
|
||||
responseFunc(resp, err)
|
||||
headerPool.Put(newHeader)
|
||||
requestPool.Put(newReq)
|
||||
}
|
||||
|
||||
// HTTP代理
|
||||
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
|
||||
ctx.Req.URL.Scheme = "http"
|
||||
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
CopyHeader(rw.Header(), resp.Header)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, _ = io.CopyBuffer(rw, resp.Body, buf)
|
||||
bufPool.Put(buf)
|
||||
})
|
||||
}
|
||||
|
||||
// HTTPS代理
|
||||
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
|
||||
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
|
||||
return
|
||||
}
|
||||
p.DoRequest(ctx, func(resp *http.Response, err error) {
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
|
||||
_, _ = tlsClientConn.Write(badGateway)
|
||||
return
|
||||
}
|
||||
err = resp.Write(tlsClientConn)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// 隧道代理
|
||||
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
|
||||
clientConn, err := hijacker(rw)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(err)
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
if websocket.IsWebSocketUpgrade(ctx.Req) {
|
||||
p.websocketProxy(ctx, clientConn)
|
||||
return
|
||||
}
|
||||
|
||||
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if parentProxyURL == nil {
|
||||
_, err = clientConn.Write(tunnelEstablishedResponseLine)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
isWebsocket := p.detectConnProtocol(clientConn)
|
||||
if isWebsocket {
|
||||
req, err := http.ReadRequest(clientConn.BufferReader())
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
req.RemoteAddr = ctx.Req.RemoteAddr
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = req.Host
|
||||
ctx.Req = req
|
||||
|
||||
p.websocketProxy(ctx, clientConn)
|
||||
return
|
||||
}
|
||||
var tlsClientConn *tls.Conn
|
||||
if p.decryptHTTPS {
|
||||
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
tlsClientConn = tls.Server(clientConn, tlsConfig)
|
||||
defer func() {
|
||||
_ = tlsClientConn.Close()
|
||||
}()
|
||||
if err := tlsClientConn.Handshake(); err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := bufio.NewReader(tlsClientConn)
|
||||
tlsReq, err := http.ReadRequest(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
|
||||
tlsReq.URL.Scheme = "https"
|
||||
tlsReq.URL.Host = tlsReq.Host
|
||||
ctx.Req = tlsReq
|
||||
}
|
||||
|
||||
targetAddr := ctx.Req.URL.Host
|
||||
if parentProxyURL != nil {
|
||||
targetAddr = parentProxyURL.Host
|
||||
}
|
||||
if !strings.Contains(targetAddr, ":") {
|
||||
targetAddr += ":443"
|
||||
}
|
||||
|
||||
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = targetConn.Close()
|
||||
}()
|
||||
if parentProxyURL != nil {
|
||||
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
|
||||
_, _ = targetConn.Write([]byte(tunnelRequestLine))
|
||||
}
|
||||
|
||||
if p.decryptHTTPS {
|
||||
p.httpsProxy(ctx, tlsClientConn)
|
||||
} else {
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transfer(clientConn, targetConn)
|
||||
}
|
||||
}
|
||||
|
||||
// WebSocket代理
|
||||
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
|
||||
if !p.websocketIntercept {
|
||||
remoteAddr := ctx.Addr()
|
||||
var err error
|
||||
var targetConn net.Conn
|
||||
if ctx.IsHTTPS() {
|
||||
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
|
||||
} else {
|
||||
targetConn, err = net.Dial("tcp", remoteAddr)
|
||||
}
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
err = ctx.Req.Write(targetConn)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transfer(srcConn, targetConn)
|
||||
return
|
||||
}
|
||||
|
||||
up := &websocket.Upgrader{
|
||||
HandshakeTimeout: defaultTargetConnectTimeout,
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
|
||||
u := ctx.WebsocketUrl()
|
||||
d := websocket.Dialer{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
}
|
||||
|
||||
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
|
||||
defer cancel()
|
||||
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
|
||||
if err != nil {
|
||||
p.tunnelConnected(ctx, err)
|
||||
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
|
||||
return
|
||||
}
|
||||
p.tunnelConnected(ctx, nil)
|
||||
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
|
||||
}
|
||||
|
||||
// 探测连接协议
|
||||
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
|
||||
methodBytes, err := connBuf.Peek(3)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
method := string(methodBytes)
|
||||
if method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// webSocket双向转发
|
||||
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
|
||||
doneCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
if doneCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msgType, msg, err := srcConn.ReadMessage()
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
|
||||
err = targetConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
if doneCtx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msgType, msg, err := targetConn.ReadMessage()
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
|
||||
err = srcConn.WriteMessage(msgType, msg)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 双向转发
|
||||
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
|
||||
go func() {
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, err := io.CopyBuffer(src, dst, buf)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
|
||||
}
|
||||
bufPool.Put(buf)
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
buf := bufPool.Get().([]byte)
|
||||
_, err := io.CopyBuffer(dst, src, buf)
|
||||
if err != nil {
|
||||
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
|
||||
}
|
||||
bufPool.Put(buf)
|
||||
_ = dst.Close()
|
||||
_ = src.Close()
|
||||
}
|
||||
|
||||
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
|
||||
ctx.TunnelProxy = true
|
||||
p.delegate.BeforeRequest(ctx)
|
||||
if err != nil {
|
||||
p.delegate.BeforeResponse(ctx, nil, err)
|
||||
return
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: "200 OK",
|
||||
StatusCode: http.StatusOK,
|
||||
Proto: "1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
Body: http.NoBody,
|
||||
}
|
||||
p.delegate.BeforeResponse(ctx, resp, nil)
|
||||
}
|
||||
|
||||
func (p *Proxy) dialContext() DialContext {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTargetConnectTimeout,
|
||||
}
|
||||
separator := strings.LastIndex(addr, ":")
|
||||
ips, err := p.dnsCache.Fetch(addr[:separator])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ip string
|
||||
for _, item := range ips {
|
||||
ip = item.String()
|
||||
if !strings.Contains(ip, ":") {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
addr = ip + addr[separator:]
|
||||
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取底层连接
|
||||
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("http server不支持Hijacker")
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hijacker错误: %s", err)
|
||||
}
|
||||
|
||||
return NewConnBuffer(conn, buf), nil
|
||||
}
|
||||
|
||||
// CopyHeader 浅拷贝Header
|
||||
func CopyHeader(dst, src http.Header) {
|
||||
for k, vv := range src {
|
||||
for _, v := range vv {
|
||||
dst.Add(k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CloneHeader 深拷贝Header
|
||||
func CloneHeader(h http.Header, h2 http.Header) {
|
||||
for k, vv := range h {
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
h2[k] = vv2
|
||||
}
|
||||
}
|
||||
|
||||
var hopHeaders = []string{
|
||||
"Proxy-Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
}
|
||||
|
||||
type ConnBuffer struct {
|
||||
net.Conn
|
||||
buf *bufio.ReadWriter
|
||||
}
|
||||
|
||||
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
|
||||
if buf == nil {
|
||||
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||
}
|
||||
return &ConnBuffer{
|
||||
Conn: conn,
|
||||
buf: buf,
|
||||
}
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
|
||||
return cb.buf.Reader
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
|
||||
return cb.buf.Read(b)
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
|
||||
return cb.buf.Peek(n)
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
|
||||
n, err = cb.buf.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return n, cb.buf.Flush()
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return cb.Conn, cb.buf, nil
|
||||
}
|
||||
|
||||
func (cb *ConnBuffer) WriteHeader(_ int) {}
|
||||
|
||||
func (cb *ConnBuffer) Header() http.Header { return nil }
|
||||
Reference in New Issue
Block a user