mirror of
https://github.com/gospider007/requests.git
synced 2025-12-24 13:57:52 +08:00
sync
This commit is contained in:
181
roundTripper.go
181
roundTripper.go
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
@@ -20,27 +21,31 @@ import (
|
||||
)
|
||||
|
||||
type reqTask struct {
|
||||
option *RequestOption
|
||||
ctx context.Context
|
||||
cnl context.CancelFunc
|
||||
req *http.Request
|
||||
res *http.Response
|
||||
reqCtx *Response
|
||||
emptyPool chan struct{}
|
||||
err error
|
||||
retry int
|
||||
}
|
||||
|
||||
func (obj *reqTask) suppertRetry() bool {
|
||||
if obj.req.Body == nil {
|
||||
if obj.reqCtx.request.Body == nil {
|
||||
return true
|
||||
} else if body, ok := obj.req.Body.(io.Seeker); ok {
|
||||
} else if body, ok := obj.reqCtx.request.Body.(io.Seeker); ok {
|
||||
if i, err := body.Seek(0, io.SeekStart); i == 0 && err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
func getKey(option *RequestOption, req *http.Request) (key string) {
|
||||
return fmt.Sprintf("%s@%s", getAddr(option.proxy), getAddr(req.URL))
|
||||
func getKey(ctx *Response) (key string) {
|
||||
adds := []string{}
|
||||
for _, p := range ctx.proxys {
|
||||
adds = append(adds, getAddr(p))
|
||||
}
|
||||
adds = append(adds, getAddr(ctx.Request().URL))
|
||||
return strings.Join(adds, "@")
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
@@ -91,37 +96,37 @@ func (obj *roundTripper) newConnecotr() *connecotr {
|
||||
return conne
|
||||
}
|
||||
|
||||
func (obj *roundTripper) http3Dial(ctx context.Context, option *RequestOption, remtoeAddress Address, proxyAddress ...Address) (udpConn net.PacketConn, err error) {
|
||||
func (obj *roundTripper) http3Dial(ctx *Response, remtoeAddress Address, proxyAddress ...Address) (udpConn net.PacketConn, err error) {
|
||||
if len(proxyAddress) > 0 {
|
||||
if proxyAddress[len(proxyAddress)-1].Scheme != "socks5" {
|
||||
err = errors.New("http3 last proxy must socks5 proxy")
|
||||
return
|
||||
}
|
||||
udpConn, _, err = obj.dialer.DialProxyContext(ctx, option, "tcp", option.TlsConfig.Clone(), append(proxyAddress, remtoeAddress)...)
|
||||
udpConn, _, err = obj.dialer.DialProxyContext(ctx, "tcp", ctx.option.TlsConfig.Clone(), append(proxyAddress, remtoeAddress)...)
|
||||
} else {
|
||||
udpConn, err = net.ListenUDP("udp", nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
func (obj *roundTripper) ghttp3Dial(ctx context.Context, option *RequestOption, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
|
||||
udpConn, err := obj.http3Dial(ctx, option, remoteAddress, proxyAddress...)
|
||||
func (obj *roundTripper) ghttp3Dial(ctx *Response, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
|
||||
udpConn, err := obj.http3Dial(ctx, remoteAddress, proxyAddress...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig := option.TlsConfig.Clone()
|
||||
tlsConfig := ctx.option.TlsConfig.Clone()
|
||||
tlsConfig.NextProtos = []string{http3.NextProtoH3}
|
||||
tlsConfig.ServerName = remoteAddress.Host
|
||||
if remoteAddress.IP == nil {
|
||||
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name, option)
|
||||
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var quicConfig *quic.Config
|
||||
if option.UquicConfig != nil {
|
||||
quicConfig = option.QuicConfig.Clone()
|
||||
if ctx.option.UquicConfig != nil {
|
||||
quicConfig = ctx.option.QuicConfig.Clone()
|
||||
}
|
||||
netConn, err := quic.DialEarly(ctx, udpConn, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
|
||||
netConn, err := quic.DialEarly(ctx.Context(), udpConn, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
|
||||
conn = obj.newConnecotr()
|
||||
conn.Conn = http3.NewClient(netConn, func() {
|
||||
conn.forceCnl(errors.New("http3 client close"))
|
||||
@@ -129,34 +134,34 @@ func (obj *roundTripper) ghttp3Dial(ctx context.Context, option *RequestOption,
|
||||
return
|
||||
}
|
||||
|
||||
func (obj *roundTripper) uhttp3Dial(ctx context.Context, option *RequestOption, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
|
||||
spec, err := ja3.CreateSpecWithUSpec(option.UJa3Spec)
|
||||
func (obj *roundTripper) uhttp3Dial(ctx *Response, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
|
||||
spec, err := ja3.CreateSpecWithUSpec(ctx.option.UJa3Spec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
udpConn, err := obj.http3Dial(ctx, option, remoteAddress, proxyAddress...)
|
||||
udpConn, err := obj.http3Dial(ctx, remoteAddress, proxyAddress...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig := option.UtlsConfig.Clone()
|
||||
tlsConfig := ctx.option.UtlsConfig.Clone()
|
||||
tlsConfig.NextProtos = []string{http3.NextProtoH3}
|
||||
tlsConfig.ServerName = remoteAddress.Host
|
||||
if remoteAddress.IP == nil {
|
||||
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name, option)
|
||||
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var quicConfig *uquic.Config
|
||||
if option.UquicConfig != nil {
|
||||
quicConfig = option.UquicConfig.Clone()
|
||||
if ctx.option.UquicConfig != nil {
|
||||
quicConfig = ctx.option.UquicConfig.Clone()
|
||||
}
|
||||
netConn, err := (&uquic.UTransport{
|
||||
Transport: &uquic.Transport{
|
||||
Conn: udpConn,
|
||||
},
|
||||
QUICSpec: &spec,
|
||||
}).DialEarly(ctx, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
|
||||
}).DialEarly(ctx.Context(), &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
|
||||
conn = obj.newConnecotr()
|
||||
conn.Conn = http3.NewUClient(netConn, func() {
|
||||
conn.forceCnl(errors.New("http3 client close"))
|
||||
@@ -164,32 +169,32 @@ func (obj *roundTripper) uhttp3Dial(ctx context.Context, option *RequestOption,
|
||||
return
|
||||
}
|
||||
|
||||
func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *connecotr, err error) {
|
||||
proxys, err := obj.initProxys(option, req)
|
||||
func (obj *roundTripper) dial(ctx *Response) (conn *connecotr, err error) {
|
||||
proxys, err := obj.initProxys(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remoteAddress, err := GetAddressWithUrl(req.URL)
|
||||
remoteAddress, err := GetAddressWithUrl(ctx.request.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if option.H3 {
|
||||
if option.UJa3Spec.IsSet() {
|
||||
return obj.uhttp3Dial(req.Context(), option, remoteAddress, proxys...)
|
||||
if ctx.option.H3 {
|
||||
if ctx.option.UJa3Spec.IsSet() {
|
||||
return obj.uhttp3Dial(ctx, remoteAddress, proxys...)
|
||||
} else {
|
||||
return obj.ghttp3Dial(req.Context(), option, remoteAddress, proxys...)
|
||||
return obj.ghttp3Dial(ctx, remoteAddress, proxys...)
|
||||
}
|
||||
}
|
||||
var netConn net.Conn
|
||||
if len(proxys) > 0 {
|
||||
_, netConn, err = obj.dialer.DialProxyContext(req.Context(), option, "tcp", option.TlsConfig.Clone(), append(proxys, remoteAddress)...)
|
||||
_, netConn, err = obj.dialer.DialProxyContext(ctx, "tcp", ctx.option.TlsConfig.Clone(), append(proxys, remoteAddress)...)
|
||||
} else {
|
||||
var remoteAddress Address
|
||||
remoteAddress, err = GetAddressWithUrl(req.URL)
|
||||
remoteAddress, err = GetAddressWithUrl(ctx.request.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netConn, err = obj.dialer.DialContext(req.Context(), option, "tcp", remoteAddress)
|
||||
netConn, err = obj.dialer.DialContext(ctx, "tcp", remoteAddress)
|
||||
}
|
||||
defer func() {
|
||||
if err != nil && netConn != nil {
|
||||
@@ -200,14 +205,14 @@ func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *c
|
||||
return nil, err
|
||||
}
|
||||
var h2 bool
|
||||
if req.URL.Scheme == "https" {
|
||||
netConn, h2, err = obj.dialAddTls(option, req, netConn)
|
||||
if option.Logger != nil {
|
||||
option.Logger(Log{
|
||||
Id: option.requestId,
|
||||
if ctx.request.URL.Scheme == "https" {
|
||||
netConn, h2, err = obj.dialAddTls(ctx.option, ctx.request, netConn)
|
||||
if ctx.option.Logger != nil {
|
||||
ctx.option.Logger(Log{
|
||||
Id: ctx.requestId,
|
||||
Time: time.Now(),
|
||||
Type: LogType_TLSHandshake,
|
||||
Msg: fmt.Sprintf("host:%s, h2:%t", getHost(req), h2),
|
||||
Msg: fmt.Sprintf("host:%s, h2:%t", getHost(ctx.request), h2),
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
@@ -217,7 +222,7 @@ func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *c
|
||||
conne := obj.newConnecotr()
|
||||
conne.proxys = proxys
|
||||
conne.c = netConn
|
||||
err = obj.dialConnecotr(option, req, conne, h2)
|
||||
err = obj.dialConnecotr(ctx.option, ctx.request, conne, h2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -257,21 +262,14 @@ func (obj *roundTripper) dialAddTls(option *RequestOption, req *http.Request, ne
|
||||
}
|
||||
}
|
||||
}
|
||||
func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([]Address, error) {
|
||||
func (obj *roundTripper) initProxys(ctx *Response) ([]Address, error) {
|
||||
var proxys []Address
|
||||
if option.DisProxy {
|
||||
if ctx.option.DisProxy {
|
||||
return nil, nil
|
||||
}
|
||||
if option.proxy != nil {
|
||||
proxyAddress, err := GetAddressWithUrl(option.proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxys = []Address{proxyAddress}
|
||||
}
|
||||
if len(proxys) == 0 && len(option.proxys) > 0 {
|
||||
proxys = make([]Address, len(option.proxys))
|
||||
for i, proxy := range option.proxys {
|
||||
if len(ctx.proxys) > 0 {
|
||||
proxys = make([]Address, len(ctx.proxys))
|
||||
for i, proxy := range ctx.proxys {
|
||||
proxyAddress, err := GetAddressWithUrl(proxy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -279,8 +277,8 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
|
||||
proxys[i] = proxyAddress
|
||||
}
|
||||
}
|
||||
if len(proxys) == 0 && option.GetProxy != nil {
|
||||
proxyStr, err := option.GetProxy(req.Context(), req.URL)
|
||||
if len(proxys) == 0 && ctx.option.GetProxy != nil {
|
||||
proxyStr, err := ctx.option.GetProxy(ctx)
|
||||
if err != nil {
|
||||
return proxys, err
|
||||
}
|
||||
@@ -296,8 +294,8 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
|
||||
proxys = []Address{proxyAddress}
|
||||
}
|
||||
}
|
||||
if len(proxys) == 0 && option.GetProxys != nil {
|
||||
proxyStrs, err := option.GetProxys(req.Context(), req.URL)
|
||||
if len(proxys) == 0 && ctx.option.GetProxys != nil {
|
||||
proxyStrs, err := ctx.option.GetProxys(ctx)
|
||||
if err != nil {
|
||||
return proxys, err
|
||||
}
|
||||
@@ -319,30 +317,31 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
|
||||
return proxys, nil
|
||||
}
|
||||
|
||||
func (obj *roundTripper) poolRoundTrip(option *RequestOption, pool *connPool, task *reqTask, key string) (isOk bool, err error) {
|
||||
task.ctx, task.cnl = context.WithTimeout(task.req.Context(), option.ResponseHeaderTimeout)
|
||||
func (obj *roundTripper) poolRoundTrip(pool *connPool, task *reqTask, key string) (isOk bool, err error) {
|
||||
task.ctx, task.cnl = context.WithTimeout(task.reqCtx.Context(), task.reqCtx.option.ResponseHeaderTimeout)
|
||||
select {
|
||||
case pool.tasks <- task:
|
||||
select {
|
||||
case <-task.emptyPool:
|
||||
return false, nil
|
||||
case <-task.ctx.Done():
|
||||
if task.err == nil && task.res == nil {
|
||||
if task.err == nil && task.reqCtx.response == nil {
|
||||
task.err = context.Cause(task.ctx)
|
||||
}
|
||||
return true, task.err
|
||||
}
|
||||
default:
|
||||
return obj.createPool(option, task, key)
|
||||
return obj.createPool(task, key)
|
||||
}
|
||||
}
|
||||
|
||||
func (obj *roundTripper) createPool(option *RequestOption, task *reqTask, key string) (isOk bool, err error) {
|
||||
option.isNewConn = true
|
||||
conn, err := obj.dial(option, task.req)
|
||||
func (obj *roundTripper) createPool(task *reqTask, key string) (isOk bool, err error) {
|
||||
task.reqCtx.isNewConn = true
|
||||
conn, err := obj.dial(task.reqCtx)
|
||||
if err != nil {
|
||||
if task.option.ErrCallBack != nil {
|
||||
if err2 := task.option.ErrCallBack(task.req.Context(), task.option, nil, err); err2 != nil {
|
||||
if task.reqCtx.option.ErrCallBack != nil {
|
||||
task.reqCtx.err = err
|
||||
if err2 := task.reqCtx.option.ErrCallBack(task.reqCtx); err2 != nil {
|
||||
return true, err2
|
||||
}
|
||||
}
|
||||
@@ -364,50 +363,46 @@ func (obj *roundTripper) forceCloseConns() {
|
||||
obj.connPools.del(key)
|
||||
}
|
||||
}
|
||||
func (obj *roundTripper) newReqTask(req *http.Request, option *RequestOption) *reqTask {
|
||||
if option.ResponseHeaderTimeout == 0 {
|
||||
option.ResponseHeaderTimeout = time.Second * 300
|
||||
func (obj *roundTripper) newReqTask(ctx *Response) *reqTask {
|
||||
if ctx.option.ResponseHeaderTimeout == 0 {
|
||||
ctx.option.ResponseHeaderTimeout = time.Second * 300
|
||||
}
|
||||
task := new(reqTask)
|
||||
task.req = req
|
||||
task.option = option
|
||||
task.reqCtx = ctx
|
||||
task.emptyPool = make(chan struct{})
|
||||
return task
|
||||
}
|
||||
func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, err error) {
|
||||
option := GetRequestOption(req.Context())
|
||||
if option.RequestCallBack != nil {
|
||||
if err = option.RequestCallBack(req.Context(), req, nil); err != nil {
|
||||
func (obj *roundTripper) RoundTrip(ctx *Response) (err error) {
|
||||
if ctx.option.RequestCallBack != nil {
|
||||
if err = ctx.option.RequestCallBack(ctx); err != nil {
|
||||
if err == http.ErrUseLastResponse {
|
||||
if req.Response == nil {
|
||||
return nil, errors.New("errUseLastResponse response is nil")
|
||||
if ctx.response == nil {
|
||||
return errors.New("errUseLastResponse response is nil")
|
||||
} else {
|
||||
return req.Response, nil
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
key := getKey(option, req) //pool key
|
||||
task := obj.newReqTask(req, option)
|
||||
maxRetry := 10
|
||||
var errNum int
|
||||
key := getKey(ctx) //pool key
|
||||
task := obj.newReqTask(ctx)
|
||||
var isOk bool
|
||||
for {
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
return nil, context.Cause(req.Context())
|
||||
case <-ctx.Context().Done():
|
||||
return context.Cause(ctx.Context())
|
||||
default:
|
||||
}
|
||||
if errNum >= maxRetry {
|
||||
task.err = fmt.Errorf("roundTrip retry %d times", maxRetry)
|
||||
if task.retry >= maxRetryCount {
|
||||
task.err = fmt.Errorf("roundTrip retry %d times", maxRetryCount)
|
||||
break
|
||||
}
|
||||
pool := obj.connPools.get(key)
|
||||
if pool == nil {
|
||||
isOk, err = obj.createPool(option, task, key)
|
||||
isOk, err = obj.createPool(task, key)
|
||||
} else {
|
||||
isOk, err = obj.poolRoundTrip(option, pool, task, key)
|
||||
isOk, err = obj.poolRoundTrip(pool, task, key)
|
||||
}
|
||||
if isOk {
|
||||
if err != nil {
|
||||
@@ -416,13 +411,13 @@ func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response,
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
errNum++
|
||||
task.retry++
|
||||
}
|
||||
}
|
||||
if task.err == nil && option.RequestCallBack != nil {
|
||||
if err = option.RequestCallBack(task.req.Context(), task.req, task.res); err != nil {
|
||||
if task.err == nil && ctx.option.RequestCallBack != nil {
|
||||
if err = ctx.option.RequestCallBack(ctx); err != nil {
|
||||
task.err = err
|
||||
}
|
||||
}
|
||||
return task.res, task.err
|
||||
return task.err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user