diff --git a/client.go b/client.go index 4d9a3ab..9b3e333 100644 --- a/client.go +++ b/client.go @@ -194,7 +194,11 @@ func (obj *Client) SetGetProxy(getProxy func(ctx context.Context, url *url.URL) } // 关闭客户端 +func (obj *Client) CloseIdleConnections() { + obj.transport.CloseIdleConnections() +} func (obj *Client) Close() { + obj.CloseIdleConnections() obj.cnl() } diff --git a/roundTripper.go b/roundTripper.go index 3b0bce7..8a3e87a 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -113,9 +113,10 @@ func (obj *RoundTripper) newConnPool(key string, conn *connecotr) *connPool { go pool.rwMain(conn) return pool } -func (obj *RoundTripper) delConnPool(key string) { +func (obj *RoundTripper) delConnPool(key string, pool *connPool) { obj.connsLock.Lock() defer obj.connsLock.Unlock() + pool.Close() delete(obj.connPools, key) } func (obj *RoundTripper) getConnPool(key string) *connPool { @@ -303,13 +304,14 @@ func (obj *connPool) rwMain(conn *connecotr) { conn.ctx, conn.cnl = context.WithCancel(obj.ctx) defer func() { if obj.total.Load() == 0 { - obj.rt.delConnPool(obj.key) + obj.rt.delConnPool(obj.key, obj) } }() defer obj.total.Add(-1) defer conn.Close() wait := time.NewTimer(0) defer wait.Stop() + defer conn.cnl() go func() { defer conn.cnl() for { @@ -460,6 +462,12 @@ func (obj *RoundTripper) poolRoundTrip(task *reqTask, key string) (bool, error) } return false, nil } + +func (obj *RoundTripper) CloseIdleConnections() { + for key, pool := range obj.connPools { + obj.delConnPool(key, pool) + } +} func (obj *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctxData := req.Context().Value(keyPrincipalID).(*reqCtxData) if ctxData.requestCallBack != nil {