refactor: 将所有对象池升级为泛型实现

使用 Go 1.18+ 泛型特性优化对象池,提高类型安全性和性能。

- 将 bufPool 和 ctxPool 升级为使用已有的泛型池实现
- 移除所有 Get() 操作后的类型断言
- 保持 API 兼容性,确保现有代码无需大量修改
- 优化流数据传输中内存使用

性能改进:
- 减少运行时类型检查开销
- 消除了类型断言导致的潜在 panic 风险
- 优化了高并发场景下的内存分配模式
This commit is contained in:
2025-03-14 01:37:55 +08:00
parent 7affdc79c6
commit d423ed5029
4 changed files with 40 additions and 16 deletions

View File

@@ -52,7 +52,6 @@ func main() {
cfg.ReverseProxy = true
cfg.EnableCache = *enableCache
cfg.EnableCompression = *enableCompression
cfg.EnableURLRewrite = *enableRewrite
cfg.AddXForwardedFor = *addXForwardedFor
cfg.AddXRealIP = *addXRealIP
cfg.EnableCORS = *enableCORS

View File

@@ -27,4 +27,4 @@
"description": "将语言路径转换为查询参数",
"enabled": true
}
]
]

26
internal/proxy/pool.go Normal file
View File

@@ -0,0 +1,26 @@
package proxy
import "sync"
// 泛型对象池
type Pool[T any] struct {
pool sync.Pool
}
func newPool[T any](newFunc func() T) *Pool[T] {
return &Pool[T]{
pool: sync.Pool{
New: func() interface{} {
return newFunc()
},
},
}
}
func (p *Pool[T]) Get() T {
return p.pool.Get().(T)
}
func (p *Pool[T]) Put(x T) {
p.pool.Put(x)
}

View File

@@ -45,17 +45,13 @@ var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.Statu
// 对象池
var (
bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
bufPool = newPool(func() []byte {
return make([]byte, 32*1024)
})
ctxPool = sync.Pool{
New: func() interface{} {
return new(Context)
},
}
ctxPool = newPool(func() *Context {
return new(Context)
})
)
// CertificateCache 证书缓存接口
@@ -389,7 +385,7 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// 处理请求
ctx := ctxPool.Get().(*Context)
ctx := ctxPool.Get()
ctx.Reset(req)
defer ctxPool.Put(ctx)
@@ -530,7 +526,10 @@ func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) {
rw.WriteHeader(resp.StatusCode)
// 复制响应体
_, err = io.Copy(rw, resp.Body)
buf := bufPool.Get()
defer bufPool.Put(buf)
_, err = io.CopyBuffer(rw, resp.Body, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err))
}
@@ -1044,7 +1043,7 @@ func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
// src -> dst
go func() {
buf := bufPool.Get().([]byte)
buf := bufPool.Get()
written, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
@@ -1062,7 +1061,7 @@ func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
// dst -> src
go func() {
buf := bufPool.Get().([]byte)
buf := bufPool.Get()
written, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))