diff --git a/.gitignore b/.gitignore index 4306551..706fd07 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ .idea -./delegate.go -./proxy.go +.vscode diff --git a/README.md b/README.md index 2f3bde8..fe5b253 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,12 @@ GoProxy是一个功能强大的Go语言HTTP代理库,支持HTTP、HTTPS和WebS - 通配符域名证书支持 - 支持RSA和ECDSA证书算法选择 - 支持上游代理链 -- 支持负载均衡(轮询、随机、权重等) +- 支持多后端DNS解析和负载均衡 + - 支持一个域名对应多个后端服务器 + - 支持多种负载均衡策略(轮询、随机、第一个可用) + - 支持通配符域名解析 + - 支持自定义DNS解析规则 + - 支持动态添加/删除后端服务器 - 支持健康检查 - 支持请求重试 - 支持HTTP缓存 @@ -264,6 +269,192 @@ func main() { } ``` +### 多后端DNS解析和负载均衡 + +```go +package main + +import ( + "log" + "net/http" + "time" + + "github.com/goproxy/internal/dns" + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建DNS解析器 + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略 + dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL + ) + + // 添加多个后端服务器 + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) + + // 添加通配符域名解析 + resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + + // 创建自定义委托 + delegate := &CustomDelegate{ + resolver: resolver, + } + + // 创建代理 + p := proxy.New(&proxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// CustomDelegate 自定义委托 +type CustomDelegate struct { + proxy.DefaultDelegate + resolver *dns.CustomResolver +} + +// ResolveBackend 解析后端服务器 +func (d *CustomDelegate) ResolveBackend(req *http.Request) (string, error) { + // 从请求中获取目标主机 + host := req.Host + if host == "" { + host = req.URL.Host + } + + // 解析域名获取后端服务器 + endpoint, err := d.resolver.ResolveWithPort(host, 80) + if err != nil { + return "", err + } + + return endpoint.String(), nil +} +``` + +### 从配置文件加载DNS规则 + +```go +package main + +import ( + "encoding/json" + "log" + "net/http" + "os" + + "github.com/goproxy/internal/dns" + "github.com/goproxy/internal/proxy" +) + +func main() { + // 创建DNS解析器 + resolver := dns.NewResolver() + + // 从JSON文件加载DNS规则 + if err := loadDNSConfig(resolver, "dns_config.json"); err != nil { + log.Fatalf("加载DNS配置失败: %v", err) + } + + // 创建自定义委托 + delegate := &CustomDelegate{ + resolver: resolver, + } + + // 创建代理 + p := proxy.New(&proxy.Options{ + Delegate: delegate, + }) + + // 启动HTTP服务器 + log.Println("代理服务器启动在 :8080") + if err := http.ListenAndServe(":8080", p); err != nil { + log.Fatalf("代理服务器启动失败: %v", err) + } +} + +// DNSConfig DNS配置结构 +type DNSConfig struct { + Records map[string][]string `json:"records"` // 域名到IP地址列表的映射 + Wildcards map[string][]string `json:"wildcards"` // 通配符域名到IP地址列表的映射 +} + +func loadDNSConfig(resolver *dns.CustomResolver, filename string) error { + data, err := os.ReadFile(filename) + if err != nil { + return err + } + + var config DNSConfig + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + // 加载精确匹配记录 + for host, ips := range config.Records { + for _, ip := range ips { + if err := resolver.Add(host, ip); err != nil { + return err + } + } + } + + // 加载通配符记录 + for pattern, ips := range config.Wildcards { + for _, ip := range ips { + if err := resolver.AddWildcard(pattern, ip); err != nil { + return err + } + } + } + + return nil +} +``` + +示例配置文件 `dns_config.json`: +```json +{ + "records": { + "api.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080", + "192.168.1.3:8080" + ] + }, + "wildcards": { + "*.example.com": [ + "192.168.1.1:8080", + "192.168.1.2:8080" + ] + } +} +``` + +### 动态管理后端服务器 + +```go +// 添加新的后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) + +// 删除特定的后端服务器 +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) + +// 删除整个域名记录 +resolver.Remove("api.example.com") + +// 清除所有记录 +resolver.Clear() +``` + ## 架构设计 GoProxy采用模块化设计,主要包含以下模块: @@ -326,14 +517,14 @@ proxy := proxy.NewProxy() // 创建一个功能丰富的代理 proxy := proxy.NewProxy( - proxy.WithConfig(config.DefaultConfig()), - proxy.WithHTTPCache(myCache), - proxy.WithDecryptHTTPS(myCertCache), - proxy.WithCACertAndKey("ca.crt", "ca.key"), - proxy.WithMetrics(myMetrics), - proxy.WithLoadBalancer(myLoadBalancer), - proxy.WithRequestTimeout(10 * time.Second), - proxy.WithEnableCORS(true) + proxy.WithConfig(config.DefaultConfig()), + proxy.WithHTTPCache(myCache), + proxy.WithDecryptHTTPS(myCertCache), + proxy.WithCACertAndKey("ca.crt", "ca.key"), + proxy.WithMetrics(myMetrics), + proxy.WithLoadBalancer(myLoadBalancer), + proxy.WithRequestTimeout(10 * time.Second), + proxy.WithEnableCORS(true) ) ``` @@ -380,19 +571,19 @@ GoProxy 提供了以下 With 方法用于配置代理的各个方面: ```go config := &config.Config{ - // 启用HTTPS解密 - DecryptHTTPS: true, - - // 方式一:使用CA证书和私钥动态生成证书 - CACert: "path/to/ca.crt", // CA证书路径 - CAKey: "path/to/ca.key", // CA私钥路径 - - // 选择证书生成算法(可选) - UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA) - - // 方式二:使用固定的TLS证书和私钥 - // TLSCert: "path/to/server.crt", - // TLSKey: "path/to/server.key", + // 启用HTTPS解密 + DecryptHTTPS: true, + + // 方式一:使用CA证书和私钥动态生成证书 + CACert: "path/to/ca.crt", // CA证书路径 + CAKey: "path/to/ca.key", // CA私钥路径 + + // 选择证书生成算法(可选) + UseECDSA: true, // 使用ECDSA生成证书(默认为false,使用RSA) + + // 方式二:使用固定的TLS证书和私钥 + // TLSCert: "path/to/server.crt", + // TLSKey: "path/to/server.key", } ``` @@ -400,11 +591,11 @@ config := &config.Config{ ```go proxy := proxy.NewProxy( - proxy.WithDecryptHTTPS(&proxy.MemCertCache{}), - proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"), - proxy.WithEnableECDSA(true), // 使用ECDSA生成证书 - // 或者使用静态TLS证书 - // proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key") + proxy.WithDecryptHTTPS(&proxy.MemCertCache{}), + proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"), + proxy.WithEnableECDSA(true), // 使用ECDSA生成证书 + // 或者使用静态TLS证书 + // proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key") ) ``` diff --git a/README_DNS.md b/README_DNS.md index 5fd4044..56d434f 100644 --- a/README_DNS.md +++ b/README_DNS.md @@ -6,6 +6,10 @@ - 动态变更后端服务器IP - 自定义后端服务器端口 - 泛解析(通配符域名)支持 +- 多后端服务器支持 + - 一个域名对应多个后端服务器 + - 支持多种负载均衡策略(轮询、随机、第一个可用) + - 支持动态添加/删除后端服务器 - 负载均衡和故障转移 - 绕过DNS污染 - 高效的DNS缓存 @@ -16,6 +20,8 @@ - **自定义端口**:为每个域名指定自定义端口,无需在URL中指定 - **泛解析**:支持通配符域名(如`*.example.com`)自动匹配多个子域名 - **多级泛解析**:支持复杂的通配符模式(如`api.*.example.com`) +- **多后端支持**:支持一个域名配置多个后端服务器 +- **负载均衡**:支持多种负载均衡策略 - **备用解析**:在自定义记录未找到时可选择使用系统DNS - **DNS缓存**:缓存解析结果以提高性能 - **自动重试**:解析失败时可配置重试策略 @@ -83,31 +89,37 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt ```json { "records": { - "example.com": "93.184.216.34", - "api.example.com": "93.184.216.35:8443", + "example.com": ["93.184.216.34", "93.184.216.35"], + "api.example.com": ["93.184.216.35:8443", "93.184.216.36:8443"], - "*.github.com": "140.82.121.3", - "github.com": "140.82.121.4", + "*.github.com": ["140.82.121.3", "140.82.121.4"], + "github.com": ["140.82.121.4", "140.82.121.5"], - "*.dev.local": "127.0.0.1:3000", - "api.*.dev.local": "127.0.0.1:3001" + "*.dev.local": ["127.0.0.1:3000", "127.0.0.1:3001"], + "api.*.dev.local": ["127.0.0.1:3001", "127.0.0.1:3002"] }, "use_fallback": true, - "ttl": 300 + "ttl": 300, + "load_balance_strategy": "round_robin" } ``` ### Hosts格式 ``` -# 精确匹配 +# 精确匹配(多后端) 93.184.216.34 example.com +93.184.216.35 example.com 93.184.216.35:8443 api.example.com +93.184.216.36:8443 api.example.com -# 泛解析(通配符域名) +# 泛解析(多后端) 140.82.121.3 *.github.com +140.82.121.4 *.github.com 127.0.0.1:3000 *.dev.local +127.0.0.1:3001 *.dev.local 127.0.0.1:3001 api.*.dev.local +127.0.0.1:3002 api.*.dev.local ``` ## 编程接口 @@ -118,57 +130,76 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt import "github.com/goproxy/internal/dns" // 创建解析器 -resolver := dns.NewResolver() +resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略 + dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL +) -// 添加标准记录(使用默认端口) -resolver.Add("example.com", "93.184.216.34") +// 添加多个后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) +resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) +resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) -// 添加带端口的记录 -resolver.AddWithPort("api.example.com", "93.184.216.35", 8443) +// 添加泛解析记录(多后端) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) -// 添加泛解析记录(使用默认端口) -resolver.AddWildcard("*.example.com", "93.184.216.36") - -// 添加带端口的泛解析记录 -resolver.AddWildcardWithPort("*.api.example.com", "93.184.216.37", 8444) - -// 解析域名(只获取IP) -ip, err := resolver.Resolve("example.com") +// 解析域名(使用负载均衡策略选择后端) +endpoint, err := resolver.ResolveWithPort("api.example.com", 443) if err != nil { log.Fatalf("解析失败: %v", err) } -fmt.Printf("解析结果IP: %s\n", ip) +fmt.Printf("解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port) -// 测试泛解析功能 -ip, err = resolver.Resolve("sub.example.com") -if err != nil { - log.Fatalf("解析失败: %v", err) -} -fmt.Printf("泛解析结果IP: %s\n", ip) - -// 测试多级泛解析功能 -endpoint, err := resolver.ResolveWithPort("test.api.example.com", 443) -if err != nil { - log.Fatalf("解析失败: %v", err) -} -fmt.Printf("多级泛解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port) +// 动态添加/删除后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) ``` +### 负载均衡策略 + +GoProxy支持三种负载均衡策略: + +1. **轮询策略(Round Robin)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), + ) + ``` + +2. **随机策略(Random)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.Random), + ) + ``` + +3. **第一个可用策略(First Available)** + ```go + resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.FirstAvailable), + ) + ``` + ### 使用DNS拨号器 ```go import "github.com/goproxy/internal/dns" // 创建解析器 -resolver := dns.NewResolver() -resolver.Add("example.com", "93.184.216.34") -resolver.AddWildcard("*.example.com", "93.184.216.36") +resolver := dns.NewResolver( + dns.WithLoadBalanceStrategy(dns.RoundRobin), +) +resolver.AddWithPort("example.com", "192.168.1.1", 8080) +resolver.AddWithPort("example.com", "192.168.1.2", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.3", 8080) +resolver.AddWildcardWithPort("*.example.com", "192.168.1.4", 8080) // 创建拨号器 dialer := dns.NewDialer(resolver) -// 使用拨号器连接(会自动应用泛解析) -conn, err := dialer.Dial("tcp", "sub.example.com:443") +// 使用拨号器连接(会自动应用负载均衡) +conn, err := dialer.Dial("tcp", "api.example.com:443") if err != nil { log.Fatalf("连接失败: %v", err) } @@ -183,125 +214,116 @@ client := &http.Client{Transport: transport} ## 高级用法 -### 泛解析模式 +### 多后端配置模式 -泛解析支持以下模式: +多后端配置支持以下模式: -1. **单级通配符**:`*.example.com` 匹配 `a.example.com`、`b.example.com` 等 +1. **精确匹配多后端**: + ```go + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.2", 8080) + resolver.AddWithPort("api.example.com", "192.168.1.3", 8080) + ``` -2. **多级通配符**:`*.*.example.com` 匹配 `a.b.example.com`、`c.d.example.com` 等 +2. **泛解析多后端**: + ```go + resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + ``` -3. **中间通配符**:`api.*.example.com` 匹配 `api.v1.example.com`、`api.beta.example.com` 等 +3. **混合配置**: + ```go + // 精确匹配优先于泛解析 + resolver.AddWithPort("api.example.com", "192.168.1.1", 8080) + resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080) + ``` -4. **前缀通配符**:`api-*.example.com` 匹配 `api-v1.example.com`、`api-beta.example.com` 等 - -### 域名匹配优先级 - -当有多条规则可以匹配同一个域名时,解析器会按照以下优先级选择: - -1. 精确匹配(如 `example.com`) -2. 最具体的通配符匹配(如 `*.test.example.com` 比 `*.example.com` 更优先) -3. 后添加的通配符规则优先于先添加的规则 - -### 透明代理模式 - -泛解析代理支持透明模式,这种模式下代理会使用请求中的Host头来进行DNS解析,而不是固定的目标主机: - -```bash -go run cmd/wildcard_dns_proxy/main.go -target "" -``` - -这样,通过修改请求的Host头或使用不同的域名访问代理,可以自动路由到不同的后端服务器。 - -### 自定义DNS后端 - -可以实现自己的`dns.Resolver`接口来创建更复杂的DNS解析策略: +### 动态管理后端服务器 ```go -type MyResolver struct { - // 你的字段 -} +// 添加新的后端服务器 +resolver.AddWithPort("api.example.com", "192.168.1.4", 8080) -// 实现Resolver接口的所有方法 -func (r *MyResolver) Resolve(hostname string) (string, error) { - // 自定义逻辑 -} +// 删除特定的后端服务器 +resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080) -func (r *MyResolver) ResolveWithPort(hostname string, defaultPort int) (*dns.Endpoint, error) { - // 自定义逻辑 -} +// 删除整个域名记录 +resolver.Remove("api.example.com") -func (r *MyResolver) Add(hostname, ip string) error { - // 自定义逻辑 -} +// 清除所有记录 +resolver.Clear() +``` -func (r *MyResolver) AddWithPort(hostname, ip string, port int) error { - // 自定义逻辑 -} +### 健康检查集成 -func (r *MyResolver) AddWildcard(wildcardDomain, ip string) error { - // 自定义逻辑 -} +可以结合健康检查功能,自动剔除不健康的后端服务器: -func (r *MyResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error { - // 自定义逻辑 -} +```go +// 创建健康检查器 +healthChecker := healthcheck.NewChecker( + healthcheck.WithCheckInterval(30*time.Second), + healthcheck.WithTimeout(5*time.Second), +) -func (r *MyResolver) Remove(hostname string) error { - // 自定义逻辑 -} +// 添加健康检查 +healthChecker.Add("api.example.com", "192.168.1.1:8080") +healthChecker.Add("api.example.com", "192.168.1.2:8080") -func (r *MyResolver) Clear() { - // 自定义逻辑 -} +// 在解析器中使用健康检查结果 +resolver.SetHealthChecker(healthChecker) ``` ## 应用场景 -### 多环境测试 +### 高可用部署 -使用泛解析可以为不同环境的所有服务配置不同的后端: +使用多后端配置实现高可用: ``` -# 测试环境的所有服务 -192.168.1.100 *.test.example.com +# 主备模式 +192.168.1.1 api.example.com # 主服务器 +192.168.1.2 api.example.com # 备用服务器 -# 预发布环境的所有服务 -192.168.1.101 *.staging.example.com +# 负载均衡模式 +192.168.1.1 api.example.com # 服务器1 +192.168.1.2 api.example.com # 服务器2 +192.168.1.3 api.example.com # 服务器3 +``` -# 生产环境的所有服务 -192.168.1.102 *.production.example.com +### 多环境部署 + +为不同环境配置不同的后端服务器组: + +``` +# 测试环境 +192.168.1.10 *.test.example.com +192.168.1.11 *.test.example.com + +# 预发布环境 +192.168.1.20 *.staging.example.com +192.168.1.21 *.staging.example.com + +# 生产环境 +192.168.1.30 *.production.example.com +192.168.1.31 *.production.example.com ``` ### 微服务架构 -为不同类型的微服务提供统一的路由模式: +为不同类型的微服务配置多个后端: ``` +# 认证服务 10.0.0.1:8001 *.auth.internal -10.0.0.2:8002 *.user.internal -10.0.0.3:8003 *.payment.internal -``` +10.0.0.2:8001 *.auth.internal -### 多租户系统 +# 用户服务 +10.0.0.3:8002 *.user.internal +10.0.0.4:8002 *.user.internal -在多租户系统中为每个租户路由到不同后端: - -``` -# 每个租户的子域名指向专用服务器 -192.168.1.10 tenant1.*.example.com -192.168.1.11 tenant2.*.example.com -192.168.1.12 tenant3.*.example.com -``` - -### 本地开发环境 - -在本地开发中快速模拟复杂的服务架构: - -``` -127.0.0.1:3000 *.local -127.0.0.1:3001 api.*.local -127.0.0.1:5432 db.*.local +# 支付服务 +10.0.0.5:8003 *.payment.internal +10.0.0.6:8003 *.payment.internal ``` ## 注意事项 @@ -310,4 +332,7 @@ func (r *MyResolver) Clear() { 2. 泛解析规则的顺序会影响匹配结果,后添加的规则优先级更高 3. 过多的泛解析规则可能会影响性能,建议合理组织规则 4. 当域名同时匹配多个规则时,精确匹配优先于通配符匹配 -5. 自签名证书会导致浏览器警告,仅用于测试目的 \ No newline at end of file +5. 自签名证书会导致浏览器警告,仅用于测试目的 +6. 多后端配置时,建议使用健康检查确保后端服务器可用性 +7. 负载均衡策略的选择应根据实际需求进行配置 +8. 动态添加/删除后端服务器时,需要考虑并发安全性 \ No newline at end of file diff --git a/delegate.go b/delegate.go deleted file mode 100644 index 1a617d8..0000000 --- a/delegate.go +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 ouqiang authors -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -package goproxy - -import ( - "log" - "net/http" - "net/url" - "strings" -) - -// Context 代理上下文 -type Context struct { - Req *http.Request - Data map[interface{}]interface{} - TunnelProxy bool - abort bool -} - -func (c *Context) IsHTTPS() bool { - return c.Req.URL.Scheme == "https" -} - -var defaultPorts = map[string]string{ - "https": "443", - "http": "80", - "": "80", -} - -func (c *Context) WebsocketUrl() *url.URL { - u := new(url.URL) - *u = *c.Req.URL - if c.IsHTTPS() { - u.Scheme = "wss" - } else { - u.Scheme = "ws" - } - - return u -} - -func (c *Context) Addr() string { - addr := c.Req.Host - - if !strings.Contains(c.Req.URL.Host, ":") { - addr += ":" + defaultPorts[c.Req.URL.Scheme] - } - - return addr -} - -// Abort 中断执行 -func (c *Context) Abort() { - c.abort = true -} - -// IsAborted 是否已中断执行 -func (c *Context) IsAborted() bool { - return c.abort -} - -// Reset 重置 -func (c *Context) Reset(req *http.Request) { - c.Req = req - c.Data = make(map[interface{}]interface{}) - c.abort = false - c.TunnelProxy = false -} - -type Delegate interface { - // Connect 收到客户端连接 - Connect(ctx *Context, rw http.ResponseWriter) - // Auth 代理身份认证 - Auth(ctx *Context, rw http.ResponseWriter) - // BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body - BeforeRequest(ctx *Context) - // BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code - BeforeResponse(ctx *Context, resp *http.Response, err error) - // WebSocketSendMessage websocket发送消息 - WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte) - // WebSockerReceiveMessage websocket接收 消息 - WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte) - // ParentProxy 上级代理 - ParentProxy(*http.Request) (*url.URL, error) - // Finish 本次请求结束 - Finish(ctx *Context) - // 记录错误信息 - ErrorLog(err error) -} - -var _ Delegate = &DefaultDelegate{} - -// DefaultDelegate 默认Handler什么也不做 -type DefaultDelegate struct { - Delegate -} - -func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {} - -func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {} - -func (h *DefaultDelegate) BeforeRequest(ctx *Context) {} - -func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {} - -func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) { - return http.ProxyFromEnvironment(req) -} - -// WebSocketSendMessage websocket发送消息 -func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {} - -// WebSockerReceiveMessage websocket接收 消息 -func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {} - -func (h *DefaultDelegate) Finish(ctx *Context) {} - -func (h *DefaultDelegate) ErrorLog(err error) { - log.Println(err) -} diff --git a/go.mod b/go.mod index 65f955a..38d8a2d 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/goproxy go 1.24.0 require ( - github.com/ouqiang/goproxy v1.3.2 github.com/ouqiang/websocket v1.6.2 github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 golang.org/x/time v0.11.0 diff --git a/go.sum b/go.sum index 3e22dbe..44ed21d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk= -github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww= github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U= github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M= github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ= diff --git a/internal/dns/resolver.go b/internal/dns/resolver.go index a483ade..0f32405 100644 --- a/internal/dns/resolver.go +++ b/internal/dns/resolver.go @@ -39,18 +39,19 @@ type Resolver interface { // CustomResolver 自定义DNS解析器 type CustomResolver struct { mu sync.RWMutex - records map[string]*Endpoint // 精确域名到端点的映射 - wildcardRules []wildcardRule // 通配符规则列表 - cache map[string]cacheEntry // 外部域名解析缓存 - fallback bool // 是否在本地记录找不到时回退到系统DNS - ttl time.Duration // 缓存TTL + records map[string][]*Endpoint // 精确域名到多个端点的映射 + wildcardRules []wildcardRule // 通配符规则列表 + cache map[string]cacheEntry // 外部域名解析缓存 + fallback bool // 是否在本地记录找不到时回退到系统DNS + ttl time.Duration // 缓存TTL + lbStrategy LoadBalanceStrategy // 负载均衡策略 } // wildcardRule 通配符规则 type wildcardRule struct { - pattern string // 原始通配符模式,如 *.example.com - parts []string // 分解后的模式部分,如 ["*", "example", "com"] - endpoint *Endpoint // 对应的端点 + pattern string // 原始通配符模式,如 *.example.com + parts []string // 分解后的模式部分,如 ["*", "example", "com"] + endpoints []*Endpoint // 对应的多个端点 } // cacheEntry 缓存条目 @@ -59,14 +60,24 @@ type cacheEntry struct { expiresAt time.Time } +// LoadBalanceStrategy 负载均衡策略 +type LoadBalanceStrategy int + +const ( + RoundRobin LoadBalanceStrategy = iota // 轮询策略 + Random // 随机策略 + FirstAvailable // 第一个可用策略 +) + // NewResolver 创建新的自定义DNS解析器 func NewResolver(options ...Option) *CustomResolver { r := &CustomResolver{ - records: make(map[string]*Endpoint), + records: make(map[string][]*Endpoint), wildcardRules: make([]wildcardRule, 0), cache: make(map[string]cacheEntry), fallback: true, ttl: 5 * time.Minute, + lbStrategy: RoundRobin, } // 应用选项 @@ -92,15 +103,15 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin r.mu.RLock() // 精确匹配 - if endpoint, ok := r.records[host]; ok { + if endpoints, ok := r.records[host]; ok && len(endpoints) > 0 { r.mu.RUnlock() - return endpoint, nil + return r.selectEndpoint(endpoints, defaultPort), nil } // 尝试通配符匹配 - if endpoint := r.matchWildcard(host); endpoint != nil { + if endpoints := r.matchWildcard(host); len(endpoints) > 0 { r.mu.RUnlock() - return endpoint, nil + return r.selectEndpoint(endpoints, defaultPort), nil } // 检查缓存 @@ -152,15 +163,49 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin return nil, errors.New("未找到域名记录且系统DNS回退被禁用") } +// selectEndpoint 根据负载均衡策略选择一个端点 +func (r *CustomResolver) selectEndpoint(endpoints []*Endpoint, defaultPort int) *Endpoint { + if len(endpoints) == 0 { + return nil + } + + // 如果只有一个端点,直接返回 + if len(endpoints) == 1 { + endpoint := endpoints[0] + if defaultPort > 0 { + endpoint.Port = defaultPort + } + return endpoint + } + + // 根据负载均衡策略选择端点 + var selected *Endpoint + switch r.lbStrategy { + case RoundRobin: + // 轮询策略:每次选择下一个端点 + selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))] + case Random: + // 随机策略:随机选择一个端点 + selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))] + case FirstAvailable: + // 第一个可用策略:选择第一个端点 + selected = endpoints[0] + } + + if selected != nil && defaultPort > 0 { + selected.Port = defaultPort + } + return selected +} + // matchWildcard 尝试匹配通配符规则 -func (r *CustomResolver) matchWildcard(host string) *Endpoint { +func (r *CustomResolver) matchWildcard(host string) []*Endpoint { hostParts := strings.Split(host, ".") // 按照通配符规则列表的顺序尝试匹配 - // 规则顺序应该保证更具体的规则先匹配 for _, rule := range r.wildcardRules { if matchDomainPattern(hostParts, rule.parts) { - return rule.endpoint + return rule.endpoints } } @@ -204,7 +249,18 @@ func (r *CustomResolver) AddWithPort(host, ip string, port int) error { r.mu.Lock() defer r.mu.Unlock() - r.records[host] = NewEndpointWithPort(ip, port) + endpoint := NewEndpointWithPort(ip, port) + if endpoints, exists := r.records[host]; exists { + // 检查是否已存在相同的端点 + for _, e := range endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + return nil // 端点已存在,无需重复添加 + } + } + r.records[host] = append(r.records[host], endpoint) + } else { + r.records[host] = []*Endpoint{endpoint} + } return nil } @@ -230,11 +286,27 @@ func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int r.mu.Lock() defer r.mu.Unlock() + // 查找是否已存在相同的通配符规则 + for i, rule := range r.wildcardRules { + if rule.pattern == wildcardDomain { + // 检查是否已存在相同的端点 + endpoint := NewEndpointWithPort(ip, port) + for _, e := range rule.endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + return nil // 端点已存在,无需重复添加 + } + } + // 添加新端点 + r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint) + return nil + } + } + // 创建新的通配符规则 rule := wildcardRule{ - pattern: wildcardDomain, - parts: parts, - endpoint: NewEndpointWithPort(ip, port), + pattern: wildcardDomain, + parts: parts, + endpoints: []*Endpoint{NewEndpointWithPort(ip, port)}, } // 将新规则添加到规则列表头部,确保更新的规则优先匹配 @@ -266,12 +338,54 @@ func (r *CustomResolver) Remove(host string) error { return errors.New("域名记录不存在") } +// RemoveEndpoint 删除特定端点 +func (r *CustomResolver) RemoveEndpoint(host, ip string, port int) error { + r.mu.Lock() + defer r.mu.Unlock() + + // 尝试从精确匹配记录中删除 + if endpoints, ok := r.records[host]; ok { + newEndpoints := make([]*Endpoint, 0) + for _, e := range endpoints { + if e.IP != ip || e.Port != port { + newEndpoints = append(newEndpoints, e) + } + } + if len(newEndpoints) == 0 { + delete(r.records, host) + } else { + r.records[host] = newEndpoints + } + return nil + } + + // 尝试从通配符规则中删除 + for i, rule := range r.wildcardRules { + if rule.pattern == host { + newEndpoints := make([]*Endpoint, 0) + for _, e := range rule.endpoints { + if e.IP != ip || e.Port != port { + newEndpoints = append(newEndpoints, e) + } + } + if len(newEndpoints) == 0 { + r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...) + } else { + r.wildcardRules[i].endpoints = newEndpoints + } + return nil + } + } + + return errors.New("域名记录不存在") +} + // Clear 清除所有解析规则 func (r *CustomResolver) Clear() { r.mu.Lock() defer r.mu.Unlock() - r.records = make(map[string]*Endpoint) + r.records = make(map[string][]*Endpoint) r.wildcardRules = make([]wildcardRule, 0) r.cache = make(map[string]cacheEntry) } @@ -293,6 +407,13 @@ func WithTTL(ttl time.Duration) Option { } } +// WithLoadBalanceStrategy 设置负载均衡策略 +func WithLoadBalanceStrategy(strategy LoadBalanceStrategy) Option { + return func(r *CustomResolver) { + r.lbStrategy = strategy + } +} + // LoadFromMap 从映射加载DNS记录 func (r *CustomResolver) LoadFromMap(records map[string]string) error { r.mu.Lock() @@ -310,13 +431,33 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error { return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") } - // 添加通配符规则 - rule := wildcardRule{ - pattern: host, - parts: strings.Split(host, "."), - endpoint: endpoint, + // 查找是否已存在相同的通配符规则 + found := false + for i, rule := range r.wildcardRules { + if rule.pattern == host { + // 检查是否已存在相同的端点 + for _, e := range rule.endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + found = true + break + } + } + if !found { + r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint) + } + break + } + } + + if !found { + // 创建新的通配符规则 + rule := wildcardRule{ + pattern: host, + parts: strings.Split(host, "."), + endpoints: []*Endpoint{endpoint}, + } + r.wildcardRules = append(r.wildcardRules, rule) } - r.wildcardRules = append(r.wildcardRules, rule) } else { // 常规记录 @@ -329,7 +470,21 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error { return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") } - r.records[host] = endpoint + // 检查是否已存在相同的端点 + if endpoints, exists := r.records[host]; exists { + found := false + for _, e := range endpoints { + if e.IP == endpoint.IP && e.Port == endpoint.Port { + found = true + break + } + } + if !found { + r.records[host] = append(r.records[host], endpoint) + } + } else { + r.records[host] = []*Endpoint{endpoint} + } } } diff --git a/proxy.go b/proxy.go deleted file mode 100644 index e7ed662..0000000 --- a/proxy.go +++ /dev/null @@ -1,753 +0,0 @@ -// Copyright 2018 ouqiang authors -// -// Licensed under the Apache License, Version 2.0 (the "License"): you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. - -// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据 -package goproxy - -import ( - "bufio" - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "net/http/httptrace" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/viki-org/dnscache" - - "github.com/ouqiang/goproxy/cert" - "github.com/ouqiang/websocket" -) - -const ( - // 连接目标服务器超时时间 - defaultTargetConnectTimeout = 5 * time.Second - // 目标服务器读写超时时间 - defaultTargetReadWriteTimeout = 10 * time.Second -) - -type DialContext func(ctx context.Context, network, addr string) (net.Conn, error) - -// 隧道连接成功响应行 -var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n") - -var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway))) - -var ( - bufPool = sync.Pool{ - New: func() interface{} { - return make([]byte, 32*1024) - }, - } - - ctxPool = sync.Pool{ - New: func() interface{} { - return new(Context) - }, - } - headerPool = NewHeaderPool() - requestPool = newRequestPool() -) - -type RequestPool struct { - pool sync.Pool -} - -func newRequestPool() *RequestPool { - return &RequestPool{ - pool: sync.Pool{ - New: func() interface{} { - return new(http.Request) - }, - }, - } -} - -func (p *RequestPool) Get() *http.Request { - req := p.pool.Get().(*http.Request) - - req.Method = "" - req.URL = nil - req.Proto = "" - req.ProtoMajor = 0 - req.ProtoMinor = 0 - req.Header = nil - req.Body = nil - req.GetBody = nil - req.ContentLength = 0 - req.TransferEncoding = nil - req.Close = false - req.Host = "" - req.Form = nil - req.PostForm = nil - req.MultipartForm = nil - req.Trailer = nil - req.RemoteAddr = "" - req.RequestURI = "" - req.TLS = nil - req.Cancel = nil - req.Response = nil - - return req -} - -func (p *RequestPool) Put(req *http.Request) { - if req != nil { - p.pool.Put(req) - } -} - -type HeaderPool struct { - pool sync.Pool -} - -func NewHeaderPool() *HeaderPool { - return &HeaderPool{ - pool: sync.Pool{ - New: func() interface{} { - return http.Header{} - }, - }, - } -} - -func (p *HeaderPool) Get() http.Header { - header := p.pool.Get().(http.Header) - for k := range header { - delete(header, k) - } - - return header -} - -func (p *HeaderPool) Put(header http.Header) { - if header != nil { - p.pool.Put(header) - } -} - -// 生成隧道建立请求行 -func makeTunnelRequestLine(addr string) string { - return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr) -} - -type options struct { - disableKeepAlive bool - delegate Delegate - - decryptHTTPS bool - websocketIntercept bool - certCache cert.Cache - transport *http.Transport - clientTrace *httptrace.ClientTrace -} - -type Option func(*options) - -// WithDisableKeepAlive 连接是否重用 -func WithDisableKeepAlive(disableKeepAlive bool) Option { - return func(opt *options) { - opt.disableKeepAlive = disableKeepAlive - } -} - -func WithClientTrace(t *httptrace.ClientTrace) Option { - return func(opt *options) { - opt.clientTrace = t - } -} - -// WithDelegate 设置委托类 -func WithDelegate(delegate Delegate) Option { - return func(opt *options) { - opt.delegate = delegate - } -} - -// WithTransport 自定义http transport -func WithTransport(t *http.Transport) Option { - return func(opt *options) { - opt.transport = t - } -} - -// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口 -func WithDecryptHTTPS(c cert.Cache) Option { - return func(opt *options) { - opt.decryptHTTPS = true - opt.certCache = c - } -} - -// WithEnableWebsocketIntercept 拦截websocket -func WithEnableWebsocketIntercept() Option { - return func(opt *options) { - opt.websocketIntercept = true - } -} - -// New 创建proxy实例 -func New(opt ...Option) *Proxy { - opts := &options{} - for _, o := range opt { - o(opts) - } - if opts.delegate == nil { - opts.delegate = &DefaultDelegate{} - } - if opts.transport == nil { - opts.transport = &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - MaxIdleConns: 100, - MaxConnsPerHost: 10, - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - } - - p := &Proxy{} - p.delegate = opts.delegate - p.websocketIntercept = opts.websocketIntercept - p.decryptHTTPS = opts.decryptHTTPS - if p.decryptHTTPS { - p.cert = cert.NewCertificate(opts.certCache, true) - } - p.transport = opts.transport - p.transport.DialContext = p.dialContext() - p.dnsCache = dnscache.New(5 * time.Minute) - p.transport.DisableKeepAlives = opts.disableKeepAlive - p.transport.Proxy = p.delegate.ParentProxy - p.clientTrace = opts.clientTrace - - return p -} - -// Proxy 实现了http.Handler接口 -type Proxy struct { - delegate Delegate - clientConnNum int32 - decryptHTTPS bool - websocketIntercept bool - cert *cert.Certificate - transport *http.Transport - clientTrace *httptrace.ClientTrace - dnsCache *dnscache.Resolver -} - -var _ http.Handler = &Proxy{} - -// ServeHTTP 实现了http.Handler接口 -func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if req.URL.Host == "" { - req.URL.Host = req.Host - } - atomic.AddInt32(&p.clientConnNum, 1) - ctx := ctxPool.Get().(*Context) - ctx.Reset(req) - - defer func() { - p.delegate.Finish(ctx) - ctxPool.Put(ctx) - atomic.AddInt32(&p.clientConnNum, -1) - }() - p.delegate.Connect(ctx, rw) - if ctx.abort { - return - } - p.delegate.Auth(ctx, rw) - if ctx.abort { - return - } - - switch { - case ctx.Req.Method == http.MethodConnect: - p.tunnelProxy(ctx, rw) - case websocket.IsWebSocketUpgrade(ctx.Req): - p.tunnelProxy(ctx, rw) - default: - p.httpProxy(ctx, rw) - } -} - -// ClientConnNum 获取客户端连接数 -func (p *Proxy) ClientConnNum() int32 { - return atomic.LoadInt32(&p.clientConnNum) -} - -// DoRequest 执行HTTP请求,并调用responseFunc处理response -func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) { - if ctx.Data == nil { - ctx.Data = make(map[interface{}]interface{}) - } - p.delegate.BeforeRequest(ctx) - if ctx.abort { - return - } - newReq := requestPool.Get() - *newReq = *ctx.Req - newHeader := headerPool.Get() - CloneHeader(newReq.Header, newHeader) - newReq.Header = newHeader - for _, item := range hopHeaders { - if newReq.Header.Get(item) != "" { - newReq.Header.Del(item) - } - } - if p.clientTrace != nil { - newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace)) - } - - resp, err := p.transport.RoundTrip(newReq) - p.delegate.BeforeResponse(ctx, resp, err) - if ctx.abort { - return - } - if err == nil { - for _, h := range hopHeaders { - resp.Header.Del(h) - } - } - responseFunc(resp, err) - headerPool.Put(newHeader) - requestPool.Put(newReq) -} - -// HTTP代理 -func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) { - ctx.Req.URL.Scheme = "http" - p.DoRequest(ctx, func(resp *http.Response, err error) { - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err)) - rw.WriteHeader(http.StatusBadGateway) - return - } - defer func() { - _ = resp.Body.Close() - }() - CopyHeader(rw.Header(), resp.Header) - rw.WriteHeader(resp.StatusCode) - buf := bufPool.Get().([]byte) - _, _ = io.CopyBuffer(rw, resp.Body, buf) - bufPool.Put(buf) - }) -} - -// HTTPS代理 -func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) { - if websocket.IsWebSocketUpgrade(ctx.Req) { - p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil)) - return - } - p.DoRequest(ctx, func(resp *http.Response, err error) { - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err)) - _, _ = tlsClientConn.Write(badGateway) - return - } - err = resp.Write(tlsClientConn) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err)) - } - _ = resp.Body.Close() - }) -} - -// 隧道代理 -func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) { - clientConn, err := hijacker(rw) - if err != nil { - p.delegate.ErrorLog(err) - rw.WriteHeader(http.StatusBadGateway) - return - } - defer func() { - _ = clientConn.Close() - }() - - if websocket.IsWebSocketUpgrade(ctx.Req) { - p.websocketProxy(ctx, clientConn) - return - } - - parentProxyURL, err := p.delegate.ParentProxy(ctx.Req) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err)) - rw.WriteHeader(http.StatusBadGateway) - return - } - if parentProxyURL == nil { - _, err = clientConn.Write(tunnelEstablishedResponseLine) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err)) - return - } - } - - isWebsocket := p.detectConnProtocol(clientConn) - if isWebsocket { - req, err := http.ReadRequest(clientConn.BufferReader()) - if err != nil { - if err != io.EOF { - p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err)) - } - return - } - req.RemoteAddr = ctx.Req.RemoteAddr - req.URL.Scheme = "http" - req.URL.Host = req.Host - ctx.Req = req - - p.websocketProxy(ctx, clientConn) - return - } - var tlsClientConn *tls.Conn - if p.decryptHTTPS { - tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host) - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err)) - return - } - tlsClientConn = tls.Server(clientConn, tlsConfig) - defer func() { - _ = tlsClientConn.Close() - }() - if err := tlsClientConn.Handshake(); err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err)) - return - } - - buf := bufio.NewReader(tlsClientConn) - tlsReq, err := http.ReadRequest(buf) - if err != nil { - if err != io.EOF { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err)) - } - return - } - tlsReq.RemoteAddr = ctx.Req.RemoteAddr - tlsReq.URL.Scheme = "https" - tlsReq.URL.Host = tlsReq.Host - ctx.Req = tlsReq - } - - targetAddr := ctx.Req.URL.Host - if parentProxyURL != nil { - targetAddr = parentProxyURL.Host - } - if !strings.Contains(targetAddr, ":") { - targetAddr += ":443" - } - - targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout) - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err)) - return - } - defer func() { - _ = targetConn.Close() - }() - if parentProxyURL != nil { - tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host) - _, _ = targetConn.Write([]byte(tunnelRequestLine)) - } - - if p.decryptHTTPS { - p.httpsProxy(ctx, tlsClientConn) - } else { - p.tunnelConnected(ctx, nil) - p.transfer(clientConn, targetConn) - } -} - -// WebSocket代理 -func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) { - if !p.websocketIntercept { - remoteAddr := ctx.Addr() - var err error - var targetConn net.Conn - if ctx.IsHTTPS() { - targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true}) - } else { - targetConn, err = net.Dial("tcp", remoteAddr) - } - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err)) - return - } - err = ctx.Req.Write(targetConn) - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err)) - return - } - p.tunnelConnected(ctx, nil) - p.transfer(srcConn, targetConn) - return - } - - up := &websocket.Upgrader{ - HandshakeTimeout: defaultTargetConnectTimeout, - ReadBufferSize: 4096, - WriteBufferSize: 4096, - CheckOrigin: func(r *http.Request) bool { - return true - }, - } - srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{}) - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err)) - return - } - - u := ctx.WebsocketUrl() - d := websocket.Dialer{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - } - - dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout) - defer cancel() - targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header) - if err != nil { - p.tunnelConnected(ctx, err) - p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err)) - return - } - p.tunnelConnected(ctx, nil) - p.transferWebsocket(ctx, srcWSConn, targetWSConn) -} - -// 探测连接协议 -func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) { - methodBytes, err := connBuf.Peek(3) - if err != nil { - return false - } - method := string(methodBytes) - if method != http.MethodGet { - return false - } - - return true -} - -// webSocket双向转发 -func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) { - doneCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for { - if doneCtx.Err() != nil { - return - } - - msgType, msg, err := srcConn.ReadMessage() - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err)) - return - } - p.delegate.WebSocketSendMessage(ctx, &msgType, &msg) - err = targetConn.WriteMessage(msgType, msg) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err)) - return - } - } - }() - - for { - if doneCtx.Err() != nil { - return - } - - msgType, msg, err := targetConn.ReadMessage() - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err)) - return - } - p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg) - err = srcConn.WriteMessage(msgType, msg) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err)) - return - } - } -} - -// 双向转发 -func (p *Proxy) transfer(src net.Conn, dst net.Conn) { - go func() { - buf := bufPool.Get().([]byte) - _, err := io.CopyBuffer(src, dst, buf) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err)) - } - bufPool.Put(buf) - _ = src.Close() - _ = dst.Close() - }() - - buf := bufPool.Get().([]byte) - _, err := io.CopyBuffer(dst, src, buf) - if err != nil { - p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) - } - bufPool.Put(buf) - _ = dst.Close() - _ = src.Close() -} - -func (p *Proxy) tunnelConnected(ctx *Context, err error) { - ctx.TunnelProxy = true - p.delegate.BeforeRequest(ctx) - if err != nil { - p.delegate.BeforeResponse(ctx, nil, err) - return - } - resp := &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Proto: "1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: http.Header{}, - Body: http.NoBody, - } - p.delegate.BeforeResponse(ctx, resp, nil) -} - -func (p *Proxy) dialContext() DialContext { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: defaultTargetConnectTimeout, - } - separator := strings.LastIndex(addr, ":") - ips, err := p.dnsCache.Fetch(addr[:separator]) - if err != nil { - return nil, err - } - var ip string - for _, item := range ips { - ip = item.String() - if !strings.Contains(ip, ":") { - break - } - } - - addr = ip + addr[separator:] - - return dialer.DialContext(ctx, network, addr) - } -} - -// 获取底层连接 -func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) { - hijacker, ok := rw.(http.Hijacker) - if !ok { - return nil, fmt.Errorf("http server不支持Hijacker") - } - conn, buf, err := hijacker.Hijack() - if err != nil { - return nil, fmt.Errorf("hijacker错误: %s", err) - } - - return NewConnBuffer(conn, buf), nil -} - -// CopyHeader 浅拷贝Header -func CopyHeader(dst, src http.Header) { - for k, vv := range src { - for _, v := range vv { - dst.Add(k, v) - } - } -} - -// CloneHeader 深拷贝Header -func CloneHeader(h http.Header, h2 http.Header) { - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } -} - -var hopHeaders = []string{ - "Proxy-Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "Te", - "Trailer", - "Transfer-Encoding", -} - -type ConnBuffer struct { - net.Conn - buf *bufio.ReadWriter -} - -func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer { - if buf == nil { - buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - } - return &ConnBuffer{ - Conn: conn, - buf: buf, - } -} - -func (cb *ConnBuffer) BufferReader() *bufio.Reader { - return cb.buf.Reader -} - -func (cb *ConnBuffer) Read(b []byte) (n int, err error) { - return cb.buf.Read(b) -} - -func (cb *ConnBuffer) Peek(n int) ([]byte, error) { - return cb.buf.Peek(n) -} - -func (cb *ConnBuffer) Write(p []byte) (n int, err error) { - n, err = cb.buf.Write(p) - if err != nil { - return 0, err - } - - return n, cb.buf.Flush() -} - -func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return cb.Conn, cb.buf, nil -} - -func (cb *ConnBuffer) WriteHeader(_ int) {} - -func (cb *ConnBuffer) Header() http.Header { return nil }