From d8da7e2a5e5d1a3b21ed60c04aa6b8230035a786 Mon Sep 17 00:00:00 2001 From: bxd <2216403312@qq.com> Date: Tue, 15 Aug 2023 13:31:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- client.go | 5 ++ option.go | 4 ++ requests.go | 21 +++------ response.go | 30 ++++++++---- roundTripper.go | 118 +++++++++++++++++++++++++++--------------------- 5 files changed, 103 insertions(+), 75 deletions(-) diff --git a/client.go b/client.go index 9b3e333..5026a77 100644 --- a/client.go +++ b/client.go @@ -35,6 +35,8 @@ type ClientOption struct { RedirectNum int //重定向次数,小于0为禁用,0:不限制 + DisAlive bool //关闭连接复用 + DisDecode bool //关闭自动编码 DisRead bool //关闭默认读取请求体 DisUnZip bool //关闭自动解压 @@ -67,6 +69,7 @@ type Client struct { headers any //请求头 bar bool //是否开启bar + disAlive bool disCookie bool client *http.Client noJarClient *http.Client @@ -156,6 +159,7 @@ func NewClient(preCtx context.Context, options ...ClientOption) (*Client, error) disCookie: option.DisCookie, redirectNum: option.RedirectNum, + disAlive: option.DisAlive, disDecode: option.DisDecode, disRead: option.DisRead, disUnZip: option.DisUnZip, @@ -197,6 +201,7 @@ func (obj *Client) SetGetProxy(getProxy func(ctx context.Context, url *url.URL) func (obj *Client) CloseIdleConnections() { obj.transport.CloseIdleConnections() } + func (obj *Client) Close() { obj.CloseIdleConnections() obj.cnl() diff --git a/option.go b/option.go index 9b805ed..b90139d 100644 --- a/option.go +++ b/option.go @@ -41,6 +41,7 @@ type RequestOption struct { ContentType string //headers 中Content-Type 的值 Raw any //不设置context-type,支持string,[]bytes,json,map + DisAlive bool //关闭连接复用 DisCookie bool //关闭cookies管理,这个请求不用cookies池 DisDecode bool //关闭自动解码 Bar bool //是否开启bar @@ -214,6 +215,9 @@ func (obj *Client) newRequestOption(option RequestOption) RequestOption { if option.Timeout == 0 { option.Timeout = obj.timeout } + if !option.DisAlive { + option.DisAlive = obj.disAlive + } if !option.DisCookie { option.DisCookie = obj.disCookie } diff --git a/requests.go b/requests.go index 5b92444..45b2eec 100644 --- a/requests.go +++ b/requests.go @@ -126,9 +126,11 @@ var ( ) type reqCtxData struct { - redirectNum int - proxy *url.URL - disProxy bool + redirectNum int + proxy *url.URL + disProxy bool + + disAlive bool ws bool requestCallBack func(context.Context, *http.Request) error responseCallBack func(context.Context, *http.Request, *http.Response) error @@ -295,11 +297,9 @@ func (obj *Client) request(preCtx context.Context, option RequestOption) (respon //构造ctxData ctxData := new(reqCtxData) + ctxData.disAlive = option.DisAlive ctxData.requestCallBack = option.RequestCallBack ctxData.responseCallBack = option.ResponseCallBack - // if option.Body != nil { - // ctxData.disBody = true - // } //构造代理 ctxData.disProxy = option.DisProxy if !ctxData.disProxy { @@ -416,15 +416,6 @@ func (obj *Client) request(preCtx context.Context, option RequestOption) (respon r, err = obj.getClient(option).Do(reqs) if r != nil { isSse := r.Header.Get("Content-Type") == "text/event-stream" - // if ctxData.responseCallBack != nil { - // var resp *ResponseDebug - // if resp, err = cloneResponse(r, isSse || ctxData.ws); err != nil { - // return - // } - // if err = ctxData.responseCallBack(reqCtx, resp); err != nil { - // return response, tools.WrapError(ErrFatal, "request requestCallBack 回调错误", err) - // } - // } if ctxData.ws { if r.StatusCode == 101 { option.DisRead = true diff --git a/response.go b/response.go index 4410c45..8976c26 100644 --- a/response.go +++ b/response.go @@ -28,6 +28,7 @@ type Response struct { cnl context.CancelFunc content []byte encoding string + disAlive bool disDecode bool disUnzip bool filePath string @@ -73,16 +74,16 @@ func (obj *SseClient) Recv() (Event, error) { } } -func (obj *Client) newResponse(ctx context.Context, cnl context.CancelFunc, r *http.Response, request_option RequestOption) (*Response, error) { - response := &Response{response: r, ctx: ctx, cnl: cnl, bar: request_option.Bar} - if request_option.DisRead { //是否预读 +func (obj *Client) newResponse(ctx context.Context, cnl context.CancelFunc, r *http.Response, option RequestOption) (*Response, error) { + response := &Response{response: r, ctx: ctx, cnl: cnl, bar: option.Bar, disAlive: option.DisAlive} + if option.DisRead { //是否预读 return response, nil } - if request_option.DisUnZip || r.Uncompressed { //是否解压 + if option.DisUnZip || r.Uncompressed { //是否解压 response.disUnzip = true } - response.disDecode = request_option.DisDecode //是否解码 - return response, response.read() //读取内容 + response.disDecode = option.DisDecode //是否解码 + return response, response.read() //读取内容 } type Cookies []*http.Cookie @@ -363,13 +364,26 @@ func (obj *Response) read() error { //读取body,对body 解压,解码操作 return nil } +func (obj *Response) Delete() error { + delFunc, ok := obj.response.Body.(interface{ Delete() error }) + if ok { + obj.response.Body.Close() + return delFunc.Delete() + } else { + return obj.response.Body.Close() + } +} + // 关闭response ,当disRead 为true 请一定要手动关闭 func (obj *Response) Close() error { defer obj.cnl() if obj.webSocket != nil { obj.webSocket.Close("close") - } - if obj.response != nil && obj.response.Body != nil { + obj.Delete() + } else if obj.response != nil && obj.response.Body != nil { + if obj.disAlive { + return obj.Delete() + } tools.CopyWitchContext(obj.ctx, io.Discard, obj.response.Body) return obj.response.Body.Close() } diff --git a/roundTripper.go b/roundTripper.go index 8a3e87a..06bd9e9 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "log" "net" "net/url" "sync" @@ -25,7 +26,7 @@ type roundTripper interface { RoundTrip(*http.Request) (*http.Response, error) } -type connecotr struct { +type Connecotr struct { ctx context.Context ctx2 context.Context cnl context.CancelFunc @@ -36,7 +37,7 @@ type connecotr struct { h2Ja3RawConn *h2ja3.ClientConn } -func (obj *connecotr) Close() error { +func (obj *Connecotr) Close() error { obj.cnl() if obj.h2RawConn != nil { obj.h2RawConn.Close() @@ -52,7 +53,6 @@ type reqTask struct { cnl context.CancelFunc req *http.Request //发送的请求 res *http.Response //接收的请求 - oneConn bool emptyPool chan struct{} err error } @@ -103,7 +103,7 @@ func getHost(req *http.Request) string { func getKey(ctxData *reqCtxData, req *http.Request) string { return fmt.Sprintf("%s@%s", getAddr(ctxData.proxy), getAddr(req.URL)) } -func (obj *RoundTripper) newConnPool(key string, conn *connecotr) *connPool { +func (obj *RoundTripper) newConnPool(key string, conn *Connecotr) *connPool { pool := new(connPool) pool.ctx, pool.cnl = context.WithCancel(obj.ctx) pool.total.Add(1) @@ -124,7 +124,7 @@ func (obj *RoundTripper) getConnPool(key string) *connPool { defer obj.connsLock.Unlock() return obj.connPools[key] } -func (obj *RoundTripper) putConnPool(key string, conn *connecotr) { +func (obj *RoundTripper) putConnPool(key string, conn *Connecotr) { obj.connsLock.Lock() defer obj.connsLock.Unlock() pool, ok := obj.connPools[key] @@ -140,7 +140,8 @@ func (obj *RoundTripper) TlsConfig() *tls.Config { func (obj *RoundTripper) UtlsConfig() *utls.Config { return obj.utlsConfig.Clone() } -func (obj *RoundTripper) dial(ctxData *reqCtxData, key string, req *http.Request) (conn *connecotr, err error) { +func (obj *RoundTripper) dial(ctxData *reqCtxData, key string, req *http.Request) (conn *Connecotr, err error) { + log.Print("new conn") if !ctxData.disProxy && ctxData.proxy == nil { //确定代理 if ctxData.proxy, err = obj.GetProxy(req.Context(), req.URL); err != nil { return nil, err @@ -157,7 +158,7 @@ func (obj *RoundTripper) dial(ctxData *reqCtxData, key string, req *http.Request if err != nil { return nil, err } - conne := new(connecotr) + conne := new(Connecotr) conne.ctx, conne.cnl = context.WithCancel(obj.ctx) var h2 bool if req.URL.Scheme == "https" { @@ -226,7 +227,7 @@ type ClientConnState struct { LastIdle time.Time } -func (obj *connecotr) ping() error { +func (obj *Connecotr) ping() error { if obj.h2RawConn != nil { state := obj.h2RawConn.State() if state.Closed || state.Closing { @@ -247,11 +248,11 @@ type ReadWriteCloser struct { cnl context.CancelFunc cnl2 context.CancelFunc body io.ReadCloser - conn net.Conn + conn *Connecotr } func (obj *ReadWriteCloser) Conn() net.Conn { - return obj.conn + return obj.conn.rawConn } func (obj *ReadWriteCloser) Read(p []byte) (n int, err error) { return obj.body.Read(p) @@ -260,23 +261,28 @@ func (obj *ReadWriteCloser) Close() error { defer obj.cnl() return obj.body.Close() } -func (obj *ReadWriteCloser) Delete() { + +func (obj *ReadWriteCloser) Delete() (err error) { + err = obj.conn.Close() obj.cnl2() + return } -func wrapBody(conn *connecotr, task *reqTask) { +func wrapBody(conn *Connecotr, task *reqTask) { body := new(ReadWriteCloser) conn.ctx2, body.cnl = context.WithCancel(conn.ctx) body.cnl2 = conn.cnl body.body = task.res.Body - if task.res.StatusCode == 101 { - body.conn = conn.rawConn - task.oneConn = true - } + body.conn = conn task.res.Body = body } -func http1Req(conn *connecotr, task *reqTask) { +func http1Req(conn *Connecotr, task *reqTask) { defer task.cnl() + defer func() { + if task.res == nil || task.err != nil { + conn.Close() + } + }() err := task.req.Write(conn.rawConn) if err != nil { task.err = err @@ -288,8 +294,13 @@ func http1Req(conn *connecotr, task *reqTask) { } } -func http2Req(conn *connecotr, task *reqTask) { +func http2Req(conn *Connecotr, task *reqTask) { defer task.cnl() + defer func() { + if task.res == nil || task.err != nil { + conn.Close() + } + }() if conn.h2RawConn != nil { task.res, task.err = conn.h2RawConn.RoundTrip(task.req) } else { @@ -300,7 +311,7 @@ func http2Req(conn *connecotr, task *reqTask) { } } -func (obj *connPool) rwMain(conn *connecotr) { +func (obj *connPool) rwMain(conn *Connecotr) { conn.ctx, conn.cnl = context.WithCancel(obj.ctx) defer func() { if obj.total.Load() == 0 { @@ -317,7 +328,7 @@ func (obj *connPool) rwMain(conn *connecotr) { for { wait.Reset(time.Second * 30) select { - case <-conn.ctx.Done(): + case <-conn.ctx.Done(): //连接池通知关闭,不用再监听了 return case <-wait.C: if conn.ping() != nil { @@ -328,37 +339,38 @@ func (obj *connPool) rwMain(conn *connecotr) { }() for { select { - case <-conn.ctx.Done(): + case <-conn.ctx.Done(): //连接池通知关闭,等待连接被释放掉 + <-conn.ctx2.Done() return - case task := <-obj.tasks: - if conn.ping() != nil { + case task := <-obj.tasks: //接收到任务 + if conn.ping() != nil { //判断连接是否异常 select { - case obj.tasks <- task: - case task.emptyPool <- struct{}{}: + case obj.tasks <- task: //任务给池子里其它连接 + case task.emptyPool <- struct{}{}: //告诉提交任务方,池子没有可用连接 } - return + return //由于连接异常直接结束 } if !conn.h2 { select { - case <-conn.ctx.Done(): - return - case <-conn.ctx2.Done(): + case <-conn.ctx2.Done(): //http1.1 连接被占用 default: select { - case obj.tasks <- task: - case task.emptyPool <- struct{}{}: + case obj.tasks <- task: //任务给池子里其它连接 + case task.emptyPool <- struct{}{}: //告诉提交任务方,池子没有可用连接 } - return + continue //由于连接被占用,开始下一个循环 } } - wait.Reset(time.Hour * 24 * 365) + wait.Reset(time.Hour * 24 * 365) //停止健康检查 if conn.h2 { go http2Req(conn, task) } else { go http1Req(conn, task) } + //等待任务完成 <-task.ctx.Done() - if task.oneConn || task.req == nil { + //如果没有response返回,就认定这个连接异常,直接返回 + if task.req == nil || task.err != nil { return } wait.Reset(time.Second * 30) @@ -414,7 +426,7 @@ func NewRoundTripper(preCtx context.Context, option RoundTripperOption) *RoundTr InsecureSkipVerify: true, InsecureSkipTimeVerify: true, SessionTicketKey: [32]byte{}, - ClientSessionCache: utls.NewLRUClientSessionCache(0), + // ClientSessionCache: utls.NewLRUClientSessionCache(0), } return &RoundTripper{ tlsConfig: tlsConfig, @@ -479,19 +491,20 @@ func (obj *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { task := &reqTask{req: req, emptyPool: make(chan struct{})} task.ctx, task.cnl = context.WithCancel(obj.ctx) defer task.cnl() - ok, err := obj.poolRoundTrip(task, key) - if err != nil { - return nil, err - } - if ok { - if ctxData.responseCallBack != nil { - if err = ctxData.responseCallBack(task.req.Context(), req, task.res); err != nil { - task.err = err - } + if !ctxData.disAlive { + ok, err := obj.poolRoundTrip(task, key) + if err != nil { + return nil, err + } + if ok { + if ctxData.responseCallBack != nil { + if err = ctxData.responseCallBack(task.req.Context(), req, task.res); err != nil { + task.err = err + } + } + return task.res, task.err } - return task.res, task.err } - conn, err := obj.dial(ctxData, key, req) if err != nil { return nil, err @@ -502,16 +515,17 @@ func (obj *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { go http2Req(conn, task) } <-task.ctx.Done() - if task.err == nil && task.res != nil && !task.oneConn { - obj.putConnPool(key, conn) + if ctxData.responseCallBack != nil { + if err = ctxData.responseCallBack(task.req.Context(), req, task.res); err != nil { + task.err = err + conn.Close() + } } if task.err == nil && task.res == nil { task.err = obj.ctx.Err() } - if ctxData.responseCallBack != nil { - if err = ctxData.responseCallBack(task.req.Context(), req, task.res); err != nil { - task.err = err - } + if task.err == nil && task.res != nil && task.res.StatusCode != 101 && !ctxData.disAlive { + obj.putConnPool(key, conn) } return task.res, task.err }