From 36a4432d712372c5bcf64d5df14315b9806fa70b Mon Sep 17 00:00:00 2001 From: bxd <2216403312@qq.com> Date: Fri, 1 Dec 2023 15:57:21 +0800 Subject: [PATCH] optimize code --- client.go | 12 +-- conn.go | 125 +++++++++++----------- dial.go | 28 +++-- go.mod | 6 +- go.sum | 6 +- option.go | 11 +- pip.go | 10 -- requests.go | 104 ++++++------------ response.go | 59 ++++------- roundTripper.go | 215 +++++++++++++------------------------- rw.go | 2 +- test/protocol/sse_test.go | 1 + tools.go | 3 + 13 files changed, 227 insertions(+), 355 deletions(-) diff --git a/client.go b/client.go index 7d0e6bc..ef0c64d 100644 --- a/client.go +++ b/client.go @@ -86,21 +86,21 @@ func (obj *Client) Close() { obj.cnl() } -func (obj *Client) getClient(option *RequestOption) *http.Client { +func (obj *Client) send(option *RequestOption, reqs *http.Request) (*http.Response, error) { if option.DisCookie { - return &http.Client{ + return (&http.Client{ Transport: obj.client.Transport, CheckRedirect: obj.client.CheckRedirect, Timeout: obj.client.Timeout, - } + }).Do(reqs) } if option.Jar != nil { - return &http.Client{ + return (&http.Client{ Transport: obj.client.Transport, CheckRedirect: obj.client.CheckRedirect, Timeout: obj.client.Timeout, Jar: option.Jar.jar, - } + }).Do(reqs) } - return obj.client + return obj.client.Do(reqs) } diff --git a/conn.go b/conn.go index a5a910c..b7da29c 100644 --- a/conn.go +++ b/conn.go @@ -19,9 +19,9 @@ type connecotr struct { connKey connKey deleteCtx context.Context //force close deleteCnl context.CancelCauseFunc - - closeCtx context.Context //safe close - closeCnl context.CancelCauseFunc + afterTime *time.Timer + closeCtx context.Context //safe close + closeCnl context.CancelCauseFunc bodyCtx context.Context //body close bodyCnl context.CancelCauseFunc @@ -35,12 +35,21 @@ type connecotr struct { inPool bool } +func newConnecotr(ctx context.Context, netConn net.Conn) *connecotr { + conne := new(connecotr) + conne.withCancel(ctx, ctx) + conne.rawConn = netConn + return conne +} func (obj *connecotr) withCancel(deleteCtx context.Context, closeCtx context.Context) { obj.deleteCtx, obj.deleteCnl = context.WithCancelCause(deleteCtx) obj.closeCtx, obj.closeCnl = context.WithCancelCause(closeCtx) } func (obj *connecotr) Close() error { obj.deleteCnl(errors.New("connecotr close")) + if obj.afterTime != nil { + obj.afterTime.Stop() + } if obj.h2RawConn != nil { obj.h2RawConn.Close() } @@ -52,14 +61,12 @@ func (obj *connecotr) read() (err error) { } var pw *pipCon obj.pr, pw = pipe(obj.deleteCtx) - defer func() { - pw.cnl(err) - obj.pr.cnl(err) - obj.Close() - }() if _, err = io.Copy(pw, obj.rawConn); err == nil { err = io.EOF } + pw.cnl(err) + obj.pr.cnl(err) + obj.Close() return } func (obj *connecotr) Read(b []byte) (i int, err error) { @@ -87,17 +94,10 @@ func (obj *connecotr) wrapBody(task *reqTask) { task.res.Body = body } func (obj *connecotr) http1Req(task *reqTask) { - defer task.cnl() - if task.debug { - debugPrint(task.requestId, "http1 req start") - } task.err = httpWrite(task.req, obj.w, task.orderHeaders) // if task.err = task.req.Write(obj); task.err == nil { // task.err = obj.w.Flush() // } - if task.debug { - debugPrint(task.requestId, "http1 req write ok ,err: ", task.err) - } if task.err == nil { if task.res, task.err = http.ReadResponse(obj.r, task.req); task.res != nil && task.err == nil { obj.wrapBody(task) @@ -107,24 +107,16 @@ func (obj *connecotr) http1Req(task *reqTask) { } else { task.err = tools.WrapError(task.err, "http1 write error") } - if task.debug { - debugPrint(task.requestId, "http1 req ok ,err: ", task.err) - } + task.cnl() } func (obj *connecotr) http2Req(task *reqTask) { - defer task.cnl() - if task.debug { - debugPrint(task.requestId, "http2 req start") - } if task.res, task.err = obj.h2RawConn.RoundTrip(task.req); task.res != nil && task.err == nil { obj.wrapBody(task) } else if task.err != nil { task.err = tools.WrapError(task.err, "http2 roundTrip error") } - if task.debug { - debugPrint(task.requestId, "http2 req ok,err: ", task.err) - } + task.cnl() } func (obj *connecotr) waitBodyClose() error { select { @@ -135,44 +127,57 @@ func (obj *connecotr) waitBodyClose() error { } } -func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer, waitBody bool) (*http.Response, error, bool) { +func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) { if obj.h2Closed() { - if task.debug { - debugPrint(task.requestId, "h2 con is closed") - } - return nil, errors.New("conn is closed"), true + obj.Close() + task.err = errors.New("conn is closed") + return true } - select { - case <-obj.closeCtx.Done(): - if task.debug { - debugPrint(task.requestId, "connecotr closeCnl") + if !waitBody { + select { + case <-obj.closeCtx.Done(): + obj.Close() + task.err = tools.WrapError(obj.closeCtx.Err(), "conn close ctx error: ") + return true + default: } - return nil, tools.WrapError(obj.closeCtx.Err(), "close ctx error: "), true - default: } if obj.h2RawConn != nil { go obj.http2Req(task) } else { go obj.http1Req(task) } - if afterTime == nil { - afterTime = time.NewTimer(task.responseHeaderTimeout) - defer afterTime.Stop() + if obj.afterTime == nil { + obj.afterTime = time.NewTimer(task.responseHeaderTimeout) } else { - afterTime.Reset(task.responseHeaderTimeout) + obj.afterTime.Reset(task.responseHeaderTimeout) } select { case <-task.ctx.Done(): - if waitBody && task.res != nil && task.err == nil { + if task.err != nil { + obj.Close() + return false + } + if task.res == nil { + obj.Close() + task.err = errors.New("response is nil") + return false + } + if waitBody { task.err = obj.waitBodyClose() } + return false case <-obj.deleteCtx.Done(): //force conn close + task.cnl() task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ") - case <-afterTime.C: + obj.Close() + return false + case <-obj.afterTime.C: + task.cnl() task.err = errors.New("response Header is Timeout") + obj.Close() + return false } - task.cnl() - return task.res, task.err, false } type connPool struct { @@ -189,12 +194,21 @@ type connPools struct { connPools sync.Map } +func newConnPools() *connPools { + return new(connPools) +} func (obj *connPools) get(key connKey) *connPool { val, ok := obj.connPools.Load(key) if !ok { return nil } - return val.(*connPool) + pool := val.(*connPool) + select { + case <-pool.closeCtx.Done(): + return nil + default: + return pool + } } func (obj *connPools) set(key connKey, pool *connPool) { obj.connPools.Store(key, pool) @@ -202,8 +216,10 @@ func (obj *connPools) set(key connKey, pool *connPool) { func (obj *connPools) del(key connKey) { obj.connPools.Delete(key) } -func (obj *connPools) iter(f func(key any, value any) bool) { - obj.connPools.Range(f) +func (obj *connPools) iter(f func(key connKey, value *connPool) bool) { + obj.connPools.Range(func(key, value any) bool { + return f(key.(connKey), value.(*connPool)) + }) } func (obj *connPool) notice(task *reqTask) { @@ -215,11 +231,7 @@ func (obj *connPool) notice(task *reqTask) { func (obj *connPool) rwMain(conn *connecotr) { conn.withCancel(obj.deleteCtx, obj.closeCtx) - var afterTime *time.Timer defer func() { - if afterTime != nil { - afterTime.Stop() - } conn.Close() obj.total.Add(-1) if obj.total.Load() <= 0 { @@ -234,21 +246,14 @@ func (obj *connPool) rwMain(conn *connecotr) { case <-conn.closeCtx.Done(): //safe close conn return case task := <-obj.tasks: //recv task - if task.debug { - debugPrint(task.requestId, "recv task") - } if task == nil { - if task.debug { - debugPrint(task.requestId, "recv task is nil") - } return } - res, err, notice := conn.taskMain(task, afterTime, true) - if notice { + if conn.taskMain(task, true) { obj.notice(task) return } - if res == nil || err != nil { + if task.err != nil { return } } diff --git a/dial.go b/dial.go index 4a84450..0bb120f 100644 --- a/dial.go +++ b/dial.go @@ -76,12 +76,11 @@ func NewDail(option DialOption) *DialClient { getAddrType: option.GetAddrType, } } -func (obj *DialClient) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (obj *DialClient) DialContext(ctx context.Context, ctxData *reqCtxData, network string, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, tools.WrapError(err, "addrToIp error,SplitHostPort") } - ctxData := GetReqCtxData(ctx) var dialer *net.Dialer if _, ipInt := gtls.ParseHost(host); ipInt == 0 { //domain host, ok := obj.loadHost(host) @@ -108,9 +107,9 @@ func (obj *DialClient) DialContext(ctx context.Context, network string, addr str } return dialer.DialContext(ctx, network, addr) } -func (obj *DialClient) DialContextWithProxy(ctx context.Context, network string, scheme string, addr string, host string, proxyUrl *url.URL, tlsConfig *tls.Config) (net.Conn, error) { +func (obj *DialClient) DialContextWithProxy(ctx context.Context, ctxData *reqCtxData, network string, scheme string, addr string, host string, proxyUrl *url.URL, tlsConfig *tls.Config) (net.Conn, error) { if proxyUrl == nil { - return obj.DialContext(ctx, network, addr) + return obj.DialContext(ctx, ctxData, network, addr) } if proxyUrl.Port() == "" { if proxyUrl.Scheme == "http" { @@ -121,7 +120,7 @@ func (obj *DialClient) DialContextWithProxy(ctx context.Context, network string, } switch proxyUrl.Scheme { case "http", "https": - conn, err := obj.DialContext(ctx, network, net.JoinHostPort(proxyUrl.Hostname(), proxyUrl.Port())) + conn, err := obj.DialContext(ctx, ctxData, network, net.JoinHostPort(proxyUrl.Hostname(), proxyUrl.Port())) if err != nil { return conn, err } else if proxyUrl.Scheme == "https" { @@ -131,7 +130,7 @@ func (obj *DialClient) DialContextWithProxy(ctx context.Context, network string, } return conn, obj.clientVerifyHttps(ctx, scheme, proxyUrl, addr, host, conn) case "socks5": - return obj.socks5Proxy(ctx, network, addr, proxyUrl) + return obj.socks5Proxy(ctx, ctxData, network, addr, proxyUrl) default: return nil, errors.New("proxyUrl Scheme error") } @@ -354,19 +353,16 @@ func (obj *DialClient) addJa3Tls(ctx context.Context, conn net.Conn, host string } return ja3.NewClient(ctx, conn, ja3Spec, disHttp2, tlsConfig) } -func (obj *DialClient) socks5Proxy(ctx context.Context, network string, addr string, proxyUrl *url.URL) (conn net.Conn, err error) { - defer func() { - if err != nil && conn != nil { - conn.Close() - } - }() - if conn, err = obj.DialContext(ctx, network, net.JoinHostPort(proxyUrl.Hostname(), proxyUrl.Port())); err != nil { +func (obj *DialClient) socks5Proxy(ctx context.Context, ctxData *reqCtxData, network string, addr string, proxyUrl *url.URL) (conn net.Conn, err error) { + if conn, err = obj.DialContext(ctx, ctxData, network, net.JoinHostPort(proxyUrl.Hostname(), proxyUrl.Port())); err != nil { return } didVerify := make(chan struct{}) go func() { - defer close(didVerify) - err = obj.clientVerifySocks5(ctx, proxyUrl, addr, conn) + if err = obj.clientVerifySocks5(ctx, proxyUrl, addr, conn); err != nil { + conn.Close() + } + close(didVerify) }() select { case <-ctx.Done(): @@ -401,7 +397,6 @@ func (obj *DialClient) clientVerifyHttps(ctx context.Context, scheme string, pro var resp *http.Response didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails go func() { - defer close(didReadResponse) connectReq := &http.Request{ Method: http.MethodConnect, URL: &url.URL{Opaque: addr}, @@ -412,6 +407,7 @@ func (obj *DialClient) clientVerifyHttps(ctx context.Context, scheme string, pro return } resp, err = http.ReadResponse(bufio.NewReader(conn), connectReq) + close(didReadResponse) }() select { case <-ctx.Done(): diff --git a/go.mod b/go.mod index 9a5914f..7f11444 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,14 @@ module github.com/gospider007/requests go 1.21.3 require ( - github.com/gospider007/bar v0.0.0-20231121084140-33c7b6797626 - github.com/gospider007/blog v0.0.0-20231121084103-59a004dafccf + github.com/gospider007/bar v0.0.0-20231201075546-252b6e7b6a54 github.com/gospider007/bs4 v0.0.0-20231123090151-001db0b91941 github.com/gospider007/gson v0.0.0-20231119141525-66095080057d github.com/gospider007/gtls v0.0.0-20231120122450-e763299259db github.com/gospider007/ja3 v0.0.0-20231029025157-38fc2f8f2d91 github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1 github.com/gospider007/re v0.0.0-20231024115818-adfd03636256 - github.com/gospider007/tools v0.0.0-20231128142841-23217c299fc2 + github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616 github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4 github.com/refraction-networking/utls v1.5.4 golang.org/x/exp v0.0.0-20231127185646-65229373498e @@ -25,6 +24,7 @@ require ( github.com/caddyserver/certmagic v0.19.2 // indirect github.com/cloudflare/circl v1.3.6 // indirect github.com/gaukas/godicttls v0.0.4 // indirect + github.com/gospider007/blog v0.0.0-20231121084103-59a004dafccf // indirect github.com/gospider007/kinds v0.0.0-20231024093643-7a4424f2d30e // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.3 // indirect diff --git a/go.sum b/go.sum index 0276208..fb02129 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/google/pprof v0.0.0-20230926050212-f7f687d19a98 h1:pUa4ghanp6q4IJHwE9 github.com/google/pprof v0.0.0-20230926050212-f7f687d19a98/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/gospider007/bar v0.0.0-20231121084140-33c7b6797626 h1:zDK4PcXQmAX37JdGUp45gFDMolnBGzWVXgemm5ekG1Y= github.com/gospider007/bar v0.0.0-20231121084140-33c7b6797626/go.mod h1:aYPgmG9340i9x9VQZhf34/XtIj7PHDTq0wSO+7zU/8s= +github.com/gospider007/bar v0.0.0-20231201075546-252b6e7b6a54 h1:3rtF5ZK6b9L8pIsn9AUtSbRBiSpcWW/3Os9XYNyDvKI= +github.com/gospider007/bar v0.0.0-20231201075546-252b6e7b6a54/go.mod h1:aYPgmG9340i9x9VQZhf34/XtIj7PHDTq0wSO+7zU/8s= github.com/gospider007/blog v0.0.0-20231121084103-59a004dafccf h1:1laTsuH/wl5pZ5QlHzacX09QzvwQw0DFENoRMpGBK8Y= github.com/gospider007/blog v0.0.0-20231121084103-59a004dafccf/go.mod h1:CCJ+hvQ0kxL+qB/Wfr1xt7xspsG4XiczhnAPVxG2m3M= github.com/gospider007/bs4 v0.0.0-20231123090151-001db0b91941 h1:Aik3aBqnpujF5LA+JyIm3LNxivobnqAOPr+VVlTbqds= @@ -43,10 +45,10 @@ github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1 h1:tYOQEvELrV+USjK github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1/go.mod h1:3ggAwYdh0NB0OvtiX0l5AfHdBjgsIt9MGsXCQ3iCzQc= github.com/gospider007/re v0.0.0-20231024115818-adfd03636256 h1:Z6kHRANoWB+/4rDzq51vBts0rIXilDrF8pdRNmbMJi4= github.com/gospider007/re v0.0.0-20231024115818-adfd03636256/go.mod h1:X58uk0/F3mVskuQOZng0ZKJiAt3ETn0wxuLN//rVZrE= -github.com/gospider007/tools v0.0.0-20231122021245-1cafbac3ef46 h1:vskdS8WLAveNSDHsAAdwiD+LBLMHq3AND1nGnVydwfM= -github.com/gospider007/tools v0.0.0-20231122021245-1cafbac3ef46/go.mod h1:myK4kDqDx4TlplDVnfYMI7Xi5VUbFZ3fxwAh2Cwm7ks= github.com/gospider007/tools v0.0.0-20231128142841-23217c299fc2 h1:io2bXntt5LSwWIBPztZUQrGnXlqrphVi5zdGBgHuyb8= github.com/gospider007/tools v0.0.0-20231128142841-23217c299fc2/go.mod h1:wiILK6EotceHz/Rnb6ux8PzY3sr5OV+mYuIcbtxpkYI= +github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616 h1:Ix8hbbbaIX9REGs0qqU1b48L3BlleDyNkCPdv297LF8= +github.com/gospider007/tools v0.0.0-20231201075443-f0a4bc8cd616/go.mod h1:wiILK6EotceHz/Rnb6ux8PzY3sr5OV+mYuIcbtxpkYI= github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4 h1:h+74nkhhTDN2tiaDjHwR4CjqBTHgh+t1pqE2IAWHN3k= github.com/gospider007/websocket v0.0.0-20231128065110-6296f87425c4/go.mod h1:OncvZIlq9TzwD/tQS/BYY/RKBqbW4+gGY3Ere1K7s24= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= diff --git a/option.go b/option.go index 57e3dd9..bab6a49 100644 --- a/option.go +++ b/option.go @@ -102,7 +102,6 @@ type RequestOption struct { Stream bool //disable auto read WsOption websocket.Option //websocket option DisProxy bool //force disable proxy - Debug bool //enable debugger once bool } @@ -153,17 +152,17 @@ func (obj *RequestOption) initBody(ctx context.Context) (body io.Reader, err err } return } -func (obj *RequestOption) initParams() (string, error) { +func (obj *RequestOption) initParams() (*url.URL, error) { if obj.Params == nil { - return obj.Url.String(), nil + return obj.Url, nil } _, dataMap, _, err := obj.newBody(obj.Params, mapType) if err != nil { - return obj.Url.String(), err + return obj.Url, err } query := dataMap.parseParams().String() if query == "" { - return obj.Url.String(), nil + return obj.Url, nil } pu := cloneUrl(obj.Url) pquery := pu.Query().Encode() @@ -172,7 +171,7 @@ func (obj *RequestOption) initParams() (string, error) { } else { pu.RawQuery = pquery + "&" + query } - return pu.String(), nil + return pu, nil } func (obj *Client) newRequestOption(option RequestOption) RequestOption { tools.Merge(&option, obj.option) diff --git a/pip.go b/pip.go index 48df6d9..8548fce 100644 --- a/pip.go +++ b/pip.go @@ -16,11 +16,6 @@ type pipCon struct { } func (obj *pipCon) Read(b []byte) (n int, err error) { - defer func() { - if err != nil { - obj.Close(err) - } - }() select { case con := <-obj.reader: n = copy(b, con) @@ -35,11 +30,6 @@ func (obj *pipCon) Read(b []byte) (n int, err error) { } } func (obj *pipCon) Write(b []byte) (n int, err error) { - defer func() { - if err != nil { - obj.Close(err) - } - }() obj.lock.Lock() defer obj.lock.Unlock() for once := true; once || len(b) > 0; once = false { diff --git a/requests.go b/requests.go index e523309..8290e76 100644 --- a/requests.go +++ b/requests.go @@ -5,9 +5,7 @@ import ( "errors" "fmt" "io" - "log" "net" - "runtime" "time" "net/textproto" @@ -17,7 +15,6 @@ import ( "net/http" - "github.com/gospider007/blog" "github.com/gospider007/gtls" "github.com/gospider007/ja3" "github.com/gospider007/re" @@ -53,10 +50,7 @@ type reqCtxData struct { localAddr *net.TCPAddr //network card ip addrType gtls.AddrType //first ip type dns *net.UDPAddr - - isNewConn bool - debug bool - requestId string + isNewConn bool } func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, error) { @@ -70,15 +64,13 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err ctxData.requestCallBack = option.RequestCallBack ctxData.responseHeaderTimeout = option.ResponseHeaderTimeout ctxData.addrType = option.AddrType - ctxData.dialTimeout = option.DialTimeout ctxData.keepAlive = option.KeepAlive ctxData.localAddr = option.LocalAddr ctxData.dns = option.Dns - ctxData.debug = option.Debug - if option.Debug { - ctxData.requestId = tools.NaoId() - } + ctxData.disProxy = option.DisProxy + ctxData.tlsHandshakeTimeout = option.TlsHandshakeTimeout + ctxData.orderHeaders = option.OrderHeaders //init scheme if option.Url != nil { switch option.Url.Scheme { @@ -94,15 +86,13 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err //init tls timeout if option.TlsHandshakeTimeout == 0 { ctxData.tlsHandshakeTimeout = time.Second * 15 - } else { - ctxData.tlsHandshakeTimeout = option.TlsHandshakeTimeout } //init orderHeaders,this must after init headers - if option.OrderHeaders == nil { + if ctxData.orderHeaders == nil { ctxData.orderHeaders = ja3.DefaultH1OrderHeaders() } else { orderHeaders := []string{} - for _, key := range option.OrderHeaders { + for _, key := range ctxData.orderHeaders { key = textproto.CanonicalMIMEHeaderKey(key) if !slices.Contains(orderHeaders, key) { orderHeaders = append(orderHeaders, key) @@ -116,7 +106,6 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err ctxData.orderHeaders = orderHeaders } //init proxy - ctxData.disProxy = option.DisProxy if option.Proxy != "" { tempProxy, err := gtls.VerifyProxy(option.Proxy) @@ -242,6 +231,15 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti rawOption = options[0] } optionBak := obj.newRequestOption(rawOption) + if optionBak.Method == "" { + optionBak.Method = method + } + if optionBak.Url == nil { + if optionBak.Url, err = url.Parse(href); err != nil { + err = tools.WrapError(err, "url parse error") + return + } + } for maxRetries := 0; maxRetries <= optionBak.MaxRetries; maxRetries++ { select { case <-obj.ctx.Done(): @@ -251,15 +249,6 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti return nil, tools.WrapError(ctx.Err(), "request ctx 错误") default: option := optionBak - if option.Method == "" { - option.Method = method - } - if option.Url == nil { - if option.Url, err = url.Parse(href); err != nil { - err = tools.WrapError(err, "url parse error") - return - } - } resp, err = obj.request(ctx, &option) if err == nil || errors.Is(err, errFatal) || option.once { return @@ -271,20 +260,11 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti } return resp, err } -func debugPrint(requestId string, content ...any) { - _, f, l, _ := runtime.Caller(1) - contents := []any{} - for _, cont := range content { - contents = append(contents, cont, " ") - } - log.Printf("%s:%d, %s>>>%s, %s>>>%s", f, l, blog.Color(2, "requestId"), blog.Color(1, requestId), blog.Color(2, "content"), blog.Color(3, contents...)) -} func (obj *Client) request(ctx context.Context, option *RequestOption) (response *Response, err error) { response = new(Response) defer func() { if err == nil && !response.IsStream() { err = response.ReadBody() - defer response.CloseBody() } if err == nil && option.ResultCallBack != nil { err = option.ResultCallBack(ctx, obj, response) @@ -296,6 +276,8 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response err = tools.WrapError(errFatal, err2) } } + } else if !response.IsStream() { + response.CloseBody() } }() if option.OptionCallBack != nil { @@ -311,14 +293,17 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response if headers == nil { headers = defaultHeaders() } + response.bar = option.Bar response.disUnzip = option.DisUnZip response.disDecode = option.DisDecode response.stream = option.Stream method := strings.ToUpper(option.Method) + if method == "" { + method = http.MethodGet + } - var reqs *http.Request //init ctxData ctxData, err := NewReqCtxData(ctx, option) if err != nil { @@ -341,19 +326,11 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response if err != nil { return response, tools.WrapError(err, errors.New("tempRequest init body error"), err) } - if ctxData.debug { - debugPrint(ctxData.requestId, "create request with ctx") - } //create request - if body != nil { - reqs, err = http.NewRequestWithContext(response.ctx, method, href, body) - } else { - reqs, err = http.NewRequestWithContext(response.ctx, method, href, nil) - } + reqs, err := http.NewRequestWithContext(response.ctx, method, href.String(), body) if err != nil { return response, tools.WrapError(errFatal, errors.New("tempRequest 构造request失败"), err) } - //init headers reqs.Header = headers //add Referer @@ -366,29 +343,22 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response } //set ContentType - if reqs.Header.Get("Content-Type") == "" && reqs.Header.Get("content-type") == "" && option.ContentType != "" { + if option.ContentType != "" && reqs.Header.Get("Content-Type") == "" { reqs.Header.Set("Content-Type", option.ContentType) } //init ws if ctxData.isWs { - if ctxData.debug { - debugPrint(ctxData.requestId, "init websocket headers") - } websocket.SetClientHeadersOption(reqs.Header, option.WsOption) } - switch reqs.URL.Scheme { - case "file": + + if reqs.URL.Scheme == "file" { response.filePath = re.Sub(`^/+`, "", reqs.URL.Path) response.content, err = os.ReadFile(response.filePath) if err != nil { err = tools.WrapError(errFatal, errors.New("read filePath data error"), err) } return - case "http", "https": - default: - err = tools.WrapError(errFatal, fmt.Errorf("url scheme error: %s", reqs.URL.Scheme)) - return } //add host if option.Host != "" { @@ -409,42 +379,30 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response reqs.AddCookie(vv) } } - if ctxData.debug { - debugPrint(ctxData.requestId, "send request start") - } //send req - response.response, err = obj.getClient(option).Do(reqs) - if ctxData.debug { - debugPrint(ctxData.requestId, "send request end, err: ", err) - } + response.response, err = obj.send(option, reqs) response.isNewConn = ctxData.isNewConn if err != nil { err = tools.WrapError(err, "roundTripper error") return - } else if response.response == nil { + } + if response.response == nil { err = errors.New("response is nil") return } - response.rawConn = response.response.Body.(*readWriteCloser) + if response.response.Body != nil { + response.rawConn = response.response.Body.(*readWriteCloser) + } if !response.disUnzip { response.disUnzip = response.response.Uncompressed } if response.response.StatusCode == 101 { response.webSocket, err = websocket.NewClientConn(response.rawConn.Conn(), response.response.Header, response.ForceCloseConn) - if ctxData.debug { - debugPrint(ctxData.requestId, "new websocket client, err: ", err) - } } else if response.response.Header.Get("Content-Type") == "text/event-stream" { response.sse = newSse(response.response.Body, response.ForceCloseConn) - if ctxData.debug { - debugPrint(ctxData.requestId, "new sse client") - } } else if !response.disUnzip { var unCompressionBody io.ReadCloser unCompressionBody, err = tools.CompressionDecode(response.response.Body, response.ContentEncoding()) - if ctxData.debug { - debugPrint(ctxData.requestId, "unCompressionBody, err: ", err) - } if err != nil { if err != io.ErrUnexpectedEOF && err != io.EOF { return diff --git a/response.go b/response.go index cf7af61..8c1afd0 100644 --- a/response.go +++ b/response.go @@ -238,20 +238,9 @@ type barBody struct { func (obj *barBody) Write(con []byte) (int, error) { l, err := obj.body.Write(con) - obj.bar.Print(int64(l)) + obj.bar.Add(int64(l)) return l, err } -func (obj *Response) barRead() (*bytes.Buffer, error) { - barData := &barBody{ - bar: bar.NewClient(obj.response.ContentLength), - body: bytes.NewBuffer(nil), - } - err := tools.CopyWitchContext(obj.response.Request.Context(), barData, obj.response.Body, true) - if err != nil { - return nil, err - } - return barData.body, nil -} func (obj *Response) defaultDecode() bool { return strings.Contains(obj.ContentType(), "html") } @@ -259,21 +248,16 @@ func (obj *Response) defaultDecode() bool { func (obj *Response) Read(con []byte) (i int, err error) { done := make(chan struct{}) go func() { - defer close(done) - defer func() { - if recErr := recover(); recErr != nil && err == nil { - err, _ = recErr.(error) - } - }() i, err = obj.response.Body.Read(con) + close(done) }() select { - case <-obj.ctx.Done(): - obj.response.Body.Close() - return 0, obj.ctx.Err() + case <-obj.response.Request.Context().Done(): + obj.ForceCloseConn() + err = obj.response.Request.Context().Err() case <-done: - return } + return } // return true if response is stream @@ -283,28 +267,26 @@ func (obj *Response) IsStream() bool { // read body func (obj *Response) ReadBody() error { - if obj.webSocket != nil || obj.sse != nil { - return errors.New("ws or sse can not read") + if obj.IsStream() { + return errors.New("can not read stream") } if obj.readBody { return errors.New("already read body") } - var bBody *bytes.Buffer - var err error - defer obj.response.Body.Close() - if obj.bar && obj.ContentLength() > 0 { - bBody, err = obj.barRead() - } else { - bBody = bytes.NewBuffer(nil) - err = tools.CopyWitchContext(obj.response.Request.Context(), bBody, obj.response.Body, true) - } obj.readBody = true - if err != nil { - obj.CloseConn() - return errors.New("response read content error: " + err.Error()) + var err error + bBody := bytes.NewBuffer(nil) + if obj.bar && obj.ContentLength() > 0 { + err = tools.CopyWitchContext(obj.response.Request.Context(), &barBody{ + bar: bar.NewClient(obj.response.ContentLength), + body: bBody, + }, obj.response.Body) + } else { + err = tools.CopyWitchContext(obj.response.Request.Context(), bBody, obj.response.Body) } - if obj.IsStream() { - obj.CloseBody() + if err != nil { + obj.ForceCloseConn() + return errors.New("response read content error: " + err.Error()) } if !obj.disDecode && obj.defaultDecode() { if content, encoding, err := tools.Charset(bBody.Bytes(), obj.ContentType()); err == nil { @@ -315,6 +297,7 @@ func (obj *Response) ReadBody() error { } else { obj.content = bBody.Bytes() } + obj.response.Body.Close() return nil } diff --git a/roundTripper.go b/roundTripper.go index ea4121c..b7ba9ff 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -26,22 +26,19 @@ type reqTask struct { err error orderHeaders []string responseHeaderTimeout time.Duration - debug bool - requestId string } -func newReqTask(req *http.Request, ctxData *reqCtxData) *reqTask { +func newReqTask(ctx context.Context, req *http.Request, ctxData *reqCtxData) *reqTask { if ctxData.responseHeaderTimeout == 0 { ctxData.responseHeaderTimeout = time.Second * 30 } - return &reqTask{ - req: req, - debug: ctxData.debug, - requestId: ctxData.requestId, - emptyPool: make(chan struct{}), - orderHeaders: ctxData.orderHeaders, - responseHeaderTimeout: ctxData.responseHeaderTimeout, - } + task := new(reqTask) + task.req = req + task.emptyPool = make(chan struct{}) + task.orderHeaders = ctxData.orderHeaders + task.responseHeaderTimeout = ctxData.responseHeaderTimeout + task.ctx, task.cnl = context.WithCancel(ctx) + return task } func (obj *reqTask) inPool() bool { return obj.err == nil && obj.res != nil && obj.res.StatusCode != 101 && obj.res.Header.Get("Content-Type") != "text/event-stream" @@ -53,14 +50,13 @@ type connKey struct { } func getKey(ctxData *reqCtxData, req *http.Request) connKey { - var proxy string + key := connKey{ + addr: getAddr(req.URL), + } if ctxData.proxy != nil { - proxy = ctxData.proxy.String() - } - return connKey{ - proxy: proxy, - addr: getAddr(req.URL), + key.proxy = ctxData.proxy.String() } + return key } type roundTripper struct { @@ -70,7 +66,7 @@ type roundTripper struct { dialer *DialClient tlsConfig *tls.Config utlsConfig *utls.Config - proxy func(ctx context.Context, url *url.URL) (string, error) + getProxy func(ctx context.Context, url *url.URL) (string, error) } type roundTripperOption struct { @@ -115,8 +111,8 @@ func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper ctx: ctx, cnl: cnl, dialer: dialClient, - proxy: option.GetProxy, - connPools: new(connPools), + getProxy: option.GetProxy, + connPools: newConnPools(), } } func (obj *roundTripper) newConnPool(conn *connecotr, key connKey) *connPool { @@ -140,13 +136,8 @@ func (obj *roundTripper) putConnPool(key connKey, conn *connecotr) { } pool := obj.connPools.get(key) if pool != nil { - select { - case <-pool.closeCtx.Done(): - obj.connPools.set(key, obj.newConnPool(conn, key)) - default: - pool.total.Add(1) - go pool.rwMain(conn) - } + pool.total.Add(1) + go pool.rwMain(conn) } else { obj.connPools.set(key, obj.newConnPool(conn, key)) } @@ -158,27 +149,22 @@ func (obj *roundTripper) utlsConfigClone() *utls.Config { return obj.utlsConfig.Clone() } func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Request) (conn *connecotr, err error) { - proxy := ctxData.proxy - if !ctxData.disProxy && proxy == nil { - if proxy, err = obj.getProxy(req.Context(), req.URL); err != nil { + proxy := cloneUrl(ctxData.proxy) + if proxy != nil { + key.proxy = proxy.String() + } else if !ctxData.disProxy && obj.getProxy != nil { + proxyStr, err := obj.getProxy(req.Context(), proxy) + if err != nil { + return conn, err + } + if proxy, err = gtls.VerifyProxy(proxyStr); err != nil { return conn, err } } - if proxy != nil { - key.proxy = proxy.String() - } - var netConn net.Conn - host := getHost(req) - if proxy == nil { - netConn, err = obj.dialer.DialContext(req.Context(), "tcp", key.addr) - } else { - netConn, err = obj.dialer.DialContextWithProxy(req.Context(), "tcp", req.URL.Scheme, key.addr, host, proxy, obj.tlsConfigClone()) - } + netConn, err := obj.dialer.DialContextWithProxy(req.Context(), ctxData, "tcp", req.URL.Scheme, key.addr, getHost(req), proxy, obj.tlsConfigClone()) if err != nil { return conn, err } - conne := new(connecotr) - conne.withCancel(obj.ctx, obj.ctx) var h2 bool if req.URL.Scheme == "https" { ctx, cnl := context.WithTimeout(req.Context(), ctxData.tlsHandshakeTimeout) @@ -188,22 +174,22 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque if ctxData.forceHttp1 { tlsConfig.NextProtos = []string{"http/1.1"} } - tlsConn, err := obj.dialer.addJa3Tls(ctx, netConn, host, ctxData.isWs || ctxData.forceHttp1, ctxData.ja3Spec, tlsConfig) + tlsConn, err := obj.dialer.addJa3Tls(ctx, netConn, getHost(req), ctxData.isWs || ctxData.forceHttp1, ctxData.ja3Spec, tlsConfig) if err != nil { - return conne, tools.WrapError(err, "add ja3 tls error") + return conn, tools.WrapError(err, "add ja3 tls error") } h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" netConn = tlsConn } else { - tlsConn, err := obj.dialer.addTls(ctx, netConn, host, ctxData.isWs || ctxData.forceHttp1, obj.tlsConfigClone()) + tlsConn, err := obj.dialer.addTls(ctx, netConn, getHost(req), ctxData.isWs || ctxData.forceHttp1, obj.tlsConfigClone()) if err != nil { - return conne, tools.WrapError(err, "add tls error") + return conn, tools.WrapError(err, "add tls error") } h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" netConn = tlsConn } } - conne.rawConn = netConn + conne := newConnecotr(obj.ctx, netConn) if h2 { if conne.h2RawConn, err = http2.NewClientConn(func() { conne.closeCnl(errors.New("http2 client close")) @@ -211,140 +197,89 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque return conne, err } } else { - conne.r = bufio.NewReader(conne) - conne.w = bufio.NewWriter(conne) + conne.r, conne.w = bufio.NewReader(conne), bufio.NewWriter(conne) } return conne, err } func (obj *roundTripper) setGetProxy(getProxy func(ctx context.Context, url *url.URL) (string, error)) { - obj.proxy = getProxy -} -func (obj *roundTripper) getProxy(ctx context.Context, proxyUrl *url.URL) (*url.URL, error) { - if obj.proxy == nil { - return nil, nil - } - proxy, err := obj.proxy(ctx, proxyUrl) - if err != nil { - return nil, err - } - return gtls.VerifyProxy(proxy) + obj.getProxy = getProxy } -func (obj *roundTripper) poolRoundTrip(task *reqTask, key connKey) (bool, error) { - if task.debug { - debugPrint(task.requestId, "poolRoundTrip start") - } +func (obj *roundTripper) poolRoundTrip(task *reqTask, key connKey) (newConn bool) { pool := obj.getConnPool(key) if pool == nil { - if task.debug { - debugPrint(task.requestId, "poolRoundTrip not found conn pool") - } - return false, nil + return true } select { - case <-obj.ctx.Done(): - if task.debug { - debugPrint(task.requestId, "RoundTripper already cloed") - } - return false, tools.WrapError(obj.ctx.Err(), "roundTripper close ctx error: ") - case <-pool.closeCtx.Done(): - if task.debug { - debugPrint(task.requestId, "pool already cloed") - } - return false, pool.closeCtx.Err() case pool.tasks <- task: - if task.debug { - debugPrint(task.requestId, "poolRoundTrip tasks <- task") - } select { case <-task.emptyPool: - if task.debug { - debugPrint(task.requestId, "poolRoundTrip emptyPool") - } + return true case <-task.ctx.Done(): - if task.err == nil && task.res == nil { - task.err = tools.WrapError(task.ctx.Err(), "task close ctx error: ") - } - if task.debug { - debugPrint(task.requestId, "task ctx done, err: ", task.err) - } - return true, task.err + return false } default: - if task.debug { - debugPrint(task.requestId, "conn pool not idle") - } + return true } - return false, nil +} +func (obj *roundTripper) connRoundTrip(ctxData *reqCtxData, task *reqTask, key connKey) (retry bool) { + ckey := key + conn, err := obj.dial(ctxData, &ckey, task.req) + if err != nil { + task.err = err + return + } + retry = conn.taskMain(task, false) + if retry || task.err != nil { + return retry + } + conn.connKey = ckey + if task.inPool() && !ctxData.disAlive { + obj.putConnPool(key, conn) + } + return retry } func (obj *roundTripper) closeConns() { - obj.connPools.iter(func(key, value any) bool { - pool := value.(*connPool) + obj.connPools.iter(func(key connKey, pool *connPool) bool { pool.close() - obj.connPools.del(key.(connKey)) + obj.connPools.del(key) return true }) } func (obj *roundTripper) forceCloseConns() { - obj.connPools.iter(func(key, value any) bool { - pool := value.(*connPool) + obj.connPools.iter(func(key connKey, pool *connPool) bool { pool.forceClose() - obj.connPools.del(key.(connKey)) + obj.connPools.del(key) return true }) } -func (obj *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { +func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, err error) { ctxData := GetReqCtxData(req.Context()) if ctxData.requestCallBack != nil { - if err := ctxData.requestCallBack(req.Context(), req, nil); err != nil { + if err = ctxData.requestCallBack(req.Context(), req, nil); err != nil { return nil, err } } key := getKey(ctxData, req) //pool key - task := newReqTask(req, ctxData) - task.ctx, task.cnl = context.WithCancel(obj.ctx) - defer task.cnl() + task := newReqTask(obj.ctx, req, ctxData) //get pool conn + var isNewConn bool if !ctxData.disAlive { - if ok, err := obj.poolRoundTrip(task, key); err != nil { - return nil, err - } else if ok { //is conn multi - if ctxData.requestCallBack != nil { - if err = ctxData.requestCallBack(task.req.Context(), req, task.res); err != nil { - task.err = err - } + isNewConn = obj.poolRoundTrip(task, key) + } + if ctxData.disAlive || isNewConn { + ctxData.isNewConn = true + for { + retry := obj.connRoundTrip(ctxData, task, key) + if !retry { + break } - return task.res, task.err } } - ctxData.isNewConn = true - if task.debug { - debugPrint(ctxData.requestId, "new Conn") - } -newConn: - ckey := key - conn, err := obj.dial(ctxData, &ckey, req) - if err != nil { - return nil, err - } - if _, _, notice := conn.taskMain(task, nil, false); notice { - goto newConn - } - if task.err == nil && task.res == nil { - task.err = obj.ctx.Err() - } - conn.connKey = ckey - if task.inPool() && !ctxData.disAlive { - if task.debug { - debugPrint(ctxData.requestId, "conn put conn pool") - } - obj.putConnPool(key, conn) - } - if ctxData.requestCallBack != nil { - if err = ctxData.requestCallBack(task.req.Context(), req, task.res); err != nil { + if task.err == nil && ctxData.requestCallBack != nil { + if err = ctxData.requestCallBack(task.req.Context(), task.req, task.res); err != nil { task.err = err - conn.Close() } } return task.res, task.err diff --git a/rw.go b/rw.go index c27ddb2..b5fa100 100644 --- a/rw.go +++ b/rw.go @@ -21,7 +21,7 @@ func (obj *readWriteCloser) Close() (err error) { if !obj.InPool() { obj.ForceCloseConn() } else { - obj.conn.bodyCnl(errors.New("readWriteCloser close")) + obj.conn.bodyCnl(errors.New("body close")) } return } diff --git a/test/protocol/sse_test.go b/test/protocol/sse_test.go index d90f944..fa12611 100644 --- a/test/protocol/sse_test.go +++ b/test/protocol/sse_test.go @@ -46,6 +46,7 @@ func TestSse(t *testing.T) { t.Error(err) } }() + time.Sleep(time.Second * 3) response, err := requests.Get(nil, "http://127.0.0.1:3333/events") // Send WebSocket request if err != nil { t.Error(err) diff --git a/tools.go b/tools.go index e7b4e19..ca86d2b 100644 --- a/tools.go +++ b/tools.go @@ -45,6 +45,9 @@ func getAddr(uurl *url.URL) string { return uurl.Host } func cloneUrl(u *url.URL) *url.URL { + if u == nil { + return nil + } r := *u return &r }