From c83d1aa4c2248d057124054cceb65328da1d0718 Mon Sep 17 00:00:00 2001 From: gospider <2216403312@qq.com> Date: Thu, 24 Jul 2025 16:50:47 +0800 Subject: [PATCH] sync --- client.go | 5 -- conn.go | 126 ++++++++++---------------------- go.mod | 10 +-- go.sum | 22 +++--- requests.go | 4 +- response.go | 97 +++++++++++++----------- roundTripper.go | 59 ++++++--------- rw.go | 11 +-- test/fingerprint/http2_test.go | 1 + test/protocol/websocket_test.go | 98 ++++++++++++++++++++++++- test/request/stream_test.go | 107 +++++++++++++++++++++++++++ test/session_test.go | 1 + tools.go | 15 ---- 13 files changed, 337 insertions(+), 219 deletions(-) diff --git a/client.go b/client.go index 1ef0ce5..d86a06a 100644 --- a/client.go +++ b/client.go @@ -59,13 +59,8 @@ func NewClient(preCtx context.Context, options ...ClientOption) (*Client, error) return result, err } -func (obj *Client) CloseConns() { - obj.transport.closeConns() -} - // Close the client and cannot be used again after shutdown func (obj *Client) Close() { obj.closed = true - obj.CloseConns() obj.cnl() } diff --git a/conn.go b/conn.go index 25ce808..2267b44 100644 --- a/conn.go +++ b/conn.go @@ -3,31 +3,19 @@ package requests import ( "context" "errors" - "io" - "iter" "net" - "net/http" - "sync" - "sync/atomic" "time" + "github.com/gospider007/http1" "github.com/gospider007/tools" ) -var maxRetryCount = 10 +var maxRetryCount = 5 -type Conn interface { - CloseWithError(err error) error - DoRequest(*http.Request, []interface { - Key() string - Val() any - }) (*http.Response, context.Context, error) - Stream() io.ReadWriteCloser -} type connecotr struct { forceCtx context.Context //force close forceCnl context.CancelCauseFunc - Conn Conn + Conn http1.Conn c net.Conn proxys []Address } @@ -48,11 +36,7 @@ func (obj *connecotr) CloseWithError(err error) error { func (obj *connecotr) wrapBody(task *reqTask) { body := new(wrapBody) - if task.reqCtx.response.Body == nil { - task.reqCtx.response.Body = http.NoBody - } - rawBody := task.reqCtx.response.Body - body.rawBody = rawBody + body.rawBody = task.reqCtx.response.Body.(*http1.ClientBody) body.conn = obj task.reqCtx.response.Body = body task.reqCtx.response.Request = task.reqCtx.request @@ -60,17 +44,18 @@ func (obj *connecotr) wrapBody(task *reqTask) { func (obj *connecotr) httpReq(task *reqTask, done chan struct{}) (err error) { defer close(done) - task.reqCtx.response, task.bodyCtx, err = obj.Conn.DoRequest(task.reqCtx.request, task.reqCtx.option.orderHeaders.Data()) - if task.reqCtx.response != nil { - obj.wrapBody(task) - } - if err != nil { - err = tools.WrapError(err, "roundTrip error") + response, bodyCtx, derr := obj.Conn.DoRequest(task.reqCtx.request, &http1.Option{OrderHeaders: task.reqCtx.option.orderHeaders.Data()}) + if derr != nil { + err = tools.WrapError(derr, "roundTrip error") + return } + task.reqCtx.response = response + task.bodyCtx = bodyCtx + obj.wrapBody(task) return } -func (obj *connPool) taskMain(conn *connecotr, task *reqTask) (err error) { +func (obj *connecotr) taskMain(task *reqTask) (err error) { defer func() { if err != nil && task.reqCtx.option.ErrCallBack != nil { task.reqCtx.err = err @@ -87,16 +72,16 @@ func (obj *connPool) taskMain(conn *connecotr, task *reqTask) (err error) { } if err == nil && task.reqCtx.response != nil && task.reqCtx.response.Body != nil && task.bodyCtx != nil { select { - case <-conn.forceCtx.Done(): - err = context.Cause(conn.forceCtx) + case <-obj.forceCtx.Done(): + err = context.Cause(obj.forceCtx) case <-task.reqCtx.Context().Done(): if context.Cause(task.reqCtx.Context()) != tools.ErrNoErr { err = context.Cause(task.reqCtx.Context()) } if err == nil && task.reqCtx.response.StatusCode == 101 { select { - case <-conn.forceCtx.Done(): - err = context.Cause(conn.forceCtx) + case <-obj.forceCtx.Done(): + err = context.Cause(obj.forceCtx) case <-task.bodyCtx.Done(): if context.Cause(task.bodyCtx) != tools.ErrNoErr { err = context.Cause(task.bodyCtx) @@ -110,28 +95,30 @@ func (obj *connPool) taskMain(conn *connecotr, task *reqTask) (err error) { } } if err != nil { - conn.CloseWithError(tools.WrapError(err, "taskMain close with error")) + obj.CloseWithError(tools.WrapError(err, "taskMain close with error")) } }() select { - case <-conn.forceCtx.Done(): //force conn close - err = context.Cause(conn.forceCtx) + case <-obj.forceCtx.Done(): //force conn close + err = context.Cause(obj.forceCtx) task.enableRetry = true task.isNotice = true return default: } done := make(chan struct{}) - go conn.httpReq(task, done) + go func() { + err = obj.httpReq(task, done) + }() select { - case <-conn.forceCtx.Done(): //force conn close - err = tools.WrapError(context.Cause(conn.forceCtx), "taskMain delete ctx error: ") + case <-obj.forceCtx.Done(): //force conn close + err = tools.WrapError(context.Cause(obj.forceCtx), "taskMain delete ctx error: ") case <-time.After(task.reqCtx.option.ResponseHeaderTimeout): err = errors.New("ResponseHeaderTimeout error: ") case <-task.ctx.Done(): err = context.Cause(task.ctx) case <-done: - if task.reqCtx.response == nil { + if err == nil && task.reqCtx.response == nil { err = context.Cause(task.ctx) if err == nil { err = errors.New("body done response is nil") @@ -149,71 +136,32 @@ func (obj *connPool) taskMain(conn *connecotr, task *reqTask) (err error) { return } -type connPool struct { - forceCtx context.Context - forceCnl context.CancelCauseFunc - tasks chan *reqTask - connPools *connPools - connKey string - total atomic.Int64 -} -type connPools struct { - connPools sync.Map -} - -func newConnPools() *connPools { - return new(connPools) -} -func (obj *connPools) get(task *reqTask) *connPool { - val, ok := obj.connPools.Load(task.key) - if !ok { - return nil - } - return val.(*connPool) -} -func (obj *connPools) set(task *reqTask, pool *connPool) { - obj.connPools.Store(task.key, pool) -} -func (obj *connPools) del(key string) { - obj.connPools.Delete(key) -} -func (obj *connPools) Range() iter.Seq2[string, *connPool] { - return func(yield func(string, *connPool) bool) { - obj.connPools.Range(func(key, value any) bool { - return yield(key.(string), value.(*connPool)) - }) - } -} - -func (obj *connPool) rwMain(done chan struct{}, conn *connecotr) { - conn.withCancel(obj.forceCtx) +func (obj *connecotr) rwMain(ctx context.Context, done chan struct{}, tasks chan *reqTask) (err error) { + obj.withCancel(ctx) defer func() { - conn.CloseWithError(errors.New("connPool rwMain close")) - obj.total.Add(-1) - if obj.total.Load() <= 0 { - obj.close(errors.New("conn pool close")) + if err != nil && err != tools.ErrNoErr { + obj.CloseWithError(tools.WrapError(err, "rwMain close with error")) } }() close(done) for { select { - case <-conn.forceCtx.Done(): //force close conn - return - case task := <-obj.tasks: //recv task + case <-obj.forceCtx.Done(): //force close conn + return errors.New("connecotr force close") + case task := <-tasks: //recv task if task == nil { - return + return errors.New("task is nil") } - err := obj.taskMain(conn, task) + err = obj.taskMain(task) if err != nil { return } + if task.reqCtx.response != nil && task.reqCtx.response.StatusCode == 101 { + return tools.ErrNoErr + } } } } -func (obj *connPool) close(err error) { - obj.connPools.del(obj.connKey) - obj.forceCnl(tools.WrapError(err, "connPool close")) -} func newSSHConn(sshCon net.Conn, rawCon net.Conn) *sshConn { return &sshConn{sshCon: sshCon, rawCon: rawCon} diff --git a/go.mod b/go.mod index a2c425d..4bc7580 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/gospider007/bs4 v0.0.0-20250413121342-fed910fb00c9 github.com/gospider007/gson v0.0.0-20250718004537-ff15820964bd github.com/gospider007/gtls v0.0.0-20250718003831-90cdeb97a23f - github.com/gospider007/http1 v0.0.0-20250718004641-26f982c140cf + github.com/gospider007/http1 v0.0.0-20250718091014-9ea72dfb1370 github.com/gospider007/http2 v0.0.0-20250718004700-7af5b064e352 github.com/gospider007/http3 v0.0.0-20250718004757-02ceb5fa2d6e github.com/gospider007/ja3 v0.0.0-20250627013834-1d2966014638 @@ -18,9 +18,9 @@ require ( github.com/gospider007/websocket v0.0.0-20250718010025-4c017acfd478 github.com/klauspost/compress v1.18.0 github.com/minio/minlz v1.0.1 - github.com/quic-go/quic-go v0.53.0 + github.com/quic-go/quic-go v0.54.0 github.com/refraction-networking/uquic v0.0.6 - github.com/refraction-networking/utls v1.7.4-0.20250621163342-5abccec539e6 + github.com/refraction-networking/utls v1.8.0 github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 golang.org/x/crypto v0.40.0 golang.org/x/net v0.42.0 @@ -37,7 +37,6 @@ require ( github.com/bodgit/windows v1.0.1 // indirect github.com/caddyserver/certmagic v0.23.0 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect - github.com/cloudflare/circl v1.6.1 // indirect github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect github.com/gaukas/clienthellod v0.4.2 // indirect github.com/gaukas/godicttls v0.0.4 // indirect @@ -63,6 +62,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nwaples/rardecode/v2 v2.1.1 // indirect github.com/onsi/ginkgo/v2 v2.23.4 // indirect + github.com/onsi/gomega v1.37.0 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/quic-go/qpack v0.5.1 // indirect @@ -82,7 +82,7 @@ require ( go.uber.org/zap v1.27.0 // indirect go.uber.org/zap/exp v0.3.0 // indirect go4.org v0.0.0-20230225012048-214862532bf5 // indirect - golang.org/x/exp v0.0.0-20250717185816-542afb5b7346 // indirect + golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect golang.org/x/image v0.29.0 // indirect golang.org/x/mod v0.26.0 // indirect golang.org/x/sync v0.16.0 // indirect diff --git a/go.sum b/go.sum index 4c0e2d0..f5daefa 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,6 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= -github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -116,8 +114,8 @@ github.com/gospider007/gson v0.0.0-20250718004537-ff15820964bd h1:aby4HnAGVJt5pI github.com/gospider007/gson v0.0.0-20250718004537-ff15820964bd/go.mod h1:GxCATDh+u/TLHTAI9p1kXfaGUkVNjHuY+Mhxdo4l5k8= github.com/gospider007/gtls v0.0.0-20250718003831-90cdeb97a23f h1:W/ug9EHRcduL40RNtsKR9Ob0VnMQF/pps6iLTmorhck= github.com/gospider007/gtls v0.0.0-20250718003831-90cdeb97a23f/go.mod h1:iTnRK0DU3YH7MlZZ9VdT5gQAFPKzHRtFi3EXpnosOAM= -github.com/gospider007/http1 v0.0.0-20250718004641-26f982c140cf h1:PUyHtMJ/mvyb0nBsgwMmEnZH0fqJabPMlC9fJ1Ks3R4= -github.com/gospider007/http1 v0.0.0-20250718004641-26f982c140cf/go.mod h1:GqV2VRFBHYl6Ovir4gnGnuUeELdw8Klg2BRVdHapeVg= +github.com/gospider007/http1 v0.0.0-20250718091014-9ea72dfb1370 h1:YXqERVMVbSH9Y/uRui3agK6ztLhOGOUIS8hP9zNunBA= +github.com/gospider007/http1 v0.0.0-20250718091014-9ea72dfb1370/go.mod h1:UwjRyGpzSwrWgNU/QVEX4VUSzo3oWTV0Oxr1nLxxAI0= github.com/gospider007/http2 v0.0.0-20250718004700-7af5b064e352 h1:E+Gy49dUc4WWfHA+ZA6BvueQ3MeimLJ4dptT+dk2rvg= github.com/gospider007/http2 v0.0.0-20250718004700-7af5b064e352/go.mod h1:TNUBmnegjmQMAdjdZqpqGvSJ1mpikVGYIAtBVvpfqhE= github.com/gospider007/http3 v0.0.0-20250718004757-02ceb5fa2d6e h1:kNwZM1/FnNLk2MASfk0VPdYk/nOR/wu+YrWlTixhYSM= @@ -179,8 +177,8 @@ github.com/nwaples/rardecode/v2 v2.1.1 h1:OJaYalXdliBUXPmC8CZGQ7oZDxzX1/5mQmgn0/ github.com/nwaples/rardecode/v2 v2.1.1/go.mod h1:7uz379lSxPe6j9nvzxUZ+n7mnJNgjsRNb6IbvGVHRmw= github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= -github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU= -github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= +github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= +github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= @@ -194,12 +192,12 @@ github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.53.0 h1:QHX46sISpG2S03dPeZBgVIZp8dGagIaiu2FiVYvpCZI= -github.com/quic-go/quic-go v0.53.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= +github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= github.com/refraction-networking/uquic v0.0.6 h1:9ol1oOaOpHDeeDlBY7u228jK+T5oic35QrFimHVaCMM= github.com/refraction-networking/uquic v0.0.6/go.mod h1:TFgTmV/yqVCMEXVwP7z7PMAhzye02rFHLV6cRAg59jc= -github.com/refraction-networking/utls v1.7.4-0.20250621163342-5abccec539e6 h1:jN13gW+A/o/SQg1EeN0Ki0pskhzlqozpdFvU/Oothw4= -github.com/refraction-networking/utls v1.7.4-0.20250621163342-5abccec539e6/go.mod h1:TUhh27RHMGtQvjQq+RyO11P6ZNQNBb3N0v7wsEjKAIQ= +github.com/refraction-networking/utls v1.8.0 h1:L38krhiTAyj9EeiQQa2sg+hYb4qwLCqdMcpZrRfbONE= +github.com/refraction-networking/utls v1.8.0/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -287,8 +285,8 @@ golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= -golang.org/x/exp v0.0.0-20250717185816-542afb5b7346 h1:vuCObX8mQzik1tfEcYxWZBuVsmQtD1IjxCyPKM18Bh4= -golang.org/x/exp v0.0.0-20250717185816-542afb5b7346/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.29.0 h1:HcdsyR4Gsuys/Axh0rDEmlBmB68rW1U9BUdB3UVHsas= diff --git a/requests.go b/requests.go index ab04fe6..3ddd966 100644 --- a/requests.go +++ b/requests.go @@ -145,7 +145,7 @@ func (obj *Client) retryRequest(ctx context.Context, option RequestOption, uhref if err != nil || loc == nil { return } - response.CloseBody(nil) + response.closeBody(true, nil) switch response.StatusCode() { case 307, 308: if response.Option().readOne { @@ -354,7 +354,7 @@ func (obj *Client) request(ctx *Response) (err error) { return } if ctx.response.Body != nil { - ctx.body = ctx.response.Body.(*wrapBody) + ctx.wrapBody = ctx.response.Body.(*wrapBody) } if encoding := ctx.ContentEncoding(); encoding != "" && ctx.response.Body != nil { var unCompressionBody io.ReadCloser diff --git a/response.go b/response.go index 927bc50..866d17f 100644 --- a/response.go +++ b/response.go @@ -59,7 +59,7 @@ type Response struct { err error ctx context.Context request *http.Request - body *wrapBody + wrapBody *wrapBody response *http.Response webSocket *websocket.Conn sse *SSE @@ -71,9 +71,9 @@ type Response struct { requestId string content []byte proxys []*url.URL - readBody bool readBodyLock sync.Mutex isNewConn bool + bodyErr error } type SSE struct { reader *bufio.Reader @@ -151,6 +151,13 @@ func (obj *SSE) Close() { // return websocket client func (obj *Response) WebSocket() *websocket.Conn { + if obj.webSocket != nil { + return obj.webSocket + } + if obj.StatusCode() != 101 { + return nil + } + obj.webSocket = websocket.NewConn(newFakeConn(obj.wrapBody.connStream()), func() { obj.CloseConn() }, true, obj.Headers().Get("Sec-WebSocket-Extensions")) return obj.webSocket } @@ -312,30 +319,40 @@ func (obj *Response) IsNewConn() bool { // conn proxy func (obj *Response) Proxys() []Address { - if obj.body != nil { - return obj.body.Proxys() - } - return nil + return obj.wrapBody.Proxys() } // close func (obj *Response) CloseConn() { - if obj.body != nil { - obj.body.CloseWithError(errors.New("force close conn")) - obj.body.CloseConn() + if obj.wrapBody != nil { + obj.wrapBody.CloseWithError(errors.New("force close conn")) + obj.wrapBody.CloseConn() } obj.cnl() } // close func (obj *Response) CloseBody(err error) { + obj.closeBody(true, err) +} +func (obj *Response) closeBody(i bool, err error) { + if obj.bodyErr != io.EOF { + obj.CloseConn() + return + } else if i { + if obj.StatusCode() == 101 && obj.webSocket == nil { + obj.CloseConn() + return + } + } if err == nil { err = tools.ErrNoErr - } else { - obj.cnl() } - if obj.body != nil { - obj.body.CloseWithError(err) + + if err == tools.ErrNoErr { + obj.wrapBody.CloseWithError(err) + } else { + obj.CloseConn() } } @@ -343,15 +360,14 @@ func (obj *Response) CloseBody(err error) { func (obj *Response) ReadBody() (err error) { obj.readBodyLock.Lock() defer obj.readBodyLock.Unlock() - if obj.readBody { + if obj.bodyErr != nil { return nil } - obj.readBody = true bBody := bytes.NewBuffer(nil) done := make(chan struct{}) var readErr error body := obj.Body() - defer body.Close() + defer body.close(false) go func() { defer close(done) if obj.option.Bar && obj.ContentLength() > 0 { @@ -368,7 +384,7 @@ func (obj *Response) ReadBody() (err error) { }() select { case <-obj.ctx.Done(): - if readErr == nil && body.closed && body.err == nil { + if readErr == nil && obj.bodyErr == io.EOF && body.err == nil { err = nil } else { err = tools.WrapError(obj.ctx.Err(), "response read ctx error") @@ -390,48 +406,41 @@ func (obj *Response) ReadBody() (err error) { } type body struct { - ctx *Response - closed bool - err error + ctx *Response + err error } func (obj *body) Read(p []byte) (n int, err error) { - obj.ctx.readBody = true - if obj.ctx == nil || obj.ctx.response == nil || obj.ctx.response.Body == nil { - obj.closed = true - return 0, io.EOF - } n, err = obj.ctx.response.Body.Read(p) if err != nil { - if err != io.EOF && err != io.ErrUnexpectedEOF { - obj.CloseWithError(err) + if err == io.ErrUnexpectedEOF { + err = io.EOF + } + obj.ctx.bodyErr = err + if err != io.EOF { + obj.closeWithError(false, err) } else { - obj.CloseWithError(nil) + obj.closeWithError(false, nil) } } return } func (obj *body) Close() (err error) { - return obj.CloseWithError(obj.err) + return obj.close(true) } -func (obj *body) CloseWithError(err error) error { - if obj.closed { - return obj.err - } - obj.closed = true - if err == nil { - if obj.ctx.StatusCode() == 101 && obj.ctx.webSocket == nil { - obj.ctx.webSocket = websocket.NewConn(newFakeConn(obj.ctx.body.connStream()), func() { obj.ctx.CloseConn() }, true, obj.ctx.Headers().Get("Sec-WebSocket-Extensions")) - } - obj.ctx.CloseBody(nil) - return nil - } else { +func (obj *body) close(i bool) (err error) { + return obj.closeWithError(i, obj.err) +} +func (obj *body) closeWithError(i bool, err error) error { + if obj.err == nil { obj.err = err - obj.ctx.CloseBody(err) - obj.ctx.CloseConn() - return err } + obj.ctx.closeBody(i, err) + if err != nil { + obj.ctx.CloseConn() + } + return obj.err } func (obj *Response) Body() *body { diff --git a/roundTripper.go b/roundTripper.go index 91ea290..37896ec 100644 --- a/roundTripper.go +++ b/roundTripper.go @@ -7,6 +7,7 @@ import ( "io" "net" "strings" + "sync" "time" "net/http" @@ -67,8 +68,9 @@ func getKey(ctx *Response) (string, error) { type roundTripper struct { ctx context.Context cnl context.CancelFunc - connPools *connPools + connPools sync.Map dialer *Dialer + lock sync.Mutex } var specClient = ja3.NewClient() @@ -79,32 +81,26 @@ func newRoundTripper(preCtx context.Context) *roundTripper { } ctx, cnl := context.WithCancel(preCtx) return &roundTripper{ - ctx: ctx, - cnl: cnl, - dialer: new(Dialer), - connPools: newConnPools(), + ctx: ctx, + cnl: cnl, + dialer: new(Dialer), } } -func (obj *roundTripper) newConnPool(done chan struct{}, conn *connecotr, task *reqTask) *connPool { - pool := new(connPool) - pool.connKey = task.key - pool.forceCtx, pool.forceCnl = context.WithCancelCause(obj.ctx) - pool.tasks = make(chan *reqTask) - pool.connPools = obj.connPools - pool.total.Add(1) - go pool.rwMain(done, conn) - return pool +func (obj *roundTripper) getConnPool(task *reqTask) chan *reqTask { + obj.lock.Lock() + defer obj.lock.Unlock() + val, ok := obj.connPools.Load(task.key) + if ok { + return val.(chan *reqTask) + } + tasks := make(chan *reqTask) + obj.connPools.Store(task.key, tasks) + return tasks } func (obj *roundTripper) putConnPool(task *reqTask, conn *connecotr) { - pool := obj.connPools.get(task) done := make(chan struct{}) - if pool != nil { - pool.total.Add(1) - go pool.rwMain(done, conn) - } else { - obj.connPools.set(task, obj.newConnPool(done, conn, task)) - } + go conn.rwMain(obj.ctx, done, obj.getConnPool(task)) <-done } func (obj *roundTripper) newConnecotr() *connecotr { @@ -149,7 +145,7 @@ func (obj *roundTripper) ghttp3Dial(ctx *Response, remoteAddress Address, proxyA } conn = obj.newConnecotr() - conn.Conn = http3.NewClient(netConn, udpConn, func() { + conn.Conn = http3.NewClient(conn.forceCtx, netConn, udpConn, func() { conn.forceCnl(errors.New("http3 client close")) }) if ct, ok := udpConn.(interface { @@ -194,7 +190,7 @@ func (obj *roundTripper) uhttp3Dial(ctx *Response, remoteAddress Address, proxyA return nil, err } conn = obj.newConnecotr() - conn.Conn = http3.NewClient(netConn, udpConn, func() { + conn.Conn = http3.NewClient(conn.forceCtx, netConn, udpConn, func() { conn.forceCnl(errors.New("http3 client close")) }) if ct, ok := udpConn.(interface { @@ -287,13 +283,13 @@ func (obj *roundTripper) dialConnecotr(ctx *Response, conne *connecotr, h2 bool) if ctx.option.gospiderSpec != nil { spec = ctx.option.gospiderSpec.H2Spec } - if conne.Conn, err = http2.NewClientConn(ctx.Context(), conne.c, spec, func(err error) { + if conne.Conn, err = http2.NewClientConn(conne.forceCtx, ctx.Context(), conne.c, spec, func(err error) { conne.forceCnl(tools.WrapError(err, "http2 client close")) }); err != nil { return err } } else { - conne.Conn = http1.NewClientConn(conne.c, func(err error) { + conne.Conn = http1.NewClientConn(conne.forceCtx, conne.c, func(err error) { conne.forceCnl(tools.WrapError(err, "http1 client close")) }) } @@ -345,13 +341,9 @@ func (obj *roundTripper) initProxys(ctx *Response) ([]Address, error) { } func (obj *roundTripper) poolRoundTrip(task *reqTask) error { - connPool := obj.connPools.get(task) - if connPool == nil { - return obj.newRoudTrip(task) - } task.ctx, task.cnl = context.WithCancelCause(task.reqCtx.Context()) select { - case connPool.tasks <- task: + case obj.getConnPool(task) <- task: <-task.ctx.Done() err := context.Cause(task.ctx) if errors.Is(err, tools.ErrNoErr) { @@ -383,13 +375,6 @@ func (obj *roundTripper) newRoudTrip(task *reqTask) error { return err } -func (obj *roundTripper) closeConns() { - for key, pool := range obj.connPools.Range() { - pool.close(errors.New("close all conn")) - obj.connPools.del(key) - } -} - func (obj *roundTripper) newReqTask(ctx *Response) (*reqTask, error) { if ctx.option.ResponseHeaderTimeout == 0 { ctx.option.ResponseHeaderTimeout = time.Second * 300 diff --git a/rw.go b/rw.go index aa49c6f..55d22a1 100644 --- a/rw.go +++ b/rw.go @@ -3,13 +3,13 @@ package requests import ( "errors" "io" - "net/http" + "github.com/gospider007/http1" "github.com/gospider007/tools" ) type wrapBody struct { - rawBody io.ReadCloser + rawBody *http1.ClientBody conn *connecotr } @@ -31,12 +31,7 @@ func (obj *wrapBody) CloseWithError(err error) error { if err != nil && err != tools.ErrNoErr { obj.conn.CloseWithError(err) } - if obj.rawBody == nil || obj.rawBody == http.NoBody { - return nil - } - return obj.rawBody.(interface { - CloseWithError(error) error - }).CloseWithError(err) + return obj.rawBody.CloseWithError(err) } // safe close conn diff --git a/test/fingerprint/http2_test.go b/test/fingerprint/http2_test.go index 114956a..605cb36 100644 --- a/test/fingerprint/http2_test.go +++ b/test/fingerprint/http2_test.go @@ -24,6 +24,7 @@ func TestH2(t *testing.T) { // log.Print(resp.Text()) jsonData, err := resp.Json() ja3 := jsonData.Get("http2.fingerprint") + // log.Print(ja3) if !ja3.Exists() { t.Fatal("not found http2") } diff --git a/test/protocol/websocket_test.go b/test/protocol/websocket_test.go index b82ac6d..b10a281 100644 --- a/test/protocol/websocket_test.go +++ b/test/protocol/websocket_test.go @@ -1,6 +1,7 @@ package main import ( + "io" "log" "net/http" "strings" @@ -11,7 +12,12 @@ import ( "github.com/gospider007/requests" ) +var wsOk bool + func websocketServer() { + if wsOk { + return + } var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许跨域 @@ -33,6 +39,7 @@ func websocketServer() { } }) log.Println("WebSocket 服务器启动于 ws://localhost:8080/ws") + wsOk = true log.Fatal(http.ListenAndServe(":8800", nil)) } func TestWebSocket(t *testing.T) { @@ -63,13 +70,100 @@ func TestWebSocket(t *testing.T) { log.Print(string(con)) if strings.Contains(string(con), "test1122332211") { n++ - if n > 6 { + if n > 2 { + break + } + } + if err = wsCli.WriteMessage(websocket.TextMessage, "test1122332211"); err != nil { // Send text message + log.Panic(err) + } + } +} +func TestWebSocketClose(t *testing.T) { + go websocketServer() + time.Sleep(time.Second * 1) // Send WebSocket request + response, err := requests.Get(nil, "ws://localhost:8800/ws", requests.RequestOption{DisProxy: true, Stream: true}) // Send WebSocket request + if err != nil { + log.Panic(err) + } + defer response.CloseConn() + response.CloseBody(nil) + wsCli := response.WebSocket() + if wsCli == nil { + t.Fatal("WebSocket client is nil") + } + defer wsCli.Close() + log.Print(wsCli) + log.Print(response.Headers()) + log.Print(response.StatusCode()) + if err = wsCli.WriteMessage(websocket.TextMessage, "test1122332211"); err == nil { // Send text message + t.Fatal("这里必须报错") + } +} +func TestWebSocketClose2(t *testing.T) { + go websocketServer() + time.Sleep(time.Second * 1) // Send WebSocket request + response, err := requests.Get(nil, "ws://localhost:8800/ws", requests.RequestOption{DisProxy: true, Stream: true}) // Send WebSocket request + if err != nil { + log.Panic(err) + } + defer response.CloseConn() + body := response.Body() + io.ReadAll(body) + response.CloseBody(nil) + wsCli := response.WebSocket() + if wsCli == nil { + t.Fatal("WebSocket client is nil") + } + defer wsCli.Close() + log.Print(wsCli) + log.Print(response.Headers()) + log.Print(response.StatusCode()) + if err = wsCli.WriteMessage(websocket.TextMessage, "test1122332211"); err == nil { // Send text message + t.Fatal("这里必须报错") + } +} +func TestWebSocketClose3(t *testing.T) { + go websocketServer() + time.Sleep(time.Second * 1) // Send WebSocket request + response, err := requests.Get(nil, "ws://localhost:8800/ws", requests.RequestOption{DisProxy: true, Stream: true}) // Send WebSocket request + if err != nil { + log.Panic(err) + } + defer response.CloseConn() + body := response.Body() + io.ReadAll(body) + // body.Close() + wsCli := response.WebSocket() + response.CloseBody(nil) + if wsCli == nil { + t.Fatal("WebSocket client is nil") + } + defer wsCli.Close() + log.Print(wsCli) + log.Print(response.Headers()) + log.Print(response.StatusCode()) + if err = wsCli.WriteMessage(websocket.TextMessage, "test1122332211"); err != nil { // Send text message + log.Panic(err) + } + n := 0 + for { + msgType, con, err := wsCli.ReadMessage() // Receive message + if err != nil { + log.Panic(err) + } + if msgType != websocket.TextMessage { + log.Panic("Message type is not text") + } + log.Print(string(con)) + if strings.Contains(string(con), "test1122332211") { + n++ + if n > 2 { break } } if err = wsCli.WriteMessage(websocket.TextMessage, "test1122332211"); err != nil { // Send text message log.Panic(err) } - time.Sleep(time.Second * 2) } } diff --git a/test/request/stream_test.go b/test/request/stream_test.go index 9e8f955..9171905 100644 --- a/test/request/stream_test.go +++ b/test/request/stream_test.go @@ -2,6 +2,7 @@ package main import ( "io" + "log" "testing" "github.com/gospider007/requests" @@ -49,3 +50,109 @@ func TestStreamWithConn(t *testing.T) { } } } + +func TestStreamWithConn2(t *testing.T) { + for i := 0; i < 2; i++ { + resp, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{Stream: true}) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode() != 200 { + t.Fatal("resp.StatusCode()!= 200") + } + body := resp.Body() + body.Close() + // con, err := io.ReadAll(body) + // if err != nil { + // t.Fatal(err) + // } + // if len(string(con)) == 0 { + // t.Fatal("con is empty") + // } + // body.Close() + if i == 1 && !resp.IsNewConn() { + t.Fatal("con is new") + } + } +} + +func TestStreamWithConn3(t *testing.T) { + for i := 0; i < 2; i++ { + resp, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{Stream: true}) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode() != 200 { + t.Fatal("resp.StatusCode()!= 200") + } + // body := resp.Body() + resp.CloseBody(nil) + // con, err := io.ReadAll(body) + // if err != nil { + // t.Fatal(err) + // } + // if len(string(con)) == 0 { + // t.Fatal("con is empty") + // } + // body.Close() + if i == 1 && !resp.IsNewConn() { + t.Fatal("con is new") + } + } +} + +func TestStreamWithConn4(t *testing.T) { + for i := 0; i < 2; i++ { + resp, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{Stream: true}) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode() != 200 { + t.Fatal("resp.StatusCode()!= 200") + } + body := resp.Body() + bbb := make([]byte, 10) + body.Read(bbb) + log.Print(len(bbb)) + resp.CloseBody(nil) + // con, err := io.ReadAll(body) + // if err != nil { + // t.Fatal(err) + // } + // if len(string(con)) == 0 { + // t.Fatal("con is empty") + // } + // body.Close() + if i == 1 && !resp.IsNewConn() { + t.Fatal("con is new") + } + } +} +func TestStreamWithConn5(t *testing.T) { + for i := 0; i < 2; i++ { + resp, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{Stream: true}) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode() != 200 { + t.Fatal("resp.StatusCode()!= 200") + } + body := resp.Body() + bbb := make([]byte, 10) + body.Read(bbb) + log.Print(len(bbb)) + // resp.CloseBody(nil) + body.Close() + // con, err := io.ReadAll(body) + // if err != nil { + // t.Fatal(err) + // } + // if len(string(con)) == 0 { + // t.Fatal("con is empty") + // } + // body.Close() + if i == 1 && !resp.IsNewConn() { + t.Fatal("con is new") + } + } +} diff --git a/test/session_test.go b/test/session_test.go index 613e30c..4c25c67 100644 --- a/test/session_test.go +++ b/test/session_test.go @@ -15,6 +15,7 @@ func TestSession(t *testing.T) { t.Error(err) } log.Print(resp.Proto(), resp.IsNewConn()) + resp.CloseBody(nil) if i == 0 { if !resp.IsNewConn() { //return is NewConn t.Error("new conn error: ", i) diff --git a/tools.go b/tools.go index 044367f..29483e0 100644 --- a/tools.go +++ b/tools.go @@ -141,21 +141,6 @@ func escapeQuotes(s string) string { return quoteEscaper.Replace(s) } -func removeZone(host string) string { - if !strings.HasPrefix(host, "[") { - return host - } - i := strings.LastIndex(host, "]") - if i < 0 { - return host - } - j := strings.LastIndex(host[:i], "%") - if j < 0 { - return host - } - return host[:j] + host[i:] -} - type requestBody struct { r io.Reader }