From 90356e076cd34b5177764ee82266a9ff1946d675 Mon Sep 17 00:00:00 2001 From: bxd <2216403312@qq.com> Date: Sat, 2 Dec 2023 16:51:46 +0800 Subject: [PATCH] optimize code --- body.go | 18 ++++----- client.go | 2 + conn.go | 82 ++++++++++++++------------------------- go.mod | 6 +-- go.sum | 6 +++ headers.go | 12 ------ option.go | 97 ++++++++++++++++++++++++++++++++++++++++++---- requests.go | 95 ++++++++++++++++++++++----------------------- response.go | 29 ++++++++------ roundTripper.go | 101 +++++++++++++++++++++++++----------------------- rw.go | 13 ++++--- tools.go | 68 ++++++++++++++++++++++++++------ 12 files changed, 316 insertions(+), 213 deletions(-) diff --git a/body.go b/body.go index 881c904..c75262f 100644 --- a/body.go +++ b/body.go @@ -256,15 +256,12 @@ func (obj *orderMap) MarshalJSON() ([]byte, error) { } func any2Map(val any) map[string]any { - mapType := reflect.TypeOf(val) - if mapType.Kind() != reflect.Map { + if reflect.TypeOf(val).Kind() != reflect.Map { return nil } mapValue := reflect.ValueOf(val) - keys := mapValue.MapKeys() result := make(map[string]any) - for _, key := range keys { - keyData := key.Interface() + for _, key := range mapValue.MapKeys() { valueData := mapValue.MapIndex(key).Interface() sliceValue := reflect.ValueOf(valueData) if sliceValue.Kind() == reflect.Slice { @@ -273,13 +270,14 @@ func any2Map(val any) map[string]any { valueData2 = append(valueData2, sliceValue.Index(i).Interface()) } valueData = valueData2 + } else { + result[fmt.Sprint(key.Interface())] = valueData } - result[fmt.Sprint(keyData)] = valueData } return result } -func (obj *RequestOption) newBody(val any, valType int) (io.Reader, *orderMap, []string, error) { +func (obj *RequestOption) newBody(val any, valType int) (reader io.Reader, parseOrderMap *orderMap, orderKey []string, err error) { if reader, ok := val.(io.Reader); ok { obj.once = true return reader, nil, nil, nil @@ -301,6 +299,7 @@ func (obj *RequestOption) newBody(val any, valType int) (io.Reader, *orderMap, [ if mapData := any2Map(val); mapData != nil { val = mapData } +mapL: switch value := val.(type) { case *gson.Client: if !value.IsObject() { @@ -336,9 +335,8 @@ func (obj *RequestOption) newBody(val any, valType int) (io.Reader, *orderMap, [ } return nil, orderMap, nil, nil } - result, err := gson.Decode(val) - if err != nil { + if val, err = gson.Decode(val); err != nil { return nil, nil, nil, err } - return obj.newBody(result, valType) + goto mapL } diff --git a/client.go b/client.go index ef0c64d..3e21eb8 100644 --- a/client.go +++ b/client.go @@ -16,6 +16,7 @@ type Client struct { ctx context.Context cnl context.CancelFunc transport *roundTripper + closed bool } var defaultClient, _ = NewClient(nil) @@ -82,6 +83,7 @@ func (obj *Client) ForceCloseConns() { // Close the client and cannot be used again after shutdown func (obj *Client) Close() { + obj.closed = true obj.ForceCloseConns() obj.cnl() } diff --git a/conn.go b/conn.go index b7da29c..3846624 100644 --- a/conn.go +++ b/conn.go @@ -9,17 +9,14 @@ import ( "net/http" "sync" "sync/atomic" - "time" "github.com/gospider007/net/http2" "github.com/gospider007/tools" ) type connecotr struct { - connKey connKey deleteCtx context.Context //force close deleteCnl context.CancelCauseFunc - afterTime *time.Timer closeCtx context.Context //safe close closeCnl context.CancelCauseFunc @@ -28,27 +25,21 @@ type connecotr struct { rawConn net.Conn h2RawConn *http2.ClientConn - - r *bufio.Reader - w *bufio.Writer - pr *pipCon - inPool bool + proxy string + r *bufio.Reader + w *bufio.Writer + pr *pipCon + inPool bool } -func newConnecotr(ctx context.Context, netConn net.Conn) *connecotr { - conne := new(connecotr) - conne.withCancel(ctx, ctx) - conne.rawConn = netConn - return conne -} func (obj *connecotr) withCancel(deleteCtx context.Context, closeCtx context.Context) { obj.deleteCtx, obj.deleteCnl = context.WithCancelCause(deleteCtx) obj.closeCtx, obj.closeCnl = context.WithCancelCause(closeCtx) } func (obj *connecotr) Close() error { obj.deleteCnl(errors.New("connecotr close")) - if obj.afterTime != nil { - obj.afterTime.Stop() + if obj.pr != nil { + obj.pr.Close(errors.New("connecotr close")) } if obj.h2RawConn != nil { obj.h2RawConn.Close() @@ -64,8 +55,7 @@ func (obj *connecotr) read() (err error) { if _, err = io.Copy(pw, obj.rawConn); err == nil { err = io.EOF } - pw.cnl(err) - obj.pr.cnl(err) + pw.Close(err) obj.Close() return } @@ -95,17 +85,18 @@ func (obj *connecotr) wrapBody(task *reqTask) { } func (obj *connecotr) http1Req(task *reqTask) { task.err = httpWrite(task.req, obj.w, task.orderHeaders) - // if task.err = task.req.Write(obj); task.err == nil { + // if task.err = task.req.Write(obj.w); task.err == nil { // task.err = obj.w.Flush() // } if task.err == nil { - if task.res, task.err = http.ReadResponse(obj.r, task.req); task.res != nil && task.err == nil { - obj.wrapBody(task) - } else if task.err != nil { + task.res, task.err = http.ReadResponse(obj.r, task.req) + if task.err != nil { task.err = tools.WrapError(task.err, "http1 read error") + } else if task.res == nil { + task.err = errors.New("response is nil") + } else { + obj.wrapBody(task) } - } else { - task.err = tools.WrapError(task.err, "http1 write error") } task.cnl() } @@ -130,14 +121,12 @@ func (obj *connecotr) waitBodyClose() error { func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) { if obj.h2Closed() { obj.Close() - task.err = errors.New("conn is closed") return true } if !waitBody { select { case <-obj.closeCtx.Done(): obj.Close() - task.err = tools.WrapError(obj.closeCtx.Err(), "conn close ctx error: ") return true default: } @@ -147,11 +136,6 @@ func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) { } else { go obj.http1Req(task) } - if obj.afterTime == nil { - obj.afterTime = time.NewTimer(task.responseHeaderTimeout) - } else { - obj.afterTime.Reset(task.responseHeaderTimeout) - } select { case <-task.ctx.Done(): if task.err != nil { @@ -159,8 +143,8 @@ func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) { return false } if task.res == nil { - obj.Close() task.err = errors.New("response is nil") + obj.Close() return false } if waitBody { @@ -168,13 +152,8 @@ func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) { } return false case <-obj.deleteCtx.Done(): //force conn close - task.cnl() task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ") - obj.Close() - return false - case <-obj.afterTime.C: task.cnl() - task.err = errors.New("response Header is Timeout") obj.Close() return false } @@ -185,11 +164,12 @@ type connPool struct { deleteCnl context.CancelCauseFunc closeCtx context.Context closeCnl context.CancelCauseFunc - connKey connKey + connKey string total atomic.Int64 tasks chan *reqTask connPools *connPools } + type connPools struct { connPools sync.Map } @@ -197,28 +177,26 @@ type connPools struct { func newConnPools() *connPools { return new(connPools) } -func (obj *connPools) get(key connKey) *connPool { + +func (obj *connPools) get(key string) *connPool { val, ok := obj.connPools.Load(key) if !ok { return nil } - pool := val.(*connPool) - select { - case <-pool.closeCtx.Done(): - return nil - default: - return pool - } + return val.(*connPool) } -func (obj *connPools) set(key connKey, pool *connPool) { + +func (obj *connPools) set(key string, pool *connPool) { obj.connPools.Store(key, pool) } -func (obj *connPools) del(key connKey) { + +func (obj *connPools) del(key string) { obj.connPools.Delete(key) } -func (obj *connPools) iter(f func(key connKey, value *connPool) bool) { + +func (obj *connPools) iter(f func(key string, value *connPool) bool) { obj.connPools.Range(func(key, value any) bool { - return f(key.(connKey), value.(*connPool)) + return f(key.(string), value.(*connPool)) }) } @@ -260,10 +238,10 @@ func (obj *connPool) rwMain(conn *connecotr) { } } func (obj *connPool) forceClose() { - obj.deleteCnl(errors.New("connPool forceClose")) obj.close() + obj.deleteCnl(errors.New("connPool forceClose")) } func (obj *connPool) close() { - obj.closeCnl(errors.New("connPool close")) obj.connPools.del(obj.connKey) + obj.closeCnl(errors.New("connPool close")) } diff --git a/go.mod b/go.mod index 7f11444..a6aa0b4 100644 --- a/go.mod +++ b/go.mod @@ -7,10 +7,10 @@ require ( github.com/gospider007/bs4 v0.0.0-20231123090151-001db0b91941 github.com/gospider007/gson v0.0.0-20231119141525-66095080057d github.com/gospider007/gtls v0.0.0-20231120122450-e763299259db - github.com/gospider007/ja3 v0.0.0-20231029025157-38fc2f8f2d91 + github.com/gospider007/ja3 v0.0.0-20231202085054-c1b92675187e github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1 github.com/gospider007/re v0.0.0-20231024115818-adfd03636256 - github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616 + github.com/gospider007/tools v0.0.0-20231202084937-8b2bc66f8198 github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4 github.com/refraction-networking/utls v1.5.4 golang.org/x/exp v0.0.0-20231127185646-65229373498e @@ -27,7 +27,7 @@ require ( github.com/gospider007/blog v0.0.0-20231121084103-59a004dafccf // indirect github.com/gospider007/kinds v0.0.0-20231024093643-7a4424f2d30e // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.17.3 // indirect + github.com/klauspost/compress v1.17.4 // indirect github.com/klauspost/cpuid/v2 v2.2.6 // indirect github.com/libdns/libdns v0.2.1 // indirect github.com/mholt/acmez v1.2.0 // indirect diff --git a/go.sum b/go.sum index fb02129..e7a136b 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/gospider007/gtls v0.0.0-20231120122450-e763299259db h1:8ghU2J0g7BKC1t github.com/gospider007/gtls v0.0.0-20231120122450-e763299259db/go.mod h1:yE9d9KMYxJEQawkOOHBea03dN89uo3hWQxYiv9tnz0A= github.com/gospider007/ja3 v0.0.0-20231029025157-38fc2f8f2d91 h1:qQokihfTAX+/U8GIMvZauRtE4G+/1Jq8XIJx8xLr04A= github.com/gospider007/ja3 v0.0.0-20231029025157-38fc2f8f2d91/go.mod h1:ur78/uhYDDULSy1ldA/pPpGhjk973Q1VsPnbktXGU/g= +github.com/gospider007/ja3 v0.0.0-20231202085054-c1b92675187e h1:6Brr8+E6fht8IVFBJBkUoFMr8HQ9yzsrVxt8m/wVtQ8= +github.com/gospider007/ja3 v0.0.0-20231202085054-c1b92675187e/go.mod h1:kWf9x0hQS+pgOpu1lRiVdE5nozLj71j740cWvjzKqok= github.com/gospider007/kinds v0.0.0-20231024093643-7a4424f2d30e h1:lmX6IQKkrNDbXfHsvrv1Uz0MoG2v5+4VC6Gdh9irUNY= github.com/gospider007/kinds v0.0.0-20231024093643-7a4424f2d30e/go.mod h1:nB4OMmd8Ji92yEmgjbHcqLcBHTAhSSmlGNb2JpTYK9A= github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1 h1:tYOQEvELrV+USjKGsAroC1cvsLMHgGlUPQY1TKS/PDM= @@ -49,6 +51,8 @@ github.com/gospider007/tools v0.0.0-20231128142841-23217c299fc2 h1:io2bXntt5LSwW github.com/gospider007/tools v0.0.0-20231128142841-23217c299fc2/go.mod h1:wiILK6EotceHz/Rnb6ux8PzY3sr5OV+mYuIcbtxpkYI= github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616 h1:Ix8hbbbaIX9REGs0qqU1b48L3BlleDyNkCPdv297LF8= github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616/go.mod h1:wiILK6EotceHz/Rnb6ux8PzY3sr5OV+mYuIcbtxpkYI= +github.com/gospider007/tools v0.0.0-20231202084937-8b2bc66f8198 h1:phk1GNobIIQWL5/G5dtgs35hotucdYv2FScMjlHHZ+Q= +github.com/gospider007/tools v0.0.0-20231202084937-8b2bc66f8198/go.mod h1:wiILK6EotceHz/Rnb6ux8PzY3sr5OV+mYuIcbtxpkYI= github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4 h1:h+74nkhhTDN2tiaDjHwR4CjqBTHgh+t1pqE2IAWHN3k= github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4/go.mod h1:OncvZIlq9TzwD/tQS/BYY/RKBqbW4+gGY3Ere1K7s24= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -56,6 +60,8 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= +github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc= github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= diff --git a/headers.go b/headers.go index 335a005..2cfb848 100644 --- a/headers.go +++ b/headers.go @@ -12,18 +12,6 @@ const ( AcceptLanguage = "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6" ) -// get default headers -func defaultHeaders() http.Header { - return http.Header{ - "User-Agent": []string{UserAgent}, - "Accept": []string{"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"}, - "Accept-Encoding": []string{"gzip, deflate, br"}, - "Accept-Language": []string{AcceptLanguage}, - "Sec-Ch-Ua": []string{SecChUa}, - "Sec-Ch-Ua-Mobile": []string{"?0"}, - "Sec-Ch-Ua-Platform": []string{`"Windows"`}, - } -} func (obj *RequestOption) initHeaders() (http.Header, error) { if obj.Headers == nil { return nil, nil diff --git a/option.go b/option.go index bab6a49..e96191c 100644 --- a/option.go +++ b/option.go @@ -10,24 +10,22 @@ import ( "github.com/gospider007/gtls" "github.com/gospider007/ja3" - "github.com/gospider007/tools" "github.com/gospider007/websocket" ) // Connection Management Options type ClientOption struct { - ForceHttp1 bool //force use http1 send requests OrderHeaders []string //order headers with http1 - Ja3 bool //enable ja3 fingerprint Ja3Spec ja3.Ja3Spec //custom ja3Spec,use ja3.CreateSpecWithStr or ja3.CreateSpecWithId create H2Ja3Spec ja3.H2Ja3Spec //h2 fingerprint Proxy string //proxy,support https,http,socks5 + ForceHttp1 bool //force use http1 send requests + Ja3 bool //enable ja3 fingerprint DisCookie bool //disable cookies DisDecode bool //disable auto decode DisUnZip bool //disable auto zip decode DisAlive bool //disable keepalive Bar bool ////enable bar display - Timeout time.Duration //request timeout OptionCallBack func(ctx context.Context, client *Client, option *RequestOption) error //option callback,if error is returnd, break request ResultCallBack func(ctx context.Context, client *Client, response *Response) error //result callback,if error is returnd,next errCallback ErrCallBack func(ctx context.Context, client *Client, response *Response, err error) error //error callback,if error is returnd,break request @@ -35,6 +33,7 @@ type ClientOption struct { MaxRetries int //try num MaxRedirect int //redirect num ,<0 no redirect,==0 no limit Headers any //default headers + Timeout time.Duration //request timeout ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:30 TlsHandshakeTimeout time.Duration //tls timeout,default:15 @@ -53,18 +52,17 @@ type ClientOption struct { // Options for sending requests type RequestOption struct { - ForceHttp1 bool //force use http1 send requests OrderHeaders []string //order headers with http1 - Ja3 bool //enable ja3 fingerprint Ja3Spec ja3.Ja3Spec //custom ja3Spec,use ja3.CreateSpecWithStr or ja3.CreateSpecWithId create H2Ja3Spec ja3.H2Ja3Spec //custom h2 fingerprint Proxy string //proxy,support http,https,socks5,example:http://127.0.0.1:7005 + ForceHttp1 bool //force use http1 send requests + Ja3 bool //enable ja3 fingerprint DisCookie bool //disable cookies,not use cookies DisDecode bool //disable auto decode DisUnZip bool //disable auto zip decode DisAlive bool //disable keepalive Bar bool //enable bar display - Timeout time.Duration //request timeout OptionCallBack func(ctx context.Context, client *Client, option *RequestOption) error //option callback,if error is returnd, break request ResultCallBack func(ctx context.Context, client *Client, response *Response) error //result callback,if error is returnd,next errCallback ErrCallBack func(ctx context.Context, client *Client, response *Response, err error) error //error callback,if error is returnd,break request @@ -73,6 +71,7 @@ type RequestOption struct { MaxRetries int //try num MaxRedirect int //redirect num ,<0 no redirect,==0 no limit Headers any //request headers:json,map,header + Timeout time.Duration //request timeout ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:30 TlsHandshakeTimeout time.Duration @@ -174,7 +173,89 @@ func (obj *RequestOption) initParams() (*url.URL, error) { return pu, nil } func (obj *Client) newRequestOption(option RequestOption) RequestOption { - tools.Merge(&option, obj.option) + // start + if option.OrderHeaders == nil { + option.OrderHeaders = obj.option.OrderHeaders + } + if !option.Ja3Spec.IsSet() { + option.Ja3Spec = obj.option.Ja3Spec + } + if !option.H2Ja3Spec.IsSet() { + option.H2Ja3Spec = obj.option.H2Ja3Spec + } + if option.Proxy == "" { + option.Proxy = obj.option.Proxy + } + if !option.ForceHttp1 { + option.ForceHttp1 = obj.option.ForceHttp1 + } + if !option.Ja3 { + option.Ja3 = obj.option.Ja3 + } + if !option.DisCookie { + option.DisCookie = obj.option.DisCookie + } + if !option.DisDecode { + option.DisDecode = obj.option.DisDecode + } + if !option.DisUnZip { + option.DisUnZip = obj.option.DisUnZip + } + if !option.DisAlive { + option.DisAlive = obj.option.DisAlive + } + if !option.Bar { + option.Bar = obj.option.Bar + } + if option.OptionCallBack == nil { + option.OptionCallBack = obj.option.OptionCallBack + } + if option.ResultCallBack == nil { + option.ResultCallBack = obj.option.ResultCallBack + } + if option.ErrCallBack == nil { + option.ErrCallBack = obj.option.ErrCallBack + } + if option.RequestCallBack == nil { + option.RequestCallBack = obj.option.RequestCallBack + } + if option.MaxRetries == 0 { + option.MaxRetries = obj.option.MaxRetries + } + if option.MaxRedirect == 0 { + option.MaxRedirect = obj.option.MaxRedirect + } + if option.Headers == nil { + option.Headers = obj.option.Headers + } + if option.Timeout == 0 { + option.Timeout = obj.option.Timeout + } + if option.ResponseHeaderTimeout == 0 { + option.ResponseHeaderTimeout = obj.option.ResponseHeaderTimeout + } + if option.TlsHandshakeTimeout == 0 { + option.TlsHandshakeTimeout = obj.option.TlsHandshakeTimeout + } + if option.DialTimeout == 0 { + option.DialTimeout = obj.option.DialTimeout + } + if option.KeepAlive == 0 { + option.KeepAlive = obj.option.KeepAlive + } + if option.LocalAddr == nil { + option.LocalAddr = obj.option.LocalAddr + } + if option.Dns == nil { + option.Dns = obj.option.Dns + } + if option.AddrType == 0 { + option.AddrType = obj.option.AddrType + } + if option.Jar == nil { + option.Jar = obj.option.Jar + } + //end if option.MaxRetries < 0 { option.MaxRetries = 0 } diff --git a/requests.go b/requests.go index 8290e76..0b7a089 100644 --- a/requests.go +++ b/requests.go @@ -3,15 +3,14 @@ package requests import ( "context" "errors" - "fmt" "io" "net" + "strings" "time" "net/textproto" "net/url" "os" - "strings" "net/http" @@ -71,22 +70,22 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err ctxData.disProxy = option.DisProxy ctxData.tlsHandshakeTimeout = option.TlsHandshakeTimeout ctxData.orderHeaders = option.OrderHeaders + //init scheme if option.Url != nil { - switch option.Url.Scheme { - case "ws": + if option.Url.Scheme == "ws" { ctxData.isWs = true option.Url.Scheme = "http" - case "wss": + } else if option.Url.Scheme == "wss" { ctxData.isWs = true option.Url.Scheme = "https" } } - //init tls timeout if option.TlsHandshakeTimeout == 0 { ctxData.tlsHandshakeTimeout = time.Second * 15 } + //init orderHeaders,this must after init headers if ctxData.orderHeaders == nil { ctxData.orderHeaders = ja3.DefaultH1OrderHeaders() @@ -106,7 +105,6 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err ctxData.orderHeaders = orderHeaders } //init proxy - if option.Proxy != "" { tempProxy, err := gtls.VerifyProxy(option.Proxy) if err != nil { @@ -219,9 +217,9 @@ func (obj *Client) Trace(ctx context.Context, href string, options ...RequestOpt } // Define a function named Request that takes in four parameters: -func (obj *Client) Request(ctx context.Context, method string, href string, options ...RequestOption) (resp *Response, err error) { - if obj == nil { - return nil, errors.New("client is nil") +func (obj *Client) Request(ctx context.Context, method string, href string, options ...RequestOption) (response *Response, err error) { + if obj.closed { + return nil, errors.New("client is closed") } if ctx == nil { ctx = obj.ctx @@ -231,9 +229,6 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti rawOption = options[0] } optionBak := obj.newRequestOption(rawOption) - if optionBak.Method == "" { - optionBak.Method = method - } if optionBak.Url == nil { if optionBak.Url, err = url.Parse(href); err != nil { err = tools.WrapError(err, "url parse error") @@ -241,42 +236,34 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti } } for maxRetries := 0; maxRetries <= optionBak.MaxRetries; maxRetries++ { - select { - case <-obj.ctx.Done(): - obj.Close() - return nil, tools.WrapError(obj.ctx.Err(), "client ctx 错误") - case <-ctx.Done(): - return nil, tools.WrapError(ctx.Err(), "request ctx 错误") - default: - option := optionBak - resp, err = obj.request(ctx, &option) - if err == nil || errors.Is(err, errFatal) || option.once { - return - } + option := optionBak + response, err = obj.request(ctx, &option) + if err == nil || errors.Is(err, errFatal) || option.once { + return } } - if err == nil { - err = errors.New("max try num") - } - return resp, err + return } func (obj *Client) request(ctx context.Context, option *RequestOption) (response *Response, err error) { response = new(Response) defer func() { + //read body if err == nil && !response.IsStream() { err = response.ReadBody() } + //result callback if err == nil && option.ResultCallBack != nil { err = option.ResultCallBack(ctx, obj, response) } - if err != nil { + + if err != nil { //err callback, must close body response.CloseBody() if option.ErrCallBack != nil { if err2 := option.ErrCallBack(ctx, obj, response, err); err2 != nil { err = tools.WrapError(errFatal, err2) } } - } else if !response.IsStream() { + } else if !response.IsStream() { //is not is stream must close body response.CloseBody() } }() @@ -285,25 +272,11 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response return } } - //init headers and orderheaders,befor init ctxData - headers, err := option.initHeaders() - if err != nil { - return response, tools.WrapError(err, errors.New("tempRequest init headers error"), err) - } - if headers == nil { - headers = defaultHeaders() - } - response.bar = option.Bar response.disUnzip = option.DisUnZip response.disDecode = option.DisDecode response.stream = option.Stream - method := strings.ToUpper(option.Method) - if method == "" { - method = http.MethodGet - } - //init ctxData ctxData, err := NewReqCtxData(ctx, option) if err != nil { @@ -327,18 +300,42 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response return response, tools.WrapError(err, errors.New("tempRequest init body error"), err) } //create request - reqs, err := http.NewRequestWithContext(response.ctx, method, href.String(), body) + reqs, err := http.NewRequestWithContext(response.ctx, strings.ToUpper(option.Method), href.String(), body) if err != nil { return response, tools.WrapError(errFatal, errors.New("tempRequest 构造request失败"), err) } //init headers - reqs.Header = headers + + //init headers and orderheaders,befor init ctxData + headers, err := option.initHeaders() + if err != nil { + return response, tools.WrapError(err, errors.New("tempRequest init headers error"), err) + } + if headers == nil { + reqs.Header = http.Header{ + "User-Agent": []string{UserAgent}, + "Accept": []string{"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"}, + "Accept-Encoding": []string{"gzip, deflate, br"}, + "Accept-Language": []string{AcceptLanguage}, + "Sec-Ch-Ua": []string{SecChUa}, + "Sec-Ch-Ua-Mobile": []string{"?0"}, + "Sec-Ch-Ua-Platform": []string{`"Windows"`}, + } + } else { + reqs.Header = headers + } //add Referer if reqs.Header.Get("Referer") == "" { if option.Referer != "" { reqs.Header.Set("Referer", option.Referer) - } else { - reqs.Header.Set("Referer", fmt.Sprintf("%s://%s", reqs.URL.Scheme, reqs.URL.Host)) + } else if reqs.URL.Scheme != "" && reqs.URL.Host != "" { + referBuild := builderPool.Get().(strings.Builder) + referBuild.WriteString(reqs.URL.Scheme) + referBuild.WriteString("://") + referBuild.WriteString(reqs.URL.Host) + reqs.Header.Set("Referer", referBuild.String()) + referBuild.Reset() + builderPool.Put(referBuild) } } diff --git a/response.go b/response.go index 8c1afd0..b291dd5 100644 --- a/response.go +++ b/response.go @@ -266,7 +266,7 @@ func (obj *Response) IsStream() bool { } // read body -func (obj *Response) ReadBody() error { +func (obj *Response) ReadBody() (err error) { if obj.IsStream() { return errors.New("can not read stream") } @@ -274,8 +274,11 @@ func (obj *Response) ReadBody() error { return errors.New("already read body") } obj.readBody = true - var err error - bBody := bytes.NewBuffer(nil) + bBody := bufferPool.Get().(*bytes.Buffer) + defer func() { + bBody.Reset() + bufferPool.Put(bBody) + }() if obj.bar && obj.ContentLength() > 0 { err = tools.CopyWitchContext(obj.response.Request.Context(), &barBody{ bar: bar.NewClient(obj.response.ContentLength), @@ -289,16 +292,11 @@ func (obj *Response) ReadBody() error { return errors.New("response read content error: " + err.Error()) } if !obj.disDecode && obj.defaultDecode() { - if content, encoding, err := tools.Charset(bBody.Bytes(), obj.ContentType()); err == nil { - obj.content, obj.encoding = content, encoding - } else { - obj.content = bBody.Bytes() - } + obj.content, obj.encoding, _ = tools.Charset(bBody.Bytes(), obj.ContentType()) } else { obj.content = bBody.Bytes() } - obj.response.Body.Close() - return nil + return } // conn is new conn @@ -316,8 +314,8 @@ func (obj *Response) CloseBody() error { } if obj.IsStream() || !obj.readBody { obj.ForceCloseConn() - } else if obj.rawConn != nil { - obj.rawConn.Close() + } else { //close body + obj.closeBody() } obj.cnl() return nil @@ -339,6 +337,13 @@ func (obj *Response) InPool() bool { return false } +// close body +func (obj *Response) closeBody() { + if obj.rawConn != nil { + obj.rawConn.Close() + } +} + // safe close conn func (obj *Response) CloseConn() { if obj.rawConn != nil { diff --git a/roundTripper.go b/roundTripper.go index b7ba9ff..cadeb80 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -7,6 +7,7 @@ import ( "errors" "net" "net/url" + "strings" "time" "net/http" @@ -18,28 +19,15 @@ import ( ) type reqTask struct { - ctx context.Context - cnl context.CancelFunc - req *http.Request - res *http.Response - emptyPool chan struct{} - err error - orderHeaders []string - responseHeaderTimeout time.Duration + ctx context.Context + cnl context.CancelFunc + req *http.Request + res *http.Response + emptyPool chan struct{} + err error + orderHeaders []string } -func newReqTask(ctx context.Context, req *http.Request, ctxData *reqCtxData) *reqTask { - if ctxData.responseHeaderTimeout == 0 { - ctxData.responseHeaderTimeout = time.Second * 30 - } - task := new(reqTask) - task.req = req - task.emptyPool = make(chan struct{}) - task.orderHeaders = ctxData.orderHeaders - task.responseHeaderTimeout = ctxData.responseHeaderTimeout - task.ctx, task.cnl = context.WithCancel(ctx) - return task -} func (obj *reqTask) inPool() bool { return obj.err == nil && obj.res != nil && obj.res.StatusCode != 101 && obj.res.Header.Get("Content-Type") != "text/event-stream" } @@ -49,14 +37,15 @@ type connKey struct { addr string } -func getKey(ctxData *reqCtxData, req *http.Request) connKey { - key := connKey{ - addr: getAddr(req.URL), - } - if ctxData.proxy != nil { - key.proxy = ctxData.proxy.String() - } - return key +func getKey(ctxData *reqCtxData, req *http.Request) (key string) { + b := builderPool.Get().(strings.Builder) + b.WriteString(getAddr(ctxData.proxy)) + b.WriteString("@") + b.WriteString(getAddr(req.URL)) + key = b.String() + b.Reset() + builderPool.Put(b) + return } type roundTripper struct { @@ -115,7 +104,7 @@ func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper connPools: newConnPools(), } } -func (obj *roundTripper) newConnPool(conn *connecotr, key connKey) *connPool { +func (obj *roundTripper) newConnPool(conn *connecotr, key string) *connPool { pool := new(connPool) pool.connKey = key pool.deleteCtx, pool.deleteCnl = context.WithCancelCause(obj.ctx) @@ -126,10 +115,7 @@ func (obj *roundTripper) newConnPool(conn *connecotr, key connKey) *connPool { go pool.rwMain(conn) return pool } -func (obj *roundTripper) getConnPool(key connKey) *connPool { - return obj.connPools.get(key) -} -func (obj *roundTripper) putConnPool(key connKey, conn *connecotr) { +func (obj *roundTripper) putConnPool(key string, conn *connecotr) { conn.inPool = true if conn.h2RawConn == nil { go conn.read() @@ -148,11 +134,15 @@ func (obj *roundTripper) tlsConfigClone() *tls.Config { func (obj *roundTripper) utlsConfigClone() *utls.Config { return obj.utlsConfig.Clone() } -func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Request) (conn *connecotr, err error) { +func (obj *roundTripper) newConnecotr(netConn net.Conn) *connecotr { + conne := new(connecotr) + conne.withCancel(obj.ctx, obj.ctx) + conne.rawConn = netConn + return conne +} +func (obj *roundTripper) dial(ctxData *reqCtxData, req *http.Request) (conn *connecotr, err error) { proxy := cloneUrl(ctxData.proxy) - if proxy != nil { - key.proxy = proxy.String() - } else if !ctxData.disProxy && obj.getProxy != nil { + if proxy == nil && !ctxData.disProxy && obj.getProxy != nil { proxyStr, err := obj.getProxy(req.Context(), proxy) if err != nil { return conn, err @@ -161,7 +151,7 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque return conn, err } } - netConn, err := obj.dialer.DialContextWithProxy(req.Context(), ctxData, "tcp", req.URL.Scheme, key.addr, getHost(req), proxy, obj.tlsConfigClone()) + netConn, err := obj.dialer.DialContextWithProxy(req.Context(), ctxData, "tcp", req.URL.Scheme, getAddr(req.URL), getHost(req), proxy, obj.tlsConfigClone()) if err != nil { return conn, err } @@ -189,7 +179,10 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque netConn = tlsConn } } - conne := newConnecotr(obj.ctx, netConn) + conne := obj.newConnecotr(netConn) + if proxy != nil { + conne.proxy = proxy.String() + } if h2 { if conne.h2RawConn, err = http2.NewClientConn(func() { conne.closeCnl(errors.New("http2 client close")) @@ -205,11 +198,12 @@ func (obj *roundTripper) setGetProxy(getProxy func(ctx context.Context, url *url obj.getProxy = getProxy } -func (obj *roundTripper) poolRoundTrip(task *reqTask, key connKey) (newConn bool) { - pool := obj.getConnPool(key) +func (obj *roundTripper) poolRoundTrip(ctxData *reqCtxData, task *reqTask, key string) (newConn bool) { + pool := obj.connPools.get(key) if pool == nil { return true } + task.ctx, task.cnl = context.WithTimeout(obj.ctx, ctxData.responseHeaderTimeout) select { case pool.tasks <- task: select { @@ -222,18 +216,17 @@ func (obj *roundTripper) poolRoundTrip(task *reqTask, key connKey) (newConn bool return true } } -func (obj *roundTripper) connRoundTrip(ctxData *reqCtxData, task *reqTask, key connKey) (retry bool) { - ckey := key - conn, err := obj.dial(ctxData, &ckey, task.req) +func (obj *roundTripper) connRoundTrip(ctxData *reqCtxData, task *reqTask, key string) (retry bool) { + conn, err := obj.dial(ctxData, task.req) if err != nil { task.err = err return } + task.ctx, task.cnl = context.WithTimeout(obj.ctx, ctxData.responseHeaderTimeout) retry = conn.taskMain(task, false) if retry || task.err != nil { return retry } - conn.connKey = ckey if task.inPool() && !ctxData.disAlive { obj.putConnPool(key, conn) } @@ -241,19 +234,29 @@ func (obj *roundTripper) connRoundTrip(ctxData *reqCtxData, task *reqTask, key c } func (obj *roundTripper) closeConns() { - obj.connPools.iter(func(key connKey, pool *connPool) bool { + obj.connPools.iter(func(key string, pool *connPool) bool { pool.close() obj.connPools.del(key) return true }) } func (obj *roundTripper) forceCloseConns() { - obj.connPools.iter(func(key connKey, pool *connPool) bool { + obj.connPools.iter(func(key string, pool *connPool) bool { pool.forceClose() obj.connPools.del(key) return true }) } +func (obj *roundTripper) newReqTask(req *http.Request, ctxData *reqCtxData) *reqTask { + if ctxData.responseHeaderTimeout == 0 { + ctxData.responseHeaderTimeout = time.Second * 30 + } + task := new(reqTask) + task.req = req + task.emptyPool = make(chan struct{}) + task.orderHeaders = ctxData.orderHeaders + return task +} func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, err error) { ctxData := GetReqCtxData(req.Context()) if ctxData.requestCallBack != nil { @@ -262,11 +265,11 @@ func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, } } key := getKey(ctxData, req) //pool key - task := newReqTask(obj.ctx, req, ctxData) + task := obj.newReqTask(req, ctxData) //get pool conn var isNewConn bool if !ctxData.disAlive { - isNewConn = obj.poolRoundTrip(task, key) + isNewConn = obj.poolRoundTrip(ctxData, task, key) } if ctxData.disAlive || isNewConn { ctxData.isNewConn = true diff --git a/rw.go b/rw.go index b5fa100..29c7712 100644 --- a/rw.go +++ b/rw.go @@ -16,6 +16,13 @@ func (obj *readWriteCloser) Conn() *connecotr { func (obj *readWriteCloser) Read(p []byte) (n int, err error) { return obj.body.Read(p) } +func (obj *readWriteCloser) InPool() bool { + return obj.conn.inPool +} +func (obj *readWriteCloser) Proxy() string { + return obj.conn.proxy +} + func (obj *readWriteCloser) Close() (err error) { err = obj.body.Close() if !obj.InPool() { @@ -25,12 +32,6 @@ func (obj *readWriteCloser) Close() (err error) { } return } -func (obj *readWriteCloser) InPool() bool { - return obj.conn.inPool -} -func (obj *readWriteCloser) Proxy() string { - return obj.conn.connKey.proxy -} // safe close conn func (obj *readWriteCloser) CloseConn() { diff --git a/tools.go b/tools.go index ca86d2b..2562c55 100644 --- a/tools.go +++ b/tools.go @@ -2,11 +2,14 @@ package requests import ( "bufio" + "bytes" "fmt" "io" "net" "net/http" "net/url" + "strings" + "sync" _ "unsafe" "golang.org/x/exp/slices" @@ -29,18 +32,24 @@ func getHost(req *http.Request) string { } return host } -func getAddr(uurl *url.URL) string { +func getAddr(uurl *url.URL) (addr string) { if uurl == nil { return "" } _, port, _ := net.SplitHostPort(uurl.Host) if port == "" { + bs := builderPool.Get().(strings.Builder) + bs.WriteString(uurl.Host) + bs.WriteString(":") if uurl.Scheme == "https" { - port = "443" + bs.WriteString("443") } else { - port = "80" + bs.WriteString("80") } - return fmt.Sprintf("%s:%s", uurl.Host, port) + addr = bs.String() + bs.Reset() + builderPool.Put(bs) + return } return uurl.Host } @@ -103,16 +112,34 @@ func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err err if r.Header.Get("Content-Length") == "" && shouldSendContentLength(r) { r.Header.Set("Content-Length", fmt.Sprint(r.ContentLength)) } - if _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", r.Method, ruri); err != nil { + bs := builderPool.Get().(strings.Builder) + defer func() { + bs.Reset() + builderPool.Put(bs) + }() + bs.WriteString(r.Method) + bs.WriteString(" ") + bs.WriteString(ruri) + bs.WriteString(" ") + bs.WriteString(r.Proto) + bs.WriteString("\r\n") + if _, err = w.WriteString(bs.String()); err != nil { return err } for _, k := range orderHeaders { - if k2, ok := replaceMap[k]; ok { - k = k2 - } - for _, v := range r.Header.Values(k) { - if _, err = fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err + if vs, ok := r.Header[k]; ok { + if k2, ok := replaceMap[k]; ok { + k = k2 + } + for _, v := range vs { + bs.Reset() + bs.WriteString(k) + bs.WriteString(": ") + bs.WriteString(v) + bs.WriteString("\r\n") + if _, err = w.WriteString(bs.String()); err != nil { + return err + } } } } @@ -122,7 +149,12 @@ func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err err k = k2 } for _, v := range vs { - if _, err = fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + bs.Reset() + bs.WriteString(k) + bs.WriteString(": ") + bs.WriteString(v) + bs.WriteString("\r\n") + if _, err = w.WriteString(bs.String()); err != nil { return err } } @@ -138,3 +170,15 @@ func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err err } return w.Flush() } + +var bufferPool sync.Pool +var builderPool sync.Pool + +func init() { + bufferPool.New = func() interface{} { + return bytes.NewBuffer(nil) + } + builderPool.New = func() interface{} { + return strings.Builder{} + } +}