From b4e48062a0397760288b185a9ca10ac121e0c35f Mon Sep 17 00:00:00 2001 From: bxd <2216403312@qq.com> Date: Tue, 28 Nov 2023 21:03:01 +0800 Subject: [PATCH] optimize code --- body.go | 46 ++++++++++++++------- conn.go | 87 ++++++++++++++++++++++++--------------- dial.go | 45 +++++++------------- requests.go | 4 +- roundTripper.go | 50 ++++++++++------------ rw.go | 4 +- test/request/file_test.go | 26 +----------- tools.go | 14 +------ 8 files changed, 126 insertions(+), 150 deletions(-) diff --git a/body.go b/body.go index b8fb5d8..881c904 100644 --- a/body.go +++ b/body.go @@ -44,13 +44,17 @@ func (obj *orderMap) Del(key string) { } func (obj *orderMap) parseHeaders() (map[string][]string, []string) { head := make(http.Header) - for kk, vv := range obj.data { - if vvs, ok := vv.([]string); ok { + data := any2Map(obj.data) + if data == nil { + data = obj.data + } + for _, kk := range obj.order { + if vvs, ok := data[kk].([]any); ok { for _, vv := range vvs { head.Add(kk, fmt.Sprint(vv)) } } else { - head.Add(kk, fmt.Sprint(vv)) + head.Add(kk, fmt.Sprint(data[kk])) } } return head, obj.order @@ -139,8 +143,12 @@ func (obj *orderMap) isformPip() bool { if len(obj.order) == 0 || len(obj.data) == 0 { return false } + data := any2Map(obj.data) + if data == nil { + data = obj.data + } for _, key := range obj.order { - if vals, ok := obj.data[key].([]any); ok { + if vals, ok := data[key].([]any); ok { for _, val := range vals { if file, ok := val.(File); ok { if _, ok := file.Content.(io.Reader); ok { @@ -149,7 +157,7 @@ func (obj *orderMap) isformPip() bool { } } } else { - if file, ok := obj.data[key].(File); ok { + if file, ok := data[key].(File); ok { if _, ok := file.Content.(io.Reader); ok { return true } @@ -159,15 +167,19 @@ func (obj *orderMap) isformPip() bool { return false } func (obj *orderMap) formWriteMain(writer *multipart.Writer) (err error) { + data := any2Map(obj.data) + if data == nil { + data = obj.data + } for _, key := range obj.order { - if vals, ok := obj.data[key].([]any); ok { + if vals, ok := data[key].([]any); ok { for _, val := range vals { if err = formWrite(writer, key, val); err != nil { return } } } else { - if err = formWrite(writer, key, obj.data[key]); err != nil { + if err = formWrite(writer, key, data[key]); err != nil { return } } @@ -184,14 +196,18 @@ func paramsWrite(buf *bytes.Buffer, key string, val any) { buf.WriteString(url.QueryEscape(fmt.Sprint(val))) } func (obj *orderMap) parseParams() *bytes.Buffer { + data := any2Map(obj.data) + if data == nil { + data = obj.data + } buf := bytes.NewBuffer(nil) for _, k := range obj.order { - if vals, ok := obj.data[k].([]any); ok { + if vals, ok := data[k].([]any); ok { for _, v := range vals { paramsWrite(buf, k, v) } } else { - paramsWrite(buf, k, obj.data[k]) + paramsWrite(buf, k, data[k]) } } return buf @@ -239,14 +255,14 @@ func (obj *orderMap) MarshalJSON() ([]byte, error) { return buf.Bytes(), nil } -func any2Map(val any) map[any]any { +func any2Map(val any) map[string]any { mapType := reflect.TypeOf(val) if mapType.Kind() != reflect.Map { return nil } mapValue := reflect.ValueOf(val) keys := mapValue.MapKeys() - result := make(map[any]any) + result := make(map[string]any) for _, key := range keys { keyData := key.Interface() valueData := mapValue.MapIndex(key).Interface() @@ -258,7 +274,7 @@ func any2Map(val any) map[any]any { } valueData = valueData2 } - result[keyData] = valueData + result[fmt.Sprint(keyData)] = valueData } return result } @@ -305,7 +321,7 @@ func (obj *RequestOption) newBody(val any, valType int) (io.Reader, *orderMap, [ return nil, orderMap, nil, nil case *orderMap: return nil, value, nil, nil - case map[any]any: + case map[string]any: orderMap := NewOrderMap() for kk, vv := range value { if vvs, ok := vv.([]any); ok { @@ -313,9 +329,9 @@ func (obj *RequestOption) newBody(val any, valType int) (io.Reader, *orderMap, [ for i, vv := range vvs { vvData[i] = vv } - orderMap.Set(fmt.Sprint(kk), vvData) + orderMap.Set(kk, vvData) } else { - orderMap.Set(fmt.Sprint(kk), vv) + orderMap.Set(kk, vv) } } return nil, orderMap, nil, nil diff --git a/conn.go b/conn.go index ff63949..a5a910c 100644 --- a/conn.go +++ b/conn.go @@ -16,7 +16,7 @@ import ( ) type connecotr struct { - key connKey + connKey connKey deleteCtx context.Context //force close deleteCnl context.CancelCauseFunc @@ -27,12 +27,12 @@ type connecotr struct { bodyCnl context.CancelCauseFunc rawConn net.Conn - h2 bool - r *bufio.Reader - w *bufio.Writer h2RawConn *http2.ClientConn - pr *pipCon - isPool bool + + r *bufio.Reader + w *bufio.Writer + pr *pipCon + inPool bool } func (obj *connecotr) withCancel(deleteCtx context.Context, closeCtx context.Context) { @@ -73,6 +73,9 @@ func (obj *connecotr) Write(b []byte) (int, error) { } func (obj *connecotr) h2Closed() bool { + if obj.h2RawConn == nil { + return false + } state := obj.h2RawConn.State() return state.Closed || state.Closing } @@ -88,11 +91,10 @@ func (obj *connecotr) http1Req(task *reqTask) { if task.debug { debugPrint(task.requestId, "http1 req start") } - if task.orderHeaders != nil && len(task.orderHeaders) > 0 { - task.err = httpWrite(task.req, obj.w, task.orderHeaders) - } else if task.err = task.req.Write(obj); task.err == nil { - task.err = obj.w.Flush() - } + task.err = httpWrite(task.req, obj.w, task.orderHeaders) + // if task.err = task.req.Write(obj); task.err == nil { + // task.err = obj.w.Flush() + // } if task.debug { debugPrint(task.requestId, "http1 req write ok ,err: ", task.err) } @@ -124,9 +126,17 @@ func (obj *connecotr) http2Req(task *reqTask) { debugPrint(task.requestId, "http2 req ok,err: ", task.err) } } +func (obj *connecotr) waitBodyClose() error { + select { + case <-obj.bodyCtx.Done(): //wait body close + return nil + case <-obj.deleteCtx.Done(): //force conn close + return tools.WrapError(context.Cause(obj.deleteCtx), "delete ctx error: ") + } +} -func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer) (*http.Response, error, bool) { - if obj.h2 && obj.h2Closed() { +func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer, waitBody bool) (*http.Response, error, bool) { + if obj.h2Closed() { if task.debug { debugPrint(task.requestId, "h2 con is closed") } @@ -140,36 +150,28 @@ func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer) (*http.Resp return nil, tools.WrapError(obj.closeCtx.Err(), "close ctx error: "), true default: } - if obj.h2 { + if obj.h2RawConn != nil { go obj.http2Req(task) } else { go obj.http1Req(task) } if afterTime == nil { afterTime = time.NewTimer(task.responseHeaderTimeout) + defer afterTime.Stop() } else { afterTime.Reset(task.responseHeaderTimeout) } - if !obj.isPool { - defer afterTime.Stop() - } select { case <-task.ctx.Done(): - if task.res != nil && task.err == nil && obj.isPool { - select { - case <-obj.bodyCtx.Done(): //wait body close - task.err = tools.WrapError(obj.deleteCtx.Err(), "body ctx error: ") - case <-obj.deleteCtx.Done(): //force conn close - task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ") - } + if waitBody && task.res != nil && task.err == nil { + task.err = obj.waitBodyClose() } case <-obj.deleteCtx.Done(): //force conn close task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ") - task.cnl() case <-afterTime.C: task.err = errors.New("response Header is Timeout") - task.cnl() } + task.cnl() return task.res, task.err, false } @@ -178,11 +180,30 @@ type connPool struct { deleteCnl context.CancelCauseFunc closeCtx context.Context closeCnl context.CancelCauseFunc - key connKey + connKey connKey total atomic.Int64 tasks chan *reqTask - rt *roundTripper - lock sync.Mutex + connPools *connPools +} +type connPools struct { + connPools sync.Map +} + +func (obj *connPools) get(key connKey) *connPool { + val, ok := obj.connPools.Load(key) + if !ok { + return nil + } + return val.(*connPool) +} +func (obj *connPools) set(key connKey, pool *connPool) { + obj.connPools.Store(key, pool) +} +func (obj *connPools) del(key connKey) { + obj.connPools.Delete(key) +} +func (obj *connPools) iter(f func(key any, value any) bool) { + obj.connPools.Range(f) } func (obj *connPool) notice(task *reqTask) { @@ -205,10 +226,8 @@ func (obj *connPool) rwMain(conn *connecotr) { obj.close() } }() - select { - case <-conn.deleteCtx.Done(): //force close all conn + if err := conn.waitBodyClose(); err != nil { return - case <-conn.bodyCtx.Done(): //wait body close } for { select { @@ -224,7 +243,7 @@ func (obj *connPool) rwMain(conn *connecotr) { } return } - res, err, notice := conn.taskMain(task, afterTime) + res, err, notice := conn.taskMain(task, afterTime, true) if notice { obj.notice(task) return @@ -241,5 +260,5 @@ func (obj *connPool) forceClose() { } func (obj *connPool) close() { obj.closeCnl(errors.New("connPool close")) - obj.rt.delConnPool(obj.key) + obj.connPools.del(obj.connKey) } diff --git a/dial.go b/dial.go index 572102e..4a84450 100644 --- a/dial.go +++ b/dial.go @@ -24,7 +24,6 @@ type DialClient struct { dialer *net.Dialer dnsIpData sync.Map dns *net.UDPAddr - localAddr *net.TCPAddr getAddrType func(host string) gtls.AddrType } type msgClient struct { @@ -40,16 +39,8 @@ type DialOption struct { Dns *net.UDPAddr GetAddrType func(host string) gtls.AddrType } -type DialerOption struct { - DialTimeout time.Duration - KeepAlive time.Duration - LocalAddr *net.TCPAddr //network card ip - AddrType gtls.AddrType //first ip type - Dns *net.UDPAddr - GetAddrType func(host string) gtls.AddrType -} -func NewDialer(option DialerOption) *net.Dialer { +func NewDialer(option DialOption) *net.Dialer { if option.KeepAlive == 0 { option.KeepAlive = time.Second * 30 } @@ -80,16 +71,8 @@ func NewDialer(option DialerOption) *net.Dialer { } func NewDail(option DialOption) *DialClient { return &DialClient{ - dialer: NewDialer(DialerOption{ - DialTimeout: option.DialTimeout, - KeepAlive: option.KeepAlive, - LocalAddr: option.LocalAddr, - AddrType: option.AddrType, - Dns: option.Dns, - GetAddrType: option.GetAddrType, - }), + dialer: NewDialer(option), dns: option.Dns, - localAddr: option.LocalAddr, getAddrType: option.GetAddrType, } } @@ -306,45 +289,47 @@ func (obj *DialClient) lookupIPAddr(ctx context.Context, host string, ips []net. return nil, errors.New("dns parse host error") } func (obj *DialClient) getDialer(ctxData *reqCtxData, parseDns bool) *net.Dialer { - var dialerOption DialerOption + var dialOption DialOption var isNew bool if ctxData.dialTimeout == 0 { - dialerOption.DialTimeout = obj.dialer.Timeout + dialOption.DialTimeout = obj.dialer.Timeout } else { - dialerOption.DialTimeout = ctxData.dialTimeout + dialOption.DialTimeout = ctxData.dialTimeout if ctxData.dialTimeout != obj.dialer.Timeout { isNew = true } } if ctxData.keepAlive == 0 { - dialerOption.KeepAlive = obj.dialer.KeepAlive + dialOption.KeepAlive = obj.dialer.KeepAlive } else { - dialerOption.KeepAlive = ctxData.keepAlive + dialOption.KeepAlive = ctxData.keepAlive if ctxData.keepAlive != obj.dialer.KeepAlive { isNew = true } } if ctxData.localAddr == nil { - dialerOption.LocalAddr = obj.localAddr + if obj.dialer.LocalAddr != nil { + dialOption.LocalAddr = obj.dialer.LocalAddr.(*net.TCPAddr) + } } else { - dialerOption.LocalAddr = ctxData.localAddr - if ctxData.localAddr.String() != obj.localAddr.String() { + dialOption.LocalAddr = ctxData.localAddr + if ctxData.localAddr.String() != obj.dialer.LocalAddr.String() { isNew = true } } if ctxData.dns == nil { - dialerOption.Dns = obj.dns + dialOption.Dns = obj.dns } else { - dialerOption.Dns = ctxData.dns + dialOption.Dns = ctxData.dns if parseDns && ctxData.dns.String() != obj.dns.String() { isNew = true } } if isNew { - return NewDialer(dialerOption) + return NewDialer(dialOption) } else { return obj.dialer } diff --git a/requests.go b/requests.go index f3c667f..e523309 100644 --- a/requests.go +++ b/requests.go @@ -99,9 +99,7 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err } //init orderHeaders,this must after init headers if option.OrderHeaders == nil { - if option.Ja3Spec.IsSet() { - ctxData.orderHeaders = ja3.DefaultH1OrderHeaders() - } + ctxData.orderHeaders = ja3.DefaultH1OrderHeaders() } else { orderHeaders := []string{} for _, key := range option.OrderHeaders { diff --git a/roundTripper.go b/roundTripper.go index d4c8cd3..ea4121c 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -7,7 +7,6 @@ import ( "errors" "net" "net/url" - "sync" "time" "net/http" @@ -67,7 +66,7 @@ func getKey(ctxData *reqCtxData, req *http.Request) connKey { type roundTripper struct { ctx context.Context cnl context.CancelFunc - connPools sync.Map + connPools *connPools dialer *DialClient tlsConfig *tls.Config utlsConfig *utls.Config @@ -117,46 +116,39 @@ func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper cnl: cnl, dialer: dialClient, proxy: option.GetProxy, + connPools: new(connPools), } } func (obj *roundTripper) newConnPool(conn *connecotr, key connKey) *connPool { pool := new(connPool) - pool.key = key + pool.connKey = key pool.deleteCtx, pool.deleteCnl = context.WithCancelCause(obj.ctx) pool.closeCtx, pool.closeCnl = context.WithCancelCause(pool.deleteCtx) pool.tasks = make(chan *reqTask) - pool.rt = obj + pool.connPools = obj.connPools pool.total.Add(1) go pool.rwMain(conn) return pool } func (obj *roundTripper) getConnPool(key connKey) *connPool { - val, ok := obj.connPools.Load(key) - if !ok { - return nil - } - return val.(*connPool) -} -func (obj *roundTripper) delConnPool(key connKey) { - obj.connPools.Delete(key) + return obj.connPools.get(key) } func (obj *roundTripper) putConnPool(key connKey, conn *connecotr) { - conn.isPool = true - if !conn.h2 { + conn.inPool = true + if conn.h2RawConn == nil { go conn.read() } - val, ok := obj.connPools.Load(key) - if ok { - pool := val.(*connPool) + pool := obj.connPools.get(key) + if pool != nil { select { case <-pool.closeCtx.Done(): - obj.connPools.Store(key, obj.newConnPool(conn, key)) + obj.connPools.set(key, obj.newConnPool(conn, key)) default: pool.total.Add(1) go pool.rwMain(conn) } } else { - obj.connPools.Store(key, obj.newConnPool(conn, key)) + obj.connPools.set(key, obj.newConnPool(conn, key)) } } func (obj *roundTripper) tlsConfigClone() *tls.Config { @@ -187,6 +179,7 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque } conne := new(connecotr) conne.withCancel(obj.ctx, obj.ctx) + var h2 bool if req.URL.Scheme == "https" { ctx, cnl := context.WithTimeout(req.Context(), ctxData.tlsHandshakeTimeout) defer cnl() @@ -199,19 +192,19 @@ func (obj *roundTripper) dial(ctxData *reqCtxData, key *connKey, req *http.Reque if err != nil { return conne, tools.WrapError(err, "add ja3 tls error") } - conne.h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" + h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" netConn = tlsConn } else { tlsConn, err := obj.dialer.addTls(ctx, netConn, host, ctxData.isWs || ctxData.forceHttp1, obj.tlsConfigClone()) if err != nil { return conne, tools.WrapError(err, "add tls error") } - conne.h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" + h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2" netConn = tlsConn } } conne.rawConn = netConn - if conne.h2 { + if h2 { if conne.h2RawConn, err = http2.NewClientConn(func() { conne.closeCnl(errors.New("http2 client close")) }, netConn, ctxData.h2Ja3Spec); err != nil { @@ -286,19 +279,18 @@ func (obj *roundTripper) poolRoundTrip(task *reqTask, key connKey) (bool, error) } func (obj *roundTripper) closeConns() { - obj.connPools.Range(func(key, value any) bool { + obj.connPools.iter(func(key, value any) bool { pool := value.(*connPool) pool.close() - obj.connPools.Delete(key) + obj.connPools.del(key.(connKey)) return true }) } - func (obj *roundTripper) forceCloseConns() { - obj.connPools.Range(func(key, value any) bool { + obj.connPools.iter(func(key, value any) bool { pool := value.(*connPool) pool.forceClose() - obj.connPools.Delete(key) + obj.connPools.del(key.(connKey)) return true }) } @@ -336,13 +328,13 @@ newConn: if err != nil { return nil, err } - if _, _, notice := conn.taskMain(task, nil); notice { + if _, _, notice := conn.taskMain(task, nil, false); notice { goto newConn } if task.err == nil && task.res == nil { task.err = obj.ctx.Err() } - conn.key = ckey + conn.connKey = ckey if task.inPool() && !ctxData.disAlive { if task.debug { debugPrint(ctxData.requestId, "conn put conn pool") diff --git a/rw.go b/rw.go index 26a8e94..c27ddb2 100644 --- a/rw.go +++ b/rw.go @@ -26,10 +26,10 @@ func (obj *readWriteCloser) Close() (err error) { return } func (obj *readWriteCloser) InPool() bool { - return obj.conn.isPool + return obj.conn.inPool } func (obj *readWriteCloser) Proxy() string { - return obj.conn.key.proxy + return obj.conn.connKey.proxy } // safe close conn diff --git a/test/request/file_test.go b/test/request/file_test.go index 801eecd..f5abe8e 100644 --- a/test/request/file_test.go +++ b/test/request/file_test.go @@ -8,35 +8,11 @@ import ( "github.com/gospider007/requests" ) -func TestSendFile(t *testing.T) { - resp, err := requests.Post(nil, "https://httpbin.org/anything", requests.RequestOption{ - Form: map[string]any{ - "file": requests.File{ - Content: []byte("test"), - FileName: "test.txt", - ContentType: "text/plain", - }, - }, - }) - if err != nil { - t.Fatal(err) - } - jsonData, err := resp.Json() - if err != nil { - t.Fatal(err) - } - if !strings.HasPrefix(jsonData.Get("headers.Content-Type").String(), "multipart/form-data") { - t.Fatal("json data error") - } - if jsonData.Get("files.file").String() != "test" { - t.Fatal("json data error") - } -} func TestSendFileWithReader(t *testing.T) { resp, err := requests.Post(nil, "https://httpbin.org/anything", requests.RequestOption{ Form: map[string]any{ "file": requests.File{ - Content: bytes.NewBuffer([]byte("test")), + Content: bytes.NewBuffer([]byte("test")), //support: io.Reader, string, []byte FileName: "test.txt", ContentType: "text/plain", }, diff --git a/tools.go b/tools.go index 06251d5..e7b4e19 100644 --- a/tools.go +++ b/tools.go @@ -2,7 +2,6 @@ package requests import ( "bufio" - "errors" "fmt" "io" "net" @@ -74,9 +73,6 @@ func removeZone(host string) string //go:linkname shouldSendContentLength net/http.(*transferWriter).shouldSendContentLength func shouldSendContentLength(t *http.Request) bool -//go:linkname stringContainsCTLByte net/http.stringContainsCTLByte -func stringContainsCTLByte(s string) bool - func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err error) { host := r.Host if host == "" { @@ -86,21 +82,15 @@ func httpWrite(r *http.Request, w *bufio.Writer, orderHeaders []string) (err err if err != nil { return err } - if !httpguts.ValidHostHeader(host) { - return errors.New("http: invalid Host header") - } host = removeZone(host) ruri := r.URL.RequestURI() if r.Method == "CONNECT" && r.URL.Path == "" { - // CONNECT requests normally give just the host and port, not a full URL. - ruri = host if r.URL.Opaque != "" { ruri = r.URL.Opaque + } else { + ruri = host } } - if stringContainsCTLByte(ruri) { - return errors.New("net/http: can't write control character in Request.URL") - } if r.Header.Get("Host") == "" { r.Header.Set("Host", host) }