optimize code

This commit is contained in:
bxd
2023-11-24 20:26:53 +08:00
parent f6d27b3f10
commit 3f79c0117d
8 changed files with 52 additions and 42 deletions

View File

@@ -16,11 +16,11 @@ type Client struct {
forceHttp1 bool
orderHeaders []string
jar *Jar
maxRedirectNum int
disDecode bool
disUnZip bool
disAlive bool
jar *Jar
maxRedirect int
disDecode bool
disUnZip bool
disAlive bool
maxRetries int
@@ -78,7 +78,7 @@ func NewClient(preCtx context.Context, options ...ClientOption) (*Client, error)
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
ctxData := GetReqCtxData(req.Context())
if ctxData.maxRedirectNum == 0 || ctxData.maxRedirectNum >= len(via) {
if ctxData.maxRedirect == 0 || ctxData.maxRedirect >= len(via) {
return nil
}
return http.ErrUseLastResponse
@@ -93,7 +93,7 @@ func NewClient(preCtx context.Context, options ...ClientOption) (*Client, error)
requestCallBack: option.RequestCallBack,
orderHeaders: option.OrderHeaders,
disCookie: option.DisCookie,
maxRedirectNum: option.MaxRedirectNum,
maxRedirect: option.MaxRedirect,
disDecode: option.DisDecode,
disUnZip: option.DisUnZip,
disAlive: option.DisAlive,

20
conn.go
View File

@@ -4,7 +4,6 @@ import (
"bufio"
"context"
"errors"
"log"
"net"
"net/http"
"sync"
@@ -135,7 +134,7 @@ func (obj *connecotr) wrapBody(task *reqTask) {
func (obj *connecotr) http1Req(task *reqTask) {
defer task.cnl()
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "http1 req start")
debugPrint(task.requestId, "http1 req start")
}
if task.orderHeaders != nil && len(task.orderHeaders) > 0 {
task.err = httpWrite(task.req, obj.w, task.orderHeaders)
@@ -143,7 +142,7 @@ func (obj *connecotr) http1Req(task *reqTask) {
task.err = obj.w.Flush()
}
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "http1 req write ok")
debugPrint(task.requestId, "http1 req write ok ,err: ", task.err)
}
if task.err == nil {
if task.res, task.err = http.ReadResponse(obj.r, task.req); task.res != nil && task.err == nil {
@@ -155,14 +154,14 @@ func (obj *connecotr) http1Req(task *reqTask) {
task.err = tools.WrapError(task.err, "http1 write error")
}
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "http1 req end")
debugPrint(task.requestId, "http1 req ok ,err: ", task.err)
}
}
func (obj *connecotr) http2Req(task *reqTask) {
defer task.cnl()
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "http2 req start")
debugPrint(task.requestId, "http2 req start")
}
if task.res, task.err = obj.h2RawConn.RoundTrip(task.req); task.res != nil && task.err == nil {
obj.wrapBody(task)
@@ -170,21 +169,21 @@ func (obj *connecotr) http2Req(task *reqTask) {
task.err = tools.WrapError(task.err, "http2 roundTrip error")
}
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "http2 req end")
debugPrint(task.requestId, "http2 req ok,err: ", task.err)
}
}
func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer) (*http.Response, error, bool) {
if obj.h2 && obj.h2Closed() {
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "h2 con is closed")
debugPrint(task.requestId, "h2 con is closed")
}
return nil, errors.New("conn is closed"), true
}
select {
case <-obj.closeCtx.Done():
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "connecotr closeCnl")
debugPrint(task.requestId, "connecotr closeCnl")
}
return nil, tools.WrapError(obj.closeCtx.Err(), "close ctx error: "), true
default:
@@ -207,6 +206,7 @@ func (obj *connecotr) taskMain(task *reqTask, afterTime *time.Timer) (*http.Resp
if task.res != nil && task.err == nil && obj.isPool {
select {
case <-obj.bodyCtx.Done(): //wait body close
task.err = tools.WrapError(obj.deleteCtx.Err(), "body ctx error: ")
case <-obj.deleteCtx.Done(): //force conn close
task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ")
}
@@ -264,11 +264,11 @@ func (obj *connPool) rwMain(conn *connecotr) {
return
case task := <-obj.tasks: //recv task
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "recv task")
debugPrint(task.requestId, "recv task")
}
if task == nil {
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "recv task is nil")
debugPrint(task.requestId, "recv task is nil")
}
return
}

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
github.com/gospider007/net v0.0.0-20231028084010-313c148cf0a1
github.com/gospider007/re v0.0.0-20231024115818-adfd03636256
github.com/gospider007/tools v0.0.0-20231122021245-1cafbac3ef46
github.com/gospider007/websocket v0.0.0-20231114095858-b8bc9b2033d3
github.com/gospider007/websocket v0.0.0-20231124122326-78d52f163d6c
github.com/refraction-networking/utls v1.5.4
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.18.0

2
go.sum
View File

@@ -55,6 +55,8 @@ github.com/gospider007/tools v0.0.0-20231122021245-1cafbac3ef46 h1:vskdS8WLAveNS
github.com/gospider007/tools v0.0.0-20231122021245-1cafbac3ef46/go.mod h1:myK4kDqDx4TlplDVnfYMI7Xi5VUbFZ3fxwAh2Cwm7ks=
github.com/gospider007/websocket v0.0.0-20231114095858-b8bc9b2033d3 h1:HpiNfOZ9Tjo4hhP1+jmlgqykngykles3ypXa2BUuxRc=
github.com/gospider007/websocket v0.0.0-20231114095858-b8bc9b2033d3/go.mod h1:jINjCM6qIRiqn2Di1bat4Ie5gY66ae7LYT8YK4fAejY=
github.com/gospider007/websocket v0.0.0-20231124122326-78d52f163d6c h1:AVquutD7Mbb9gcq7/ciRC/Vt2StSiBuUagJviG8+vkg=
github.com/gospider007/websocket v0.0.0-20231124122326-78d52f163d6c/go.mod h1:TquIvV/QrLmSufnwdc+54DAbUd39HsNgpFcoQYthVU8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=

View File

@@ -32,7 +32,7 @@ type ClientOption struct {
ErrCallBack func(ctx context.Context, client *Client, response *Response, err error) error //error callback,if error is returnd,break request
RequestCallBack func(ctx context.Context, request *http.Request, response *http.Response) error //request and response callback,if error is returnd,reponse is error
MaxRetries int //try num
MaxRedirectNum int //redirect num ,<0 no redirect,==0 no limit
MaxRedirect int //redirect num ,<0 no redirect,==0 no limit
Headers any //default headers
ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:30
TlsHandshakeTimeout time.Duration //tls timeout,default:15
@@ -69,7 +69,7 @@ type RequestOption struct {
RequestCallBack func(ctx context.Context, request *http.Request, response *http.Response) error //request and response callback,if error is returnd,reponse is error
MaxRetries int //try num
MaxRedirectNum int //redirect num ,<0 no redirect,==0 no limit
MaxRedirect int //redirect num ,<0 no redirect,==0 no limit
Headers any //request headersjson,mapheader
ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:30
TlsHandshakeTimeout time.Duration
@@ -189,8 +189,8 @@ func (obj *Client) newRequestOption(option RequestOption) RequestOption {
if !option.Bar {
option.Bar = obj.bar
}
if option.MaxRedirectNum == 0 {
option.MaxRedirectNum = obj.maxRedirectNum
if option.MaxRedirect == 0 {
option.MaxRedirect = obj.maxRedirect
}
if option.Timeout == 0 {
option.Timeout = obj.timeout

View File

@@ -35,7 +35,7 @@ var errFatal = errors.New("Fatal error")
type reqCtxData struct {
isWs bool
forceHttp1 bool
maxRedirectNum int
maxRedirect int
proxy *url.URL
disProxy bool
disAlive bool
@@ -66,7 +66,7 @@ func NewReqCtxData(ctx context.Context, option *RequestOption) (*reqCtxData, err
ctxData.h2Ja3Spec = option.H2Ja3Spec
ctxData.forceHttp1 = option.ForceHttp1
ctxData.disAlive = option.DisAlive
ctxData.maxRedirectNum = option.MaxRedirectNum
ctxData.maxRedirect = option.MaxRedirect
ctxData.requestCallBack = option.RequestCallBack
ctxData.responseHeaderTimeout = option.ResponseHeaderTimeout
ctxData.addrType = option.AddrType
@@ -279,7 +279,7 @@ func debugPrint(requestId string, content ...any) {
for _, cont := range content {
contents = append(contents, cont, " ")
}
log.Printf("%s:%d , %s: %s, %s: %s", f, l, blog.Color(2, "requestId"), blog.Color(1, requestId), blog.Color(2, "content"), blog.Color(3, contents...))
log.Printf("%s:%d, %s>>>%s, %s>>>%s", f, l, blog.Color(2, "requestId"), blog.Color(1, requestId), blog.Color(2, "content"), blog.Color(3, contents...))
}
func (obj *Client) request(ctx context.Context, option *RequestOption) (response *Response, err error) {
response = new(Response)
@@ -375,7 +375,7 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
//init ws
if ctxData.isWs {
if ctxData.debug {
log.Printf("requestId:%s: %s", ctxData.requestId, "init websocket headers")
debugPrint(ctxData.requestId, "init websocket headers")
}
websocket.SetClientHeadersOption(reqs.Header, option.WsOption)
}
@@ -412,12 +412,12 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
}
}
if ctxData.debug {
log.Printf("requestId:%s: %s", ctxData.requestId, "send request start")
debugPrint(ctxData.requestId, "send request start")
}
//send req
response.response, err = obj.getClient(option).Do(reqs)
if ctxData.debug {
log.Printf("requestId:%s: %s", ctxData.requestId, "send request end")
debugPrint(ctxData.requestId, "send request end, err: ", err)
}
response.isNewConn = ctxData.isNewConn
if err != nil {
@@ -432,12 +432,21 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
response.disUnzip = response.response.Uncompressed
}
if response.response.StatusCode == 101 {
response.webSocket, err = websocket.NewClientConn(response.response, response.CloseConn)
response.webSocket, err = websocket.NewClientConn(response.response, response.ForceCloseConn)
if ctxData.debug {
debugPrint(ctxData.requestId, "new websocket client, err: ", err)
}
} else if response.response.Header.Get("Content-Type") == "text/event-stream" {
response.sse = newSse(response.response.Body, response.CloseConn)
response.sse = newSse(response.response.Body, response.ForceCloseConn)
if ctxData.debug {
debugPrint(ctxData.requestId, "new sse client")
}
} else if !response.disUnzip {
var unCompressionBody io.ReadCloser
unCompressionBody, err = tools.CompressionDecode(response.response.Body, response.ContentEncoding())
if ctxData.debug {
debugPrint(ctxData.requestId, "unCompressionBody, err: ", err)
}
if err != nil {
if err != io.ErrUnexpectedEOF && err != io.EOF {
return

View File

@@ -317,7 +317,7 @@ func (obj *Response) CloseBody() error {
obj.sse.Close()
}
if obj.IsStream() || !obj.readBody {
obj.CloseConn()
obj.ForceCloseConn()
} else if obj.rawConn != nil {
obj.rawConn.Close()
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"context"
"crypto/tls"
"log"
"net"
"net/url"
"sync"
@@ -245,47 +244,47 @@ func (obj *RoundTripper) getProxy(ctx context.Context, proxyUrl *url.URL) (*url.
func (obj *RoundTripper) poolRoundTrip(task *reqTask, key connKey) (bool, error) {
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "poolRoundTrip get conn pool")
debugPrint(task.requestId, "poolRoundTrip start")
}
pool := obj.getConnPool(key)
if pool == nil {
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "poolRoundTrip not found conn pool")
debugPrint(task.requestId, "poolRoundTrip not found conn pool")
}
return false, nil
}
select {
case <-obj.ctx.Done():
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "RoundTripper already cloed")
debugPrint(task.requestId, "RoundTripper already cloed")
}
return false, tools.WrapError(obj.ctx.Err(), "roundTripper close ctx error: ")
case <-pool.closeCtx.Done():
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "pool already cloed")
debugPrint(task.requestId, "pool already cloed")
}
return false, pool.closeCtx.Err()
case pool.tasks <- task:
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "pool.tasks <- task")
debugPrint(task.requestId, "poolRoundTrip tasks <- task")
}
select {
case <-task.emptyPool:
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "poolRoundTrip emptyPool")
debugPrint(task.requestId, "poolRoundTrip emptyPool")
}
case <-task.ctx.Done():
if task.err == nil && task.res == nil {
task.err = tools.WrapError(task.ctx.Err(), "task close ctx error: ")
}
if task.debug {
log.Printf("requestId:%s: %v", task.requestId, task.err)
debugPrint(task.requestId, "task ctx done, err: ", task.err)
}
return true, task.err
}
default:
if task.debug {
log.Printf("requestId:%s: %s", task.requestId, "conn pool not idle")
debugPrint(task.requestId, "conn pool not idle")
}
}
return false, nil
@@ -334,7 +333,7 @@ func (obj *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
}
ctxData.isNewConn = true
if task.debug {
log.Printf("requestId:%s: %s", ctxData.requestId, "new Conn")
debugPrint(ctxData.requestId, "new Conn")
}
newConn:
ckey := key
@@ -351,7 +350,7 @@ newConn:
conn.key = ckey
if task.inPool() && !ctxData.disAlive {
if task.debug {
log.Printf("requestId:%s: %s", ctxData.requestId, "conn put conn pool")
debugPrint(ctxData.requestId, "conn put conn pool")
}
obj.putConnPool(key, conn)
}