增加多后端 DNS 解析和负载均衡的支持

This commit is contained in:
2025-03-13 18:04:41 +08:00
parent cfcc696bda
commit f1bbf466e7
8 changed files with 555 additions and 1073 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,2 @@
.idea
./delegate.go
./proxy.go
.vscode

245
README.md
View File

@@ -12,7 +12,12 @@ GoProxy是一个功能强大的Go语言HTTP代理库支持HTTP、HTTPS和WebS
- 通配符域名证书支持
- 支持RSA和ECDSA证书算法选择
- 支持上游代理链
- 支持负载均衡(轮询、随机、权重等)
- 支持多后端DNS解析和负载均衡
- 支持一个域名对应多个后端服务器
- 支持多种负载均衡策略(轮询、随机、第一个可用)
- 支持通配符域名解析
- 支持自定义DNS解析规则
- 支持动态添加/删除后端服务器
- 支持健康检查
- 支持请求重试
- 支持HTTP缓存
@@ -264,6 +269,192 @@ func main() {
}
```
### 多后端DNS解析和负载均衡
```go
package main
import (
"log"
"net/http"
"time"
"github.com/goproxy/internal/dns"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建DNS解析器
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
)
// 添加多个后端服务器
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
// 添加通配符域名解析
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
// 创建自定义委托
delegate := &CustomDelegate{
resolver: resolver,
}
// 创建代理
p := proxy.New(&proxy.Options{
Delegate: delegate,
})
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
// CustomDelegate 自定义委托
type CustomDelegate struct {
proxy.DefaultDelegate
resolver *dns.CustomResolver
}
// ResolveBackend 解析后端服务器
func (d *CustomDelegate) ResolveBackend(req *http.Request) (string, error) {
// 从请求中获取目标主机
host := req.Host
if host == "" {
host = req.URL.Host
}
// 解析域名获取后端服务器
endpoint, err := d.resolver.ResolveWithPort(host, 80)
if err != nil {
return "", err
}
return endpoint.String(), nil
}
```
### 从配置文件加载DNS规则
```go
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"github.com/goproxy/internal/dns"
"github.com/goproxy/internal/proxy"
)
func main() {
// 创建DNS解析器
resolver := dns.NewResolver()
// 从JSON文件加载DNS规则
if err := loadDNSConfig(resolver, "dns_config.json"); err != nil {
log.Fatalf("加载DNS配置失败: %v", err)
}
// 创建自定义委托
delegate := &CustomDelegate{
resolver: resolver,
}
// 创建代理
p := proxy.New(&proxy.Options{
Delegate: delegate,
})
// 启动HTTP服务器
log.Println("代理服务器启动在 :8080")
if err := http.ListenAndServe(":8080", p); err != nil {
log.Fatalf("代理服务器启动失败: %v", err)
}
}
// DNSConfig DNS配置结构
type DNSConfig struct {
Records map[string][]string `json:"records"` // 域名到IP地址列表的映射
Wildcards map[string][]string `json:"wildcards"` // 通配符域名到IP地址列表的映射
}
func loadDNSConfig(resolver *dns.CustomResolver, filename string) error {
data, err := os.ReadFile(filename)
if err != nil {
return err
}
var config DNSConfig
if err := json.Unmarshal(data, &config); err != nil {
return err
}
// 加载精确匹配记录
for host, ips := range config.Records {
for _, ip := range ips {
if err := resolver.Add(host, ip); err != nil {
return err
}
}
}
// 加载通配符记录
for pattern, ips := range config.Wildcards {
for _, ip := range ips {
if err := resolver.AddWildcard(pattern, ip); err != nil {
return err
}
}
}
return nil
}
```
示例配置文件 `dns_config.json`:
```json
{
"records": {
"api.example.com": [
"192.168.1.1:8080",
"192.168.1.2:8080",
"192.168.1.3:8080"
]
},
"wildcards": {
"*.example.com": [
"192.168.1.1:8080",
"192.168.1.2:8080"
]
}
}
```
### 动态管理后端服务器
```go
// 添加新的后端服务器
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
// 删除特定的后端服务器
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
// 删除整个域名记录
resolver.Remove("api.example.com")
// 清除所有记录
resolver.Clear()
```
## 架构设计
GoProxy采用模块化设计主要包含以下模块
@@ -326,14 +517,14 @@ proxy := proxy.NewProxy()
// 创建一个功能丰富的代理
proxy := proxy.NewProxy(
proxy.WithConfig(config.DefaultConfig()),
proxy.WithHTTPCache(myCache),
proxy.WithDecryptHTTPS(myCertCache),
proxy.WithCACertAndKey("ca.crt", "ca.key"),
proxy.WithMetrics(myMetrics),
proxy.WithLoadBalancer(myLoadBalancer),
proxy.WithRequestTimeout(10 * time.Second),
proxy.WithEnableCORS(true)
proxy.WithConfig(config.DefaultConfig()),
proxy.WithHTTPCache(myCache),
proxy.WithDecryptHTTPS(myCertCache),
proxy.WithCACertAndKey("ca.crt", "ca.key"),
proxy.WithMetrics(myMetrics),
proxy.WithLoadBalancer(myLoadBalancer),
proxy.WithRequestTimeout(10 * time.Second),
proxy.WithEnableCORS(true)
)
```
@@ -380,19 +571,19 @@ GoProxy 提供了以下 With 方法用于配置代理的各个方面:
```go
config := &config.Config{
// 启用HTTPS解密
DecryptHTTPS: true,
// 方式一使用CA证书和私钥动态生成证书
CACert: "path/to/ca.crt", // CA证书路径
CAKey: "path/to/ca.key", // CA私钥路径
// 选择证书生成算法(可选)
UseECDSA: true, // 使用ECDSA生成证书默认为false使用RSA
// 方式二使用固定的TLS证书和私钥
// TLSCert: "path/to/server.crt",
// TLSKey: "path/to/server.key",
// 启用HTTPS解密
DecryptHTTPS: true,
// 方式一使用CA证书和私钥动态生成证书
CACert: "path/to/ca.crt", // CA证书路径
CAKey: "path/to/ca.key", // CA私钥路径
// 选择证书生成算法(可选)
UseECDSA: true, // 使用ECDSA生成证书默认为false使用RSA
// 方式二使用固定的TLS证书和私钥
// TLSCert: "path/to/server.crt",
// TLSKey: "path/to/server.key",
}
```
@@ -400,11 +591,11 @@ config := &config.Config{
```go
proxy := proxy.NewProxy(
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
// 或者使用静态TLS证书
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
proxy.WithDecryptHTTPS(&proxy.MemCertCache{}),
proxy.WithCACertAndKey("path/to/ca.crt", "path/to/ca.key"),
proxy.WithEnableECDSA(true), // 使用ECDSA生成证书
// 或者使用静态TLS证书
// proxy.WithTLSCertAndKey("path/to/server.crt", "path/to/server.key")
)
```

View File

@@ -6,6 +6,10 @@
- 动态变更后端服务器IP
- 自定义后端服务器端口
- 泛解析(通配符域名)支持
- 多后端服务器支持
- 一个域名对应多个后端服务器
- 支持多种负载均衡策略(轮询、随机、第一个可用)
- 支持动态添加/删除后端服务器
- 负载均衡和故障转移
- 绕过DNS污染
- 高效的DNS缓存
@@ -16,6 +20,8 @@
- **自定义端口**为每个域名指定自定义端口无需在URL中指定
- **泛解析**:支持通配符域名(如`*.example.com`)自动匹配多个子域名
- **多级泛解析**:支持复杂的通配符模式(如`api.*.example.com`
- **多后端支持**:支持一个域名配置多个后端服务器
- **负载均衡**:支持多种负载均衡策略
- **备用解析**在自定义记录未找到时可选择使用系统DNS
- **DNS缓存**:缓存解析结果以提高性能
- **自动重试**:解析失败时可配置重试策略
@@ -83,31 +89,37 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt
```json
{
"records": {
"example.com": "93.184.216.34",
"api.example.com": "93.184.216.35:8443",
"example.com": ["93.184.216.34", "93.184.216.35"],
"api.example.com": ["93.184.216.35:8443", "93.184.216.36:8443"],
"*.github.com": "140.82.121.3",
"github.com": "140.82.121.4",
"*.github.com": ["140.82.121.3", "140.82.121.4"],
"github.com": ["140.82.121.4", "140.82.121.5"],
"*.dev.local": "127.0.0.1:3000",
"api.*.dev.local": "127.0.0.1:3001"
"*.dev.local": ["127.0.0.1:3000", "127.0.0.1:3001"],
"api.*.dev.local": ["127.0.0.1:3001", "127.0.0.1:3002"]
},
"use_fallback": true,
"ttl": 300
"ttl": 300,
"load_balance_strategy": "round_robin"
}
```
### Hosts格式
```
# 精确匹配
# 精确匹配(多后端)
93.184.216.34 example.com
93.184.216.35 example.com
93.184.216.35:8443 api.example.com
93.184.216.36:8443 api.example.com
# 泛解析(通配符域名
# 泛解析(多后端
140.82.121.3 *.github.com
140.82.121.4 *.github.com
127.0.0.1:3000 *.dev.local
127.0.0.1:3001 *.dev.local
127.0.0.1:3001 api.*.dev.local
127.0.0.1:3002 api.*.dev.local
```
## 编程接口
@@ -118,57 +130,76 @@ go run cmd/wildcard_dns_proxy/main.go -hosts examples/wildcard_hosts.txt
import "github.com/goproxy/internal/dns"
// 创建解析器
resolver := dns.NewResolver()
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.RoundRobin), // 设置负载均衡策略
dns.WithTTL(5*time.Minute), // 设置DNS缓存TTL
)
// 添加标准记录(使用默认端口)
resolver.Add("example.com", "93.184.216.34")
// 添加多个后端服务器
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
// 添加带端口的记录
resolver.AddWithPort("api.example.com", "93.184.216.35", 8443)
// 添加泛解析记录(多后端)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
// 添加泛解析记录(使用默认端口
resolver.AddWildcard("*.example.com", "93.184.216.36")
// 添加带端口的泛解析记录
resolver.AddWildcardWithPort("*.api.example.com", "93.184.216.37", 8444)
// 解析域名只获取IP
ip, err := resolver.Resolve("example.com")
// 解析域名(使用负载均衡策略选择后端
endpoint, err := resolver.ResolveWithPort("api.example.com", 443)
if err != nil {
log.Fatalf("解析失败: %v", err)
}
fmt.Printf("解析结果IP: %s\n", ip)
fmt.Printf("解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port)
// 测试泛解析功能
ip, err = resolver.Resolve("sub.example.com")
if err != nil {
log.Fatalf("解析失败: %v", err)
}
fmt.Printf("泛解析结果IP: %s\n", ip)
// 测试多级泛解析功能
endpoint, err := resolver.ResolveWithPort("test.api.example.com", 443)
if err != nil {
log.Fatalf("解析失败: %v", err)
}
fmt.Printf("多级泛解析结果: IP=%s, 端口=%d\n", endpoint.IP, endpoint.Port)
// 动态添加/删除后端服务器
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
```
### 负载均衡策略
GoProxy支持三种负载均衡策略
1. **轮询策略Round Robin**
```go
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.RoundRobin),
)
```
2. **随机策略Random**
```go
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.Random),
)
```
3. **第一个可用策略First Available**
```go
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.FirstAvailable),
)
```
### 使用DNS拨号器
```go
import "github.com/goproxy/internal/dns"
// 创建解析器
resolver := dns.NewResolver()
resolver.Add("example.com", "93.184.216.34")
resolver.AddWildcard("*.example.com", "93.184.216.36")
resolver := dns.NewResolver(
dns.WithLoadBalanceStrategy(dns.RoundRobin),
)
resolver.AddWithPort("example.com", "192.168.1.1", 8080)
resolver.AddWithPort("example.com", "192.168.1.2", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.3", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.4", 8080)
// 创建拨号器
dialer := dns.NewDialer(resolver)
// 使用拨号器连接(会自动应用泛解析)
conn, err := dialer.Dial("tcp", "sub.example.com:443")
// 使用拨号器连接(会自动应用负载均衡)
conn, err := dialer.Dial("tcp", "api.example.com:443")
if err != nil {
log.Fatalf("连接失败: %v", err)
}
@@ -183,125 +214,116 @@ client := &http.Client{Transport: transport}
## 高级用法
### 泛解析模式
### 多后端配置模式
泛解析支持以下模式:
多后端配置支持以下模式:
1. **单级通配符**`*.example.com` 匹配 `a.example.com``b.example.com`
1. **精确匹配多后端**
```go
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.2", 8080)
resolver.AddWithPort("api.example.com", "192.168.1.3", 8080)
```
2. **多级通配符**`*.*.example.com` 匹配 `a.b.example.com``c.d.example.com`
2. **泛解析多后端**
```go
resolver.AddWildcardWithPort("*.example.com", "192.168.1.1", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
```
3. **中间通配符**`api.*.example.com` 匹配 `api.v1.example.com``api.beta.example.com`
3. **混合配置**
```go
// 精确匹配优先于泛解析
resolver.AddWithPort("api.example.com", "192.168.1.1", 8080)
resolver.AddWildcardWithPort("*.example.com", "192.168.1.2", 8080)
```
4. **前缀通配符**`api-*.example.com` 匹配 `api-v1.example.com``api-beta.example.com`
### 域名匹配优先级
当有多条规则可以匹配同一个域名时,解析器会按照以下优先级选择:
1. 精确匹配(如 `example.com`
2. 最具体的通配符匹配(如 `*.test.example.com``*.example.com` 更优先)
3. 后添加的通配符规则优先于先添加的规则
### 透明代理模式
泛解析代理支持透明模式这种模式下代理会使用请求中的Host头来进行DNS解析而不是固定的目标主机
```bash
go run cmd/wildcard_dns_proxy/main.go -target ""
```
这样通过修改请求的Host头或使用不同的域名访问代理可以自动路由到不同的后端服务器。
### 自定义DNS后端
可以实现自己的`dns.Resolver`接口来创建更复杂的DNS解析策略
### 动态管理后端服务器
```go
type MyResolver struct {
// 你的字段
}
// 添加新的后端服务器
resolver.AddWithPort("api.example.com", "192.168.1.4", 8080)
// 实现Resolver接口的所有方法
func (r *MyResolver) Resolve(hostname string) (string, error) {
// 自定义逻辑
}
// 删除特定的后端服务器
resolver.RemoveEndpoint("api.example.com", "192.168.1.1", 8080)
func (r *MyResolver) ResolveWithPort(hostname string, defaultPort int) (*dns.Endpoint, error) {
// 自定义逻辑
}
// 删除整个域名记录
resolver.Remove("api.example.com")
func (r *MyResolver) Add(hostname, ip string) error {
// 自定义逻辑
}
// 清除所有记录
resolver.Clear()
```
func (r *MyResolver) AddWithPort(hostname, ip string, port int) error {
// 自定义逻辑
}
### 健康检查集成
func (r *MyResolver) AddWildcard(wildcardDomain, ip string) error {
// 自定义逻辑
}
可以结合健康检查功能,自动剔除不健康的后端服务器:
func (r *MyResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
// 自定义逻辑
}
```go
// 创建健康检查器
healthChecker := healthcheck.NewChecker(
healthcheck.WithCheckInterval(30*time.Second),
healthcheck.WithTimeout(5*time.Second),
)
func (r *MyResolver) Remove(hostname string) error {
// 自定义逻辑
}
// 添加健康检查
healthChecker.Add("api.example.com", "192.168.1.1:8080")
healthChecker.Add("api.example.com", "192.168.1.2:8080")
func (r *MyResolver) Clear() {
// 自定义逻辑
}
// 在解析器中使用健康检查结果
resolver.SetHealthChecker(healthChecker)
```
## 应用场景
### 多环境测试
### 高可用部署
使用泛解析可以为不同环境的所有服务配置不同的后端:
使用多后端配置实现高可用:
```
# 测试环境的所有服务
192.168.1.100 *.test.example.com
# 主备模式
192.168.1.1 api.example.com # 主服务器
192.168.1.2 api.example.com # 备用服务器
# 预发布环境的所有服务
192.168.1.101 *.staging.example.com
# 负载均衡模式
192.168.1.1 api.example.com # 服务器1
192.168.1.2 api.example.com # 服务器2
192.168.1.3 api.example.com # 服务器3
```
# 生产环境的所有服务
192.168.1.102 *.production.example.com
### 多环境部署
为不同环境配置不同的后端服务器组:
```
# 测试环境
192.168.1.10 *.test.example.com
192.168.1.11 *.test.example.com
# 预发布环境
192.168.1.20 *.staging.example.com
192.168.1.21 *.staging.example.com
# 生产环境
192.168.1.30 *.production.example.com
192.168.1.31 *.production.example.com
```
### 微服务架构
为不同类型的微服务提供统一的路由模式:
为不同类型的微服务配置多个后端:
```
# 认证服务
10.0.0.1:8001 *.auth.internal
10.0.0.2:8002 *.user.internal
10.0.0.3:8003 *.payment.internal
```
10.0.0.2:8001 *.auth.internal
### 多租户系统
# 用户服务
10.0.0.3:8002 *.user.internal
10.0.0.4:8002 *.user.internal
在多租户系统中为每个租户路由到不同后端:
```
# 每个租户的子域名指向专用服务器
192.168.1.10 tenant1.*.example.com
192.168.1.11 tenant2.*.example.com
192.168.1.12 tenant3.*.example.com
```
### 本地开发环境
在本地开发中快速模拟复杂的服务架构:
```
127.0.0.1:3000 *.local
127.0.0.1:3001 api.*.local
127.0.0.1:5432 db.*.local
# 支付服务
10.0.0.5:8003 *.payment.internal
10.0.0.6:8003 *.payment.internal
```
## 注意事项
@@ -310,4 +332,7 @@ func (r *MyResolver) Clear() {
2. 泛解析规则的顺序会影响匹配结果,后添加的规则优先级更高
3. 过多的泛解析规则可能会影响性能,建议合理组织规则
4. 当域名同时匹配多个规则时,精确匹配优先于通配符匹配
5. 自签名证书会导致浏览器警告,仅用于测试目的
5. 自签名证书会导致浏览器警告,仅用于测试目的
6. 多后端配置时,建议使用健康检查确保后端服务器可用性
7. 负载均衡策略的选择应根据实际需求进行配置
8. 动态添加/删除后端服务器时,需要考虑并发安全性

View File

@@ -1,132 +0,0 @@
// Copyright 2018 ouqiang authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
package goproxy
import (
"log"
"net/http"
"net/url"
"strings"
)
// Context 代理上下文
type Context struct {
Req *http.Request
Data map[interface{}]interface{}
TunnelProxy bool
abort bool
}
func (c *Context) IsHTTPS() bool {
return c.Req.URL.Scheme == "https"
}
var defaultPorts = map[string]string{
"https": "443",
"http": "80",
"": "80",
}
func (c *Context) WebsocketUrl() *url.URL {
u := new(url.URL)
*u = *c.Req.URL
if c.IsHTTPS() {
u.Scheme = "wss"
} else {
u.Scheme = "ws"
}
return u
}
func (c *Context) Addr() string {
addr := c.Req.Host
if !strings.Contains(c.Req.URL.Host, ":") {
addr += ":" + defaultPorts[c.Req.URL.Scheme]
}
return addr
}
// Abort 中断执行
func (c *Context) Abort() {
c.abort = true
}
// IsAborted 是否已中断执行
func (c *Context) IsAborted() bool {
return c.abort
}
// Reset 重置
func (c *Context) Reset(req *http.Request) {
c.Req = req
c.Data = make(map[interface{}]interface{})
c.abort = false
c.TunnelProxy = false
}
type Delegate interface {
// Connect 收到客户端连接
Connect(ctx *Context, rw http.ResponseWriter)
// Auth 代理身份认证
Auth(ctx *Context, rw http.ResponseWriter)
// BeforeRequest HTTP请求前 设置X-Forwarded-For, 修改Header、Body
BeforeRequest(ctx *Context)
// BeforeResponse 响应发送到客户端前, 修改Header、Body、Status Code
BeforeResponse(ctx *Context, resp *http.Response, err error)
// WebSocketSendMessage websocket发送消息
WebSocketSendMessage(ctx *Context, messageType *int, p *[]byte)
// WebSockerReceiveMessage websocket接收 消息
WebSocketReceiveMessage(ctx *Context, messageType *int, p *[]byte)
// ParentProxy 上级代理
ParentProxy(*http.Request) (*url.URL, error)
// Finish 本次请求结束
Finish(ctx *Context)
// 记录错误信息
ErrorLog(err error)
}
var _ Delegate = &DefaultDelegate{}
// DefaultDelegate 默认Handler什么也不做
type DefaultDelegate struct {
Delegate
}
func (h *DefaultDelegate) Connect(ctx *Context, rw http.ResponseWriter) {}
func (h *DefaultDelegate) Auth(ctx *Context, rw http.ResponseWriter) {}
func (h *DefaultDelegate) BeforeRequest(ctx *Context) {}
func (h *DefaultDelegate) BeforeResponse(ctx *Context, resp *http.Response, err error) {}
func (h *DefaultDelegate) ParentProxy(req *http.Request) (*url.URL, error) {
return http.ProxyFromEnvironment(req)
}
// WebSocketSendMessage websocket发送消息
func (h *DefaultDelegate) WebSocketSendMessage(ctx *Context, messageType *int, payload *[]byte) {}
// WebSockerReceiveMessage websocket接收 消息
func (h *DefaultDelegate) WebSocketReceiveMessage(ctx *Context, messageType *int, payload *[]byte) {}
func (h *DefaultDelegate) Finish(ctx *Context) {}
func (h *DefaultDelegate) ErrorLog(err error) {
log.Println(err)
}

1
go.mod
View File

@@ -3,7 +3,6 @@ module github.com/goproxy
go 1.24.0
require (
github.com/ouqiang/goproxy v1.3.2
github.com/ouqiang/websocket v1.6.2
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8
golang.org/x/time v0.11.0

2
go.sum
View File

@@ -1,5 +1,3 @@
github.com/ouqiang/goproxy v1.3.2 h1:+3uBRrM0RU4LFcsH0lbWsdUCoHIzoRxk+ISPbIS3lTk=
github.com/ouqiang/goproxy v1.3.2/go.mod h1:yF0a+DlUi0Zff28iUeuqLov90bivevUX9uOn3Yk9rww=
github.com/ouqiang/websocket v1.6.2 h1:LGQIySbQO3ahZCl34v9xBVb0yncDk8yIcuEIbWBab/U=
github.com/ouqiang/websocket v1.6.2/go.mod h1:fIROJIHRlQwgCyUFTMzaaIcs4HIwUj2xlOW43u9Sf+M=
github.com/viki-org/dnscache v0.0.0-20130720023526-c70c1f23c5d8 h1:EVObHAr8DqpoJCVv6KYTle8FEImKhtkfcZetNqxDoJQ=

View File

@@ -39,18 +39,19 @@ type Resolver interface {
// CustomResolver 自定义DNS解析器
type CustomResolver struct {
mu sync.RWMutex
records map[string]*Endpoint // 精确域名到端点的映射
wildcardRules []wildcardRule // 通配符规则列表
cache map[string]cacheEntry // 外部域名解析缓存
fallback bool // 是否在本地记录找不到时回退到系统DNS
ttl time.Duration // 缓存TTL
records map[string][]*Endpoint // 精确域名到多个端点的映射
wildcardRules []wildcardRule // 通配符规则列表
cache map[string]cacheEntry // 外部域名解析缓存
fallback bool // 是否在本地记录找不到时回退到系统DNS
ttl time.Duration // 缓存TTL
lbStrategy LoadBalanceStrategy // 负载均衡策略
}
// wildcardRule 通配符规则
type wildcardRule struct {
pattern string // 原始通配符模式,如 *.example.com
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
endpoint *Endpoint // 对应的端点
pattern string // 原始通配符模式,如 *.example.com
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
endpoints []*Endpoint // 对应的多个端点
}
// cacheEntry 缓存条目
@@ -59,14 +60,24 @@ type cacheEntry struct {
expiresAt time.Time
}
// LoadBalanceStrategy 负载均衡策略
type LoadBalanceStrategy int
const (
RoundRobin LoadBalanceStrategy = iota // 轮询策略
Random // 随机策略
FirstAvailable // 第一个可用策略
)
// NewResolver 创建新的自定义DNS解析器
func NewResolver(options ...Option) *CustomResolver {
r := &CustomResolver{
records: make(map[string]*Endpoint),
records: make(map[string][]*Endpoint),
wildcardRules: make([]wildcardRule, 0),
cache: make(map[string]cacheEntry),
fallback: true,
ttl: 5 * time.Minute,
lbStrategy: RoundRobin,
}
// 应用选项
@@ -92,15 +103,15 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin
r.mu.RLock()
// 精确匹配
if endpoint, ok := r.records[host]; ok {
if endpoints, ok := r.records[host]; ok && len(endpoints) > 0 {
r.mu.RUnlock()
return endpoint, nil
return r.selectEndpoint(endpoints, defaultPort), nil
}
// 尝试通配符匹配
if endpoint := r.matchWildcard(host); endpoint != nil {
if endpoints := r.matchWildcard(host); len(endpoints) > 0 {
r.mu.RUnlock()
return endpoint, nil
return r.selectEndpoint(endpoints, defaultPort), nil
}
// 检查缓存
@@ -152,15 +163,49 @@ func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoin
return nil, errors.New("未找到域名记录且系统DNS回退被禁用")
}
// selectEndpoint 根据负载均衡策略选择一个端点
func (r *CustomResolver) selectEndpoint(endpoints []*Endpoint, defaultPort int) *Endpoint {
if len(endpoints) == 0 {
return nil
}
// 如果只有一个端点,直接返回
if len(endpoints) == 1 {
endpoint := endpoints[0]
if defaultPort > 0 {
endpoint.Port = defaultPort
}
return endpoint
}
// 根据负载均衡策略选择端点
var selected *Endpoint
switch r.lbStrategy {
case RoundRobin:
// 轮询策略:每次选择下一个端点
selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))]
case Random:
// 随机策略:随机选择一个端点
selected = endpoints[time.Now().UnixNano()%int64(len(endpoints))]
case FirstAvailable:
// 第一个可用策略:选择第一个端点
selected = endpoints[0]
}
if selected != nil && defaultPort > 0 {
selected.Port = defaultPort
}
return selected
}
// matchWildcard 尝试匹配通配符规则
func (r *CustomResolver) matchWildcard(host string) *Endpoint {
func (r *CustomResolver) matchWildcard(host string) []*Endpoint {
hostParts := strings.Split(host, ".")
// 按照通配符规则列表的顺序尝试匹配
// 规则顺序应该保证更具体的规则先匹配
for _, rule := range r.wildcardRules {
if matchDomainPattern(hostParts, rule.parts) {
return rule.endpoint
return rule.endpoints
}
}
@@ -204,7 +249,18 @@ func (r *CustomResolver) AddWithPort(host, ip string, port int) error {
r.mu.Lock()
defer r.mu.Unlock()
r.records[host] = NewEndpointWithPort(ip, port)
endpoint := NewEndpointWithPort(ip, port)
if endpoints, exists := r.records[host]; exists {
// 检查是否已存在相同的端点
for _, e := range endpoints {
if e.IP == endpoint.IP && e.Port == endpoint.Port {
return nil // 端点已存在,无需重复添加
}
}
r.records[host] = append(r.records[host], endpoint)
} else {
r.records[host] = []*Endpoint{endpoint}
}
return nil
}
@@ -230,11 +286,27 @@ func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int
r.mu.Lock()
defer r.mu.Unlock()
// 查找是否已存在相同的通配符规则
for i, rule := range r.wildcardRules {
if rule.pattern == wildcardDomain {
// 检查是否已存在相同的端点
endpoint := NewEndpointWithPort(ip, port)
for _, e := range rule.endpoints {
if e.IP == endpoint.IP && e.Port == endpoint.Port {
return nil // 端点已存在,无需重复添加
}
}
// 添加新端点
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
return nil
}
}
// 创建新的通配符规则
rule := wildcardRule{
pattern: wildcardDomain,
parts: parts,
endpoint: NewEndpointWithPort(ip, port),
pattern: wildcardDomain,
parts: parts,
endpoints: []*Endpoint{NewEndpointWithPort(ip, port)},
}
// 将新规则添加到规则列表头部,确保更新的规则优先匹配
@@ -266,12 +338,54 @@ func (r *CustomResolver) Remove(host string) error {
return errors.New("域名记录不存在")
}
// RemoveEndpoint 删除特定端点
func (r *CustomResolver) RemoveEndpoint(host, ip string, port int) error {
r.mu.Lock()
defer r.mu.Unlock()
// 尝试从精确匹配记录中删除
if endpoints, ok := r.records[host]; ok {
newEndpoints := make([]*Endpoint, 0)
for _, e := range endpoints {
if e.IP != ip || e.Port != port {
newEndpoints = append(newEndpoints, e)
}
}
if len(newEndpoints) == 0 {
delete(r.records, host)
} else {
r.records[host] = newEndpoints
}
return nil
}
// 尝试从通配符规则中删除
for i, rule := range r.wildcardRules {
if rule.pattern == host {
newEndpoints := make([]*Endpoint, 0)
for _, e := range rule.endpoints {
if e.IP != ip || e.Port != port {
newEndpoints = append(newEndpoints, e)
}
}
if len(newEndpoints) == 0 {
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
} else {
r.wildcardRules[i].endpoints = newEndpoints
}
return nil
}
}
return errors.New("域名记录不存在")
}
// Clear 清除所有解析规则
func (r *CustomResolver) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
r.records = make(map[string]*Endpoint)
r.records = make(map[string][]*Endpoint)
r.wildcardRules = make([]wildcardRule, 0)
r.cache = make(map[string]cacheEntry)
}
@@ -293,6 +407,13 @@ func WithTTL(ttl time.Duration) Option {
}
}
// WithLoadBalanceStrategy 设置负载均衡策略
func WithLoadBalanceStrategy(strategy LoadBalanceStrategy) Option {
return func(r *CustomResolver) {
r.lbStrategy = strategy
}
}
// LoadFromMap 从映射加载DNS记录
func (r *CustomResolver) LoadFromMap(records map[string]string) error {
r.mu.Lock()
@@ -310,13 +431,33 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error {
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
}
// 添加通配符规则
rule := wildcardRule{
pattern: host,
parts: strings.Split(host, "."),
endpoint: endpoint,
// 查找是否已存在相同的通配符规则
found := false
for i, rule := range r.wildcardRules {
if rule.pattern == host {
// 检查是否已存在相同的端点
for _, e := range rule.endpoints {
if e.IP == endpoint.IP && e.Port == endpoint.Port {
found = true
break
}
}
if !found {
r.wildcardRules[i].endpoints = append(r.wildcardRules[i].endpoints, endpoint)
}
break
}
}
if !found {
// 创建新的通配符规则
rule := wildcardRule{
pattern: host,
parts: strings.Split(host, "."),
endpoints: []*Endpoint{endpoint},
}
r.wildcardRules = append(r.wildcardRules, rule)
}
r.wildcardRules = append(r.wildcardRules, rule)
} else {
// 常规记录
@@ -329,7 +470,21 @@ func (r *CustomResolver) LoadFromMap(records map[string]string) error {
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
}
r.records[host] = endpoint
// 检查是否已存在相同的端点
if endpoints, exists := r.records[host]; exists {
found := false
for _, e := range endpoints {
if e.IP == endpoint.IP && e.Port == endpoint.Port {
found = true
break
}
}
if !found {
r.records[host] = append(r.records[host], endpoint)
}
} else {
r.records[host] = []*Endpoint{endpoint}
}
}
}

753
proxy.go
View File

@@ -1,753 +0,0 @@
// Copyright 2018 ouqiang authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
package goproxy
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/viki-org/dnscache"
"github.com/ouqiang/goproxy/cert"
"github.com/ouqiang/websocket"
)
const (
// 连接目标服务器超时时间
defaultTargetConnectTimeout = 5 * time.Second
// 目标服务器读写超时时间
defaultTargetReadWriteTimeout = 10 * time.Second
)
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// 隧道连接成功响应行
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
var (
bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
ctxPool = sync.Pool{
New: func() interface{} {
return new(Context)
},
}
headerPool = NewHeaderPool()
requestPool = newRequestPool()
)
type RequestPool struct {
pool sync.Pool
}
func newRequestPool() *RequestPool {
return &RequestPool{
pool: sync.Pool{
New: func() interface{} {
return new(http.Request)
},
},
}
}
func (p *RequestPool) Get() *http.Request {
req := p.pool.Get().(*http.Request)
req.Method = ""
req.URL = nil
req.Proto = ""
req.ProtoMajor = 0
req.ProtoMinor = 0
req.Header = nil
req.Body = nil
req.GetBody = nil
req.ContentLength = 0
req.TransferEncoding = nil
req.Close = false
req.Host = ""
req.Form = nil
req.PostForm = nil
req.MultipartForm = nil
req.Trailer = nil
req.RemoteAddr = ""
req.RequestURI = ""
req.TLS = nil
req.Cancel = nil
req.Response = nil
return req
}
func (p *RequestPool) Put(req *http.Request) {
if req != nil {
p.pool.Put(req)
}
}
type HeaderPool struct {
pool sync.Pool
}
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
return http.Header{}
},
},
}
}
func (p *HeaderPool) Get() http.Header {
header := p.pool.Get().(http.Header)
for k := range header {
delete(header, k)
}
return header
}
func (p *HeaderPool) Put(header http.Header) {
if header != nil {
p.pool.Put(header)
}
}
// 生成隧道建立请求行
func makeTunnelRequestLine(addr string) string {
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
}
type options struct {
disableKeepAlive bool
delegate Delegate
decryptHTTPS bool
websocketIntercept bool
certCache cert.Cache
transport *http.Transport
clientTrace *httptrace.ClientTrace
}
type Option func(*options)
// WithDisableKeepAlive 连接是否重用
func WithDisableKeepAlive(disableKeepAlive bool) Option {
return func(opt *options) {
opt.disableKeepAlive = disableKeepAlive
}
}
func WithClientTrace(t *httptrace.ClientTrace) Option {
return func(opt *options) {
opt.clientTrace = t
}
}
// WithDelegate 设置委托类
func WithDelegate(delegate Delegate) Option {
return func(opt *options) {
opt.delegate = delegate
}
}
// WithTransport 自定义http transport
func WithTransport(t *http.Transport) Option {
return func(opt *options) {
opt.transport = t
}
}
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
func WithDecryptHTTPS(c cert.Cache) Option {
return func(opt *options) {
opt.decryptHTTPS = true
opt.certCache = c
}
}
// WithEnableWebsocketIntercept 拦截websocket
func WithEnableWebsocketIntercept() Option {
return func(opt *options) {
opt.websocketIntercept = true
}
}
// New 创建proxy实例
func New(opt ...Option) *Proxy {
opts := &options{}
for _, o := range opt {
o(opts)
}
if opts.delegate == nil {
opts.delegate = &DefaultDelegate{}
}
if opts.transport == nil {
opts.transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: 100,
MaxConnsPerHost: 10,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
p := &Proxy{}
p.delegate = opts.delegate
p.websocketIntercept = opts.websocketIntercept
p.decryptHTTPS = opts.decryptHTTPS
if p.decryptHTTPS {
p.cert = cert.NewCertificate(opts.certCache, true)
}
p.transport = opts.transport
p.transport.DialContext = p.dialContext()
p.dnsCache = dnscache.New(5 * time.Minute)
p.transport.DisableKeepAlives = opts.disableKeepAlive
p.transport.Proxy = p.delegate.ParentProxy
p.clientTrace = opts.clientTrace
return p
}
// Proxy 实现了http.Handler接口
type Proxy struct {
delegate Delegate
clientConnNum int32
decryptHTTPS bool
websocketIntercept bool
cert *cert.Certificate
transport *http.Transport
clientTrace *httptrace.ClientTrace
dnsCache *dnscache.Resolver
}
var _ http.Handler = &Proxy{}
// ServeHTTP 实现了http.Handler接口
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Host == "" {
req.URL.Host = req.Host
}
atomic.AddInt32(&p.clientConnNum, 1)
ctx := ctxPool.Get().(*Context)
ctx.Reset(req)
defer func() {
p.delegate.Finish(ctx)
ctxPool.Put(ctx)
atomic.AddInt32(&p.clientConnNum, -1)
}()
p.delegate.Connect(ctx, rw)
if ctx.abort {
return
}
p.delegate.Auth(ctx, rw)
if ctx.abort {
return
}
switch {
case ctx.Req.Method == http.MethodConnect:
p.tunnelProxy(ctx, rw)
case websocket.IsWebSocketUpgrade(ctx.Req):
p.tunnelProxy(ctx, rw)
default:
p.httpProxy(ctx, rw)
}
}
// ClientConnNum 获取客户端连接数
func (p *Proxy) ClientConnNum() int32 {
return atomic.LoadInt32(&p.clientConnNum)
}
// DoRequest 执行HTTP请求并调用responseFunc处理response
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
if ctx.Data == nil {
ctx.Data = make(map[interface{}]interface{})
}
p.delegate.BeforeRequest(ctx)
if ctx.abort {
return
}
newReq := requestPool.Get()
*newReq = *ctx.Req
newHeader := headerPool.Get()
CloneHeader(newReq.Header, newHeader)
newReq.Header = newHeader
for _, item := range hopHeaders {
if newReq.Header.Get(item) != "" {
newReq.Header.Del(item)
}
}
if p.clientTrace != nil {
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
}
resp, err := p.transport.RoundTrip(newReq)
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.abort {
return
}
if err == nil {
for _, h := range hopHeaders {
resp.Header.Del(h)
}
}
responseFunc(resp, err)
headerPool.Put(newHeader)
requestPool.Put(newReq)
}
// HTTP代理
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
ctx.Req.URL.Scheme = "http"
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = resp.Body.Close()
}()
CopyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
buf := bufPool.Get().([]byte)
_, _ = io.CopyBuffer(rw, resp.Body, buf)
bufPool.Put(buf)
})
}
// HTTPS代理
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
return
}
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
_, _ = tlsClientConn.Write(badGateway)
return
}
err = resp.Write(tlsClientConn)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
}
_ = resp.Body.Close()
})
}
// 隧道代理
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
clientConn, err := hijacker(rw)
if err != nil {
p.delegate.ErrorLog(err)
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = clientConn.Close()
}()
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, clientConn)
return
}
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
if parentProxyURL == nil {
_, err = clientConn.Write(tunnelEstablishedResponseLine)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
return
}
}
isWebsocket := p.detectConnProtocol(clientConn)
if isWebsocket {
req, err := http.ReadRequest(clientConn.BufferReader())
if err != nil {
if err != io.EOF {
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
req.RemoteAddr = ctx.Req.RemoteAddr
req.URL.Scheme = "http"
req.URL.Host = req.Host
ctx.Req = req
p.websocketProxy(ctx, clientConn)
return
}
var tlsClientConn *tls.Conn
if p.decryptHTTPS {
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
return
}
tlsClientConn = tls.Server(clientConn, tlsConfig)
defer func() {
_ = tlsClientConn.Close()
}()
if err := tlsClientConn.Handshake(); err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
return
}
buf := bufio.NewReader(tlsClientConn)
tlsReq, err := http.ReadRequest(buf)
if err != nil {
if err != io.EOF {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
tlsReq.URL.Scheme = "https"
tlsReq.URL.Host = tlsReq.Host
ctx.Req = tlsReq
}
targetAddr := ctx.Req.URL.Host
if parentProxyURL != nil {
targetAddr = parentProxyURL.Host
}
if !strings.Contains(targetAddr, ":") {
targetAddr += ":443"
}
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
return
}
defer func() {
_ = targetConn.Close()
}()
if parentProxyURL != nil {
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
_, _ = targetConn.Write([]byte(tunnelRequestLine))
}
if p.decryptHTTPS {
p.httpsProxy(ctx, tlsClientConn)
} else {
p.tunnelConnected(ctx, nil)
p.transfer(clientConn, targetConn)
}
}
// WebSocket代理
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
if !p.websocketIntercept {
remoteAddr := ctx.Addr()
var err error
var targetConn net.Conn
if ctx.IsHTTPS() {
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
} else {
targetConn, err = net.Dial("tcp", remoteAddr)
}
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
err = ctx.Req.Write(targetConn)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transfer(srcConn, targetConn)
return
}
up := &websocket.Upgrader{
HandshakeTimeout: defaultTargetConnectTimeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
u := ctx.WebsocketUrl()
d := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
}
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
defer cancel()
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
}
// 探测连接协议
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
methodBytes, err := connBuf.Peek(3)
if err != nil {
return false
}
method := string(methodBytes)
if method != http.MethodGet {
return false
}
return true
}
// webSocket双向转发
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
doneCtx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := srcConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
err = targetConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
}
}()
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
err = srcConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
}
}
// 双向转发
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
go func() {
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = src.Close()
_ = dst.Close()
}()
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = dst.Close()
_ = src.Close()
}
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
ctx.TunnelProxy = true
p.delegate.BeforeRequest(ctx)
if err != nil {
p.delegate.BeforeResponse(ctx, nil, err)
return
}
resp := &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: http.NoBody,
}
p.delegate.BeforeResponse(ctx, resp, nil)
}
func (p *Proxy) dialContext() DialContext {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: defaultTargetConnectTimeout,
}
separator := strings.LastIndex(addr, ":")
ips, err := p.dnsCache.Fetch(addr[:separator])
if err != nil {
return nil, err
}
var ip string
for _, item := range ips {
ip = item.String()
if !strings.Contains(ip, ":") {
break
}
}
addr = ip + addr[separator:]
return dialer.DialContext(ctx, network, addr)
}
}
// 获取底层连接
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
return nil, fmt.Errorf("http server不支持Hijacker")
}
conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijacker错误: %s", err)
}
return NewConnBuffer(conn, buf), nil
}
// CopyHeader 浅拷贝Header
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// CloneHeader 深拷贝Header
func CloneHeader(h http.Header, h2 http.Header) {
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
}
var hopHeaders = []string{
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
}
type ConnBuffer struct {
net.Conn
buf *bufio.ReadWriter
}
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
if buf == nil {
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
}
return &ConnBuffer{
Conn: conn,
buf: buf,
}
}
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
return cb.buf.Reader
}
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
return cb.buf.Read(b)
}
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
return cb.buf.Peek(n)
}
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
n, err = cb.buf.Write(p)
if err != nil {
return 0, err
}
return n, cb.buf.Flush()
}
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return cb.Conn, cb.buf, nil
}
func (cb *ConnBuffer) WriteHeader(_ int) {}
func (cb *ConnBuffer) Header() http.Header { return nil }