diff --git a/cmd/reverse_proxy_example/main.go b/cmd/reverse_proxy_example/main.go index ebbd172..f6d20b9 100644 --- a/cmd/reverse_proxy_example/main.go +++ b/cmd/reverse_proxy_example/main.go @@ -52,7 +52,6 @@ func main() { cfg.ReverseProxy = true cfg.EnableCache = *enableCache cfg.EnableCompression = *enableCompression - cfg.EnableURLRewrite = *enableRewrite cfg.AddXForwardedFor = *addXForwardedFor cfg.AddXRealIP = *addXRealIP cfg.EnableCORS = *enableCORS diff --git a/examples/rewriter/rules.json b/examples/rewriter/rules.json index 95a3997..db88ab9 100644 --- a/examples/rewriter/rules.json +++ b/examples/rewriter/rules.json @@ -27,4 +27,4 @@ "description": "将语言路径转换为查询参数", "enabled": true } -] \ No newline at end of file +] \ No newline at end of file diff --git a/internal/proxy/pool.go b/internal/proxy/pool.go new file mode 100644 index 0000000..8edd86c --- /dev/null +++ b/internal/proxy/pool.go @@ -0,0 +1,26 @@ +package proxy + +import "sync" + +// 泛型对象池 +type Pool[T any] struct { + pool sync.Pool +} + +func newPool[T any](newFunc func() T) *Pool[T] { + return &Pool[T]{ + pool: sync.Pool{ + New: func() interface{} { + return newFunc() + }, + }, + } +} + +func (p *Pool[T]) Get() T { + return p.pool.Get().(T) +} + +func (p *Pool[T]) Put(x T) { + p.pool.Put(x) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index e6cde52..019ca44 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -45,17 +45,13 @@ var badGatewayResponse = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.Statu // 对象池 var ( - bufPool = sync.Pool{ - New: func() interface{} { - return make([]byte, 32*1024) - }, - } + bufPool = newPool(func() []byte { + return make([]byte, 32*1024) + }) - ctxPool = sync.Pool{ - New: func() interface{} { - return new(Context) - }, - } + ctxPool = newPool(func() *Context { + return new(Context) + }) ) // CertificateCache 证书缓存接口 @@ -389,7 +385,7 @@ func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } // 处理请求 - ctx := ctxPool.Get().(*Context) + ctx := ctxPool.Get() ctx.Reset(req) defer ctxPool.Put(ctx) @@ -530,7 +526,10 @@ func (p *Proxy) handleHTTP(ctx *Context, rw http.ResponseWriter) { rw.WriteHeader(resp.StatusCode) // 复制响应体 - _, err = io.Copy(rw, resp.Body) + buf := bufPool.Get() + defer bufPool.Put(buf) + + _, err = io.CopyBuffer(rw, resp.Body, buf) if err != nil { p.delegate.ErrorLog(fmt.Errorf("%s - 复制响应体错误: %s", req.URL.Host, err)) } @@ -1044,7 +1043,7 @@ func (p *Proxy) transfer(src net.Conn, dst net.Conn) { // src -> dst go func() { - buf := bufPool.Get().([]byte) + buf := bufPool.Get() written, err := io.CopyBuffer(dst, src, buf) if err != nil { p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err)) @@ -1062,7 +1061,7 @@ func (p *Proxy) transfer(src net.Conn, dst net.Conn) { // dst -> src go func() { - buf := bufPool.Get().([]byte) + buf := bufPool.Get() written, err := io.CopyBuffer(src, dst, buf) if err != nil { p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))