This commit is contained in:
gospider
2025-01-10 14:34:01 +08:00
parent b2654b8c61
commit ec697ef837
16 changed files with 394 additions and 400 deletions

View File

@@ -5,7 +5,6 @@ import (
"crypto/tls"
"fmt"
"io"
"net/url"
"net/http"
@@ -74,11 +73,6 @@ func (obj *Client) SetProxy(proxyUrl string) (err error) {
return
}
// Modify the proxy method of the client
func (obj *Client) SetGetProxy(getProxy func(ctx context.Context, url *url.URL) (string, error)) {
obj.option.GetProxy = getProxy
}
// Modifying the client's proxy
func (obj *Client) SetProxys(proxyUrls []string) (err error) {
for _, proxy := range proxyUrls {
@@ -91,11 +85,6 @@ func (obj *Client) SetProxys(proxyUrls []string) (err error) {
return
}
// Modify the proxy method of the client
func (obj *Client) SetGetProxys(getProxys func(ctx context.Context, url *url.URL) ([]string, error)) {
obj.option.GetProxys = getProxys
}
// Close idle connections. If the connection is in use, wait until it ends before closing
func (obj *Client) CloseConns() {
obj.transport.closeConns()
@@ -113,72 +102,71 @@ func (obj *Client) Close() {
obj.cnl()
}
func (obj *Client) do(req *http.Request, option *RequestOption) (resp *http.Response, err error) {
func (obj *Client) do(ctx *Response) (err error) {
var redirectNum int
for {
redirectNum++
resp, err = obj.send(req, option)
if req.Body != nil {
req.Body.Close()
err = obj.send(ctx)
if ctx.Request().Body != nil {
ctx.Request().Body.Close()
}
if err != nil {
return
}
if option.MaxRedirect < 0 { //dis redirect
if ctx.Option().MaxRedirect < 0 { //dis redirect
return
}
if option.MaxRedirect > 0 && redirectNum > option.MaxRedirect {
if ctx.Option().MaxRedirect > 0 && redirectNum > ctx.Option().MaxRedirect {
return
}
loc := resp.Header.Get("Location")
loc := ctx.response.Header.Get("Location")
if loc == "" {
return resp, nil
return nil
}
u, err := req.URL.Parse(loc)
u, err := ctx.Request().URL.Parse(loc)
if err != nil {
return resp, fmt.Errorf("failed to parse Location header %q: %v", loc, err)
return fmt.Errorf("failed to parse Location header %q: %v", loc, err)
}
ireq, err := NewRequestWithContext(req.Context(), http.MethodGet, u, nil)
ctx.request, err = NewRequestWithContext(ctx.Context(), http.MethodGet, u, nil)
if err != nil {
return resp, err
return err
}
var shouldRedirect bool
ireq.Method, shouldRedirect, _ = redirectBehavior(req.Method, resp, ireq)
ctx.request.Method, shouldRedirect, _ = redirectBehavior(ctx.Request().Method, ctx.response, ctx.request)
if !shouldRedirect {
return resp, nil
return nil
}
ireq.Response = resp
ireq.Header = defaultHeaders()
ireq.Header.Set("Referer", req.URL.String())
for key := range ireq.Header {
if val := req.Header.Get(key); val != "" {
ireq.Header.Set(key, val)
ctx.request.Response = ctx.response
ctx.request.Header = defaultHeaders()
ctx.request.Header.Set("Referer", ctx.Request().URL.String())
for key := range ctx.request.Header {
if val := ctx.Request().Header.Get(key); val != "" {
ctx.request.Header.Set(key, val)
}
}
if getDomain(u) == getDomain(req.URL) {
if Authorization := req.Header.Get("Authorization"); Authorization != "" {
ireq.Header.Set("Authorization", Authorization)
if getDomain(u) == getDomain(ctx.Request().URL) {
if Authorization := ctx.Request().Header.Get("Authorization"); Authorization != "" {
ctx.request.Header.Set("Authorization", Authorization)
}
cookies := Cookies(req.Cookies()).String()
cookies := Cookies(ctx.Request().Cookies()).String()
if cookies != "" {
ireq.Header.Set("Cookie", cookies)
ctx.request.Header.Set("Cookie", cookies)
}
addCookie(ireq, resp.Cookies())
addCookie(ctx.request, ctx.response.Cookies())
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
req = ireq
io.Copy(io.Discard, ctx.response.Body)
ctx.response.Body.Close()
}
}
func (obj *Client) send(req *http.Request, option *RequestOption) (resp *http.Response, err error) {
if option.Jar != nil {
addCookie(req, option.Jar.GetCookies(req.URL))
func (obj *Client) send(ctx *Response) (err error) {
if ctx.Option().Jar != nil {
addCookie(ctx.Request(), ctx.Option().Jar.GetCookies(ctx.Request().URL))
}
resp, err = obj.transport.RoundTrip(req)
if option.Jar != nil && resp != nil {
if rc := resp.Cookies(); len(rc) > 0 {
option.Jar.SetCookies(req.URL, rc)
err = obj.transport.RoundTrip(ctx)
if ctx.Option().Jar != nil && ctx.response != nil {
if rc := ctx.response.Cookies(); len(rc) > 0 {
ctx.Option().Jar.SetCookies(ctx.Request().URL, rc)
}
}
return resp, err
return err
}

43
conn.go
View File

@@ -15,6 +15,8 @@ import (
"github.com/gospider007/tools"
)
var maxRetryCount = 10
type Conn interface {
CloseWithError(err error) error
DoRequest(*http.Request, []string) (*http.Response, error)
@@ -130,13 +132,13 @@ func (obj *connecotr) CloseWithError(err error) error {
func (obj *connecotr) wrapBody(task *reqTask) {
body := new(readWriteCloser)
obj.bodyCtx, obj.bodyCnl = context.WithCancelCause(task.req.Context())
body.body = task.res.Body
obj.bodyCtx, obj.bodyCnl = context.WithCancelCause(task.reqCtx.Context())
body.body = task.reqCtx.response.Body
body.conn = obj
task.res.Body = body
task.reqCtx.response.Body = body
}
func (obj *connecotr) httpReq(task *reqTask, done chan struct{}) {
if task.res, task.err = obj.Conn.DoRequest(task.req, task.option.OrderHeaders); task.res != nil && task.err == nil {
if task.reqCtx.response, task.err = obj.Conn.DoRequest(task.reqCtx.request, task.reqCtx.option.OrderHeaders); task.reqCtx.response != nil && task.err == nil {
obj.wrapBody(task)
} else if task.err != nil {
task.err = tools.WrapError(task.err, "roundTrip error")
@@ -146,8 +148,15 @@ func (obj *connecotr) httpReq(task *reqTask, done chan struct{}) {
func (obj *connecotr) taskMain(task *reqTask) (retry bool) {
defer func() {
if task.err != nil && task.option.ErrCallBack != nil {
if err2 := task.option.ErrCallBack(task.ctx, task.option, nil, task.err); err2 != nil {
if retry {
task.retry++
if task.retry > maxRetryCount {
retry = false
}
}
if task.err != nil && task.reqCtx.option.ErrCallBack != nil {
task.reqCtx.err = task.err
if err2 := task.reqCtx.option.ErrCallBack(task.reqCtx); err2 != nil {
retry = false
task.err = err2
}
@@ -155,12 +164,12 @@ func (obj *connecotr) taskMain(task *reqTask) (retry bool) {
if retry {
task.err = nil
obj.CloseWithError(errors.New("taskMain retry close"))
if task.res != nil && task.res.Body != nil {
task.res.Body.Close()
if task.reqCtx.response != nil && task.reqCtx.response.Body != nil {
task.reqCtx.response.Body.Close()
}
} else {
task.cnl()
if task.err == nil && task.res != nil && task.res.Body != nil {
if task.err == nil && task.reqCtx.response != nil && task.reqCtx.response.Body != nil {
select {
case <-obj.bodyCtx.Done(): //wait body close
if task.err = context.Cause(obj.bodyCtx); !errors.Is(task.err, errGospiderBodyClose) {
@@ -168,16 +177,16 @@ func (obj *connecotr) taskMain(task *reqTask) (retry bool) {
} else {
task.err = nil
}
case <-task.req.Context().Done(): //wait request close
task.err = tools.WrapError(context.Cause(task.req.Context()), "requestCtx close")
case <-task.reqCtx.Context().Done(): //wait request close
task.err = tools.WrapError(context.Cause(task.reqCtx.Context()), "requestCtx close")
case <-obj.forceCtx.Done(): //force conn close
task.err = tools.WrapError(context.Cause(obj.forceCtx), "connecotr force close")
}
}
if task.err != nil {
obj.CloseWithError(task.err)
if task.res != nil && task.res.Body != nil {
task.res.Body.Close()
if task.reqCtx.response != nil && task.reqCtx.response.Body != nil {
task.reqCtx.response.Body.Close()
}
}
}
@@ -199,16 +208,16 @@ func (obj *connecotr) taskMain(task *reqTask) (retry bool) {
if task.err != nil {
return task.suppertRetry()
}
if task.res == nil {
if task.reqCtx.response == nil {
task.err = context.Cause(task.ctx)
if task.err == nil {
task.err = errors.New("response is nil")
}
return task.suppertRetry()
}
if task.option.Logger != nil {
task.option.Logger(Log{
Id: task.option.requestId,
if task.reqCtx.option.Logger != nil {
task.reqCtx.option.Logger(Log{
Id: task.reqCtx.requestId,
Time: time.Now(),
Type: LogType_ResponseHeader,
Msg: "response header",

118
dial.go
View File

@@ -91,25 +91,22 @@ func newDialer(option DialOption) dialer {
dialer.dialer.SetMultipathTCP(true)
return &dialer
}
func (obj *Dialer) dialContext(ctx context.Context, option *RequestOption, network string, addr Address, isProxy bool) (net.Conn, error) {
if option == nil {
option = &RequestOption{}
}
func (obj *Dialer) dialContext(ctx *Response, network string, addr Address, isProxy bool) (net.Conn, error) {
var err error
if addr.IP == nil {
addr.IP, err = obj.loadHost(ctx, addr.Name, option)
addr.IP, err = obj.loadHost(ctx, addr.Name)
}
if option.Logger != nil {
if ctx.option != nil && ctx.option.Logger != nil {
if isProxy {
option.Logger(Log{
Id: option.requestId,
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyDNSLookup,
Msg: addr.Name,
})
} else {
option.Logger(Log{
Id: option.requestId,
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_DNSLookup,
Msg: addr.Name,
@@ -119,18 +116,18 @@ func (obj *Dialer) dialContext(ctx context.Context, option *RequestOption, netwo
if err != nil {
return nil, err
}
con, err := newDialer(option.DialOption).DialContext(ctx, network, addr.String())
if option.Logger != nil {
con, err := newDialer(ctx.option.DialOption).DialContext(ctx.Context(), network, addr.String())
if ctx.option != nil && ctx.option.Logger != nil {
if isProxy {
option.Logger(Log{
Id: option.requestId,
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyTCPConnect,
Msg: addr,
})
} else {
option.Logger(Log{
Id: option.requestId,
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_TCPConnect,
Msg: addr,
@@ -139,14 +136,14 @@ func (obj *Dialer) dialContext(ctx context.Context, option *RequestOption, netwo
}
return con, err
}
func (obj *Dialer) DialContext(ctx context.Context, ctxData *RequestOption, network string, addr Address) (net.Conn, error) {
return obj.dialContext(ctx, ctxData, network, addr, false)
func (obj *Dialer) DialContext(ctx *Response, network string, addr Address) (net.Conn, error) {
return obj.dialContext(ctx, network, addr, false)
}
func (obj *Dialer) ProxyDialContext(ctx context.Context, ctxData *RequestOption, network string, addr Address) (net.Conn, error) {
return obj.dialContext(ctx, ctxData, network, addr, true)
func (obj *Dialer) ProxyDialContext(ctx *Response, network string, addr Address) (net.Conn, error) {
return obj.dialContext(ctx, network, addr, true)
}
func (obj *Dialer) DialProxyContext(ctx context.Context, ctxData *RequestOption, network string, proxyTlsConfig *tls.Config, proxyUrls ...Address) (net.PacketConn, net.Conn, error) {
func (obj *Dialer) DialProxyContext(ctx *Response, network string, proxyTlsConfig *tls.Config, proxyUrls ...Address) (net.PacketConn, net.Conn, error) {
proxyLen := len(proxyUrls)
if proxyLen < 2 {
return nil, nil, errors.New("proxyUrls is nil")
@@ -158,33 +155,30 @@ func (obj *Dialer) DialProxyContext(ctx context.Context, ctxData *RequestOption,
oneProxy := proxyUrls[index]
remoteUrl := proxyUrls[index+1]
if index == 0 {
conn, err = obj.dialProxyContext(ctx, ctxData, network, oneProxy)
conn, err = obj.dialProxyContext(ctx, network, oneProxy)
if err != nil {
return packCon, conn, err
}
}
packCon, conn, err = obj.verifyProxyToRemote(ctx, ctxData, conn, proxyTlsConfig, oneProxy, remoteUrl, index == proxyLen-2)
packCon, conn, err = obj.verifyProxyToRemote(ctx, conn, proxyTlsConfig, oneProxy, remoteUrl, index == proxyLen-2)
}
return packCon, conn, err
}
func (obj *Dialer) dialProxyContext(ctx context.Context, ctxData *RequestOption, network string, proxyUrl Address) (net.Conn, error) {
if ctxData == nil {
ctxData = &RequestOption{}
}
return obj.ProxyDialContext(ctx, ctxData, network, proxyUrl)
func (obj *Dialer) dialProxyContext(ctx *Response, network string, proxyUrl Address) (net.Conn, error) {
return obj.ProxyDialContext(ctx, network, proxyUrl)
}
func (obj *Dialer) verifyProxyToRemote(ctx context.Context, option *RequestOption, conn net.Conn, proxyTlsConfig *tls.Config, proxyAddress Address, remoteAddress Address, isLast bool) (net.PacketConn, net.Conn, error) {
func (obj *Dialer) verifyProxyToRemote(ctx *Response, conn net.Conn, proxyTlsConfig *tls.Config, proxyAddress Address, remoteAddress Address, isLast bool) (net.PacketConn, net.Conn, error) {
var err error
var packCon net.PacketConn
if proxyAddress.Scheme == "https" {
if conn, err = obj.addTls(ctx, conn, proxyAddress.Host, true, proxyTlsConfig); err != nil {
if conn, err = obj.addTls(ctx.Context(), conn, proxyAddress.Host, true, proxyTlsConfig); err != nil {
return packCon, conn, err
}
if option.Logger != nil {
option.Logger(Log{
Id: option.requestId,
if ctx.option.Logger != nil {
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyTLSHandshake,
Msg: proxyAddress.String(),
@@ -195,24 +189,24 @@ func (obj *Dialer) verifyProxyToRemote(ctx context.Context, option *RequestOptio
go func() {
switch proxyAddress.Scheme {
case "http", "https":
err = obj.clientVerifyHttps(ctx, conn, proxyAddress, remoteAddress)
if option.Logger != nil {
option.Logger(Log{
Id: option.requestId,
err = obj.clientVerifyHttps(ctx.Context(), conn, proxyAddress, remoteAddress)
if ctx.option.Logger != nil {
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyConnectRemote,
Msg: remoteAddress.String(),
})
}
case "socks5":
if isLast && option.H3 {
packCon, err = obj.verifyUDPSocks5(ctx, conn, proxyAddress, remoteAddress)
if isLast && ctx.option.H3 {
packCon, err = obj.verifyUDPSocks5(ctx.Context(), conn, proxyAddress, remoteAddress)
} else {
err = obj.verifyTCPSocks5(conn, proxyAddress, remoteAddress)
}
if option.Logger != nil {
option.Logger(Log{
Id: option.requestId,
if ctx.option.Logger != nil {
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyConnectRemote,
Msg: remoteAddress.String(),
@@ -222,13 +216,13 @@ func (obj *Dialer) verifyProxyToRemote(ctx context.Context, option *RequestOptio
close(done)
}()
select {
case <-ctx.Done():
return packCon, conn, context.Cause(ctx)
case <-ctx.Context().Done():
return packCon, conn, context.Cause(ctx.Context())
case <-done:
return packCon, conn, err
}
}
func (obj *Dialer) loadHost(ctx context.Context, host string, option *RequestOption) (net.IP, error) {
func (obj *Dialer) loadHost(ctx *Response, host string) (net.IP, error) {
msgDataAny, ok := obj.dnsIpData.Load(host)
if ok {
msgdata := msgDataAny.(msgClient)
@@ -241,12 +235,12 @@ func (obj *Dialer) loadHost(ctx context.Context, host string, option *RequestOpt
return ip, nil
}
var addrType gtls.AddrType
if option.DialOption.AddrType != 0 {
addrType = option.DialOption.AddrType
} else if option.DialOption.GetAddrType != nil {
addrType = option.DialOption.GetAddrType(host)
if ctx.option.DialOption.AddrType != 0 {
addrType = ctx.option.DialOption.AddrType
} else if ctx.option.DialOption.GetAddrType != nil {
addrType = ctx.option.DialOption.GetAddrType(host)
}
ips, err := newDialer(option.DialOption).LookupIPAddr(ctx, host)
ips, err := newDialer(ctx.option.DialOption).LookupIPAddr(ctx.Context(), host)
if err != nil {
return net.IP{}, err
}
@@ -414,8 +408,8 @@ func (obj *Dialer) addJa3Tls(ctx context.Context, conn net.Conn, host string, h2
}
return ja3.NewClient(ctx, conn, spec, h2, tlsConfig)
}
func (obj *Dialer) Socks5TcpProxy(ctx context.Context, ctxData *RequestOption, proxyAddr Address, remoteAddr Address) (conn net.Conn, err error) {
if conn, err = obj.DialContext(ctx, ctxData, "tcp", proxyAddr); err != nil {
func (obj *Dialer) Socks5TcpProxy(ctx *Response, proxyAddr Address, remoteAddr Address) (conn net.Conn, err error) {
if conn, err = obj.DialContext(ctx, "tcp", proxyAddr); err != nil {
return
}
defer func() {
@@ -429,14 +423,14 @@ func (obj *Dialer) Socks5TcpProxy(ctx context.Context, ctxData *RequestOption, p
err = obj.verifyTCPSocks5(conn, proxyAddr, remoteAddr)
}()
select {
case <-ctx.Done():
return conn, context.Cause(ctx)
case <-ctx.Context().Done():
return conn, context.Cause(ctx.Context())
case <-didVerify:
return
}
}
func (obj *Dialer) Socks5UdpProxy(ctx context.Context, ctxData *RequestOption, proxyAddress Address, remoteAddress Address) (udpConn net.PacketConn, err error) {
conn, err := obj.ProxyDialContext(ctx, ctxData, "tcp", proxyAddress)
func (obj *Dialer) Socks5UdpProxy(ctx *Response, proxyAddress Address, remoteAddress Address) (udpConn net.PacketConn, err error) {
conn, err := obj.ProxyDialContext(ctx, "tcp", proxyAddress)
if err != nil {
return nil, err
}
@@ -453,10 +447,10 @@ func (obj *Dialer) Socks5UdpProxy(ctx context.Context, ctxData *RequestOption, p
didVerify := make(chan struct{})
go func() {
defer close(didVerify)
udpConn, err = obj.verifyUDPSocks5(ctx, conn, proxyAddress, remoteAddress)
if ctxData.Logger != nil {
ctxData.Logger(Log{
Id: ctxData.requestId,
udpConn, err = obj.verifyUDPSocks5(ctx.Context(), conn, proxyAddress, remoteAddress)
if ctx.option.Logger != nil {
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_ProxyConnectRemote,
Msg: remoteAddress.String(),
@@ -464,8 +458,8 @@ func (obj *Dialer) Socks5UdpProxy(ctx context.Context, ctxData *RequestOption, p
}
}()
select {
case <-ctx.Done():
return udpConn, context.Cause(ctx)
case <-ctx.Context().Done():
return udpConn, context.Cause(ctx.Context())
case <-didVerify:
return
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"io"
"net/http"
"net/url"
"time"
@@ -47,33 +46,33 @@ type Log struct {
// Connection Management Options
type ClientOption struct {
Logger func(Log) //debuggable
H3 bool //开启http3
OrderHeaders []string //order headers
Ja3Spec ja3.Spec //custom ja3Spec,use ja3.CreateSpecWithStr or ja3.CreateSpecWithId create
H2Ja3Spec ja3.H2Spec //h2 fingerprint
UJa3Spec ja3.USpec //h3 fingerprint
Proxy string //proxy,support https,http,socks5
Proxys []string //proxy list,support https,http,socks5
ForceHttp1 bool //force use http1 send requests
Ja3 bool //enable ja3 fingerprint
DisCookie bool //disable cookies
DisDecode bool //disable auto decode
DisUnZip bool //disable auto zip decode
Bar bool ////enable bar display
OptionCallBack func(ctx context.Context, option *RequestOption) error //option callback,if error is returnd, break request
ResultCallBack func(ctx context.Context, option *RequestOption, response *Response) error //result callback,if error is returnd,next errCallback
ErrCallBack func(ctx context.Context, option *RequestOption, response *Response, err error) error //error callback,if error is returnd,break request
RequestCallBack func(ctx context.Context, request *http.Request, response *http.Response) error //request and response callback,if error is returnd,reponse is error
MaxRetries int //try num
MaxRedirect int //redirect num ,<0 no redirect,==0 no limit
Headers any //default headers
Timeout time.Duration //request timeout
ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:300
TlsHandshakeTimeout time.Duration //tls timeout,default:15
UserAgent string //headers User-Agent value
GetProxy func(ctx context.Context, url *url.URL) (string, error) //proxy callback:support https,http,socks5 proxy
GetProxys func(ctx context.Context, url *url.URL) ([]string, error) //proxys callback:support https,http,socks5 proxy
Logger func(Log) //debuggable
H3 bool //开启http3
OrderHeaders []string //order headers
Ja3Spec ja3.Spec //custom ja3Spec,use ja3.CreateSpecWithStr or ja3.CreateSpecWithId create
H2Ja3Spec ja3.H2Spec //h2 fingerprint
UJa3Spec ja3.USpec //h3 fingerprint
Proxy string //proxy,support https,http,socks5
Proxys []string //proxy list,support https,http,socks5
ForceHttp1 bool //force use http1 send requests
Ja3 bool //enable ja3 fingerprint
DisCookie bool //disable cookies
DisDecode bool //disable auto decode
DisUnZip bool //disable auto zip decode
Bar bool ////enable bar display
OptionCallBack func(ctx *Response) error //option callback,if error is returnd, break request
ResultCallBack func(ctx *Response) error //result callback,if error is returnd,next errCallback
ErrCallBack func(ctx *Response) error //error callback,if error is returnd,break request
RequestCallBack func(ctx *Response) error //request and response callback,if error is returnd,reponse is error
MaxRetries int //try num
MaxRedirect int //redirect num ,<0 no redirect,==0 no limit
Headers any //default headers
Timeout time.Duration //request timeout
ResponseHeaderTimeout time.Duration //ResponseHeaderTimeout ,default:300
TlsHandshakeTimeout time.Duration //tls timeout,default:15
UserAgent string //headers User-Agent value
GetProxy func(ctx *Response) (string, error) //proxy callback:support https,http,socks5 proxy
GetProxys func(ctx *Response) ([]string, error) //proxys callback:support https,http,socks5 proxy
DialOption DialOption
Jar Jar //custom cookies
TlsConfig *tls.Config
@@ -105,16 +104,7 @@ type RequestOption struct {
WsOption websocket.Option //websocket option
DisProxy bool //force disable proxy
once bool
client *Client
requestId string
proxy *url.URL
proxys []*url.URL
isNewConn bool
}
func (obj *RequestOption) Client() *Client {
return obj.client
once bool
}
// Upload files with form-data,

View File

@@ -19,25 +19,10 @@ import (
"github.com/gospider007/websocket"
)
type contextKey string
const gospiderContextKey contextKey = "GospiderContextKey"
var errFatal = errors.New("ErrFatal")
var ErrUseLastResponse = http.ErrUseLastResponse
func CreateReqCtx(ctx context.Context, option *RequestOption) context.Context {
return context.WithValue(ctx, gospiderContextKey, option)
}
func GetRequestOption(ctx context.Context) *RequestOption {
option, ok := ctx.Value(gospiderContextKey).(*RequestOption)
if ok {
return option
}
return new(RequestOption)
}
// sends a GET request and returns the response.
func Get(ctx context.Context, href string, options ...RequestOption) (resp *Response, err error) {
return defaultClient.Request(ctx, http.MethodGet, href, options...)
@@ -149,7 +134,7 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti
if err != nil {
return nil, err
}
optionBak.requestId = tools.NaoId()
requestId := tools.NaoId()
if optionBak.Method == "" {
optionBak.Method = method
}
@@ -160,80 +145,84 @@ func (obj *Client) Request(ctx context.Context, method string, href string, opti
return
}
}
for maxRetries := 0; maxRetries <= optionBak.MaxRetries; maxRetries++ {
for ; optionBak.MaxRetries >= 0; optionBak.MaxRetries-- {
option := optionBak
option.Url = cloneUrl(uhref)
option.client = obj
response, err = obj.request(ctx, &option)
response = NewResponse(ctx, option)
response.client = obj
response.requestId = requestId
err = obj.request(response)
if err == nil || errors.Is(err, errFatal) || option.once {
return
}
optionBak.MaxRetries = option.MaxRedirect
}
return
}
func (obj *Client) request(ctx context.Context, option *RequestOption) (response *Response, err error) {
response = new(Response)
func (obj *Client) request(ctx *Response) (err error) {
defer func() {
//read body
if err == nil && !response.IsWebSocket() && !response.IsSSE() && !response.IsStream() {
err = response.ReadBody()
if err == nil && !ctx.IsWebSocket() && !ctx.IsSSE() && !ctx.IsStream() {
err = ctx.ReadBody()
}
//result callback
if err == nil && option.ResultCallBack != nil {
err = option.ResultCallBack(ctx, option, response)
if err == nil && ctx.option.ResultCallBack != nil {
err = ctx.option.ResultCallBack(ctx)
}
if err != nil { //err callback, must close body
response.CloseBody()
if option.ErrCallBack != nil {
if err2 := option.ErrCallBack(ctx, option, response, err); err2 != nil {
ctx.CloseBody()
if ctx.option.ErrCallBack != nil {
ctx.err = err
if err2 := ctx.option.ErrCallBack(ctx); err2 != nil {
err = tools.WrapError(errFatal, err2)
}
}
}
}()
if option.OptionCallBack != nil {
if err = option.OptionCallBack(ctx, option); err != nil {
if ctx.option.OptionCallBack != nil {
if err = ctx.option.OptionCallBack(ctx); err != nil {
return
}
}
response.requestOption = option
//init headers and orderheaders,befor init ctxData
headers, err := option.initHeaders()
headers, err := ctx.option.initHeaders()
if err != nil {
return response, tools.WrapError(err, errors.New("tempRequest init headers error"), err)
return tools.WrapError(err, errors.New("tempRequest init headers error"), err)
}
if headers != nil && option.UserAgent != "" {
headers.Set("User-Agent", option.UserAgent)
if headers != nil && ctx.option.UserAgent != "" {
headers.Set("User-Agent", ctx.option.UserAgent)
}
//设置 h2 请求头顺序
if option.OrderHeaders != nil {
if !option.H2Ja3Spec.IsSet() {
option.H2Ja3Spec = ja3.DefaultH2Spec()
option.H2Ja3Spec.OrderHeaders = option.OrderHeaders
} else if option.H2Ja3Spec.OrderHeaders == nil {
option.H2Ja3Spec.OrderHeaders = option.OrderHeaders
if ctx.option.OrderHeaders != nil {
if !ctx.option.H2Ja3Spec.IsSet() {
ctx.option.H2Ja3Spec = ja3.DefaultH2Spec()
ctx.option.H2Ja3Spec.OrderHeaders = ctx.option.OrderHeaders
} else if ctx.option.H2Ja3Spec.OrderHeaders == nil {
ctx.option.H2Ja3Spec.OrderHeaders = ctx.option.OrderHeaders
}
}
//init tls timeout
if option.TlsHandshakeTimeout == 0 {
option.TlsHandshakeTimeout = time.Second * 15
if ctx.option.TlsHandshakeTimeout == 0 {
ctx.option.TlsHandshakeTimeout = time.Second * 15
}
//init proxy
if option.Proxy != "" {
tempProxy, err := gtls.VerifyProxy(option.Proxy)
if ctx.option.Proxy != "" {
tempProxy, err := gtls.VerifyProxy(ctx.option.Proxy)
if err != nil {
return nil, tools.WrapError(errFatal, errors.New("tempRequest init proxy error"), err)
return tools.WrapError(errFatal, errors.New("tempRequest init proxy error"), err)
}
option.proxy = tempProxy
ctx.proxys = []*url.URL{tempProxy}
}
if l := len(option.Proxys); l > 0 {
option.proxys = make([]*url.URL, l)
for i, proxy := range option.Proxys {
if l := len(ctx.option.Proxys); l > 0 {
ctx.proxys = make([]*url.URL, l)
for i, proxy := range ctx.option.Proxys {
tempProxy, err := gtls.VerifyProxy(proxy)
if err != nil {
return response, tools.WrapError(errFatal, errors.New("tempRequest init proxy error"), err)
return tools.WrapError(errFatal, errors.New("tempRequest init proxy error"), err)
}
option.proxys[i] = tempProxy
ctx.proxys[i] = tempProxy
}
}
//init headers
@@ -241,35 +230,35 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
headers = defaultHeaders()
}
//设置 h1 请求头顺序
if option.OrderHeaders == nil {
option.OrderHeaders = ja3.DefaultOrderHeaders()
if ctx.option.OrderHeaders == nil {
ctx.option.OrderHeaders = ja3.DefaultOrderHeaders()
}
//init ctx,cnl
if option.Timeout > 0 { //超时
response.ctx, response.cnl = context.WithTimeout(CreateReqCtx(ctx, option), option.Timeout)
if ctx.option.Timeout > 0 { //超时
ctx.ctx, ctx.cnl = context.WithTimeout(ctx.Context(), ctx.option.Timeout)
} else {
response.ctx, response.cnl = context.WithCancel(CreateReqCtx(ctx, option))
ctx.ctx, ctx.cnl = context.WithCancel(ctx.Context())
}
//init Scheme
switch option.Url.Scheme {
switch ctx.option.Url.Scheme {
case "file":
response.filePath = re.Sub(`^/+`, "", option.Url.Path)
response.content, err = os.ReadFile(response.filePath)
ctx.filePath = re.Sub(`^/+`, "", ctx.option.Url.Path)
ctx.content, err = os.ReadFile(ctx.filePath)
if err != nil {
err = tools.WrapError(errFatal, errors.New("read filePath data error"), err)
}
return
case "ws":
option.ForceHttp1 = true
option.Url.Scheme = "http"
websocket.SetClientHeadersWithOption(headers, option.WsOption)
ctx.option.ForceHttp1 = true
ctx.option.Url.Scheme = "http"
websocket.SetClientHeadersWithOption(headers, ctx.option.WsOption)
case "wss":
option.ForceHttp1 = true
option.Url.Scheme = "https"
websocket.SetClientHeadersWithOption(headers, option.WsOption)
ctx.option.ForceHttp1 = true
ctx.option.Url.Scheme = "https"
websocket.SetClientHeadersWithOption(headers, ctx.option.WsOption)
}
//init url
href, err := option.initParams()
href, err := ctx.option.initParams()
if err != nil {
err = tools.WrapError(err, "url init error")
return
@@ -278,30 +267,30 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
headers.Set("Authorization", "Basic "+tools.Base64Encode(href.User.String()))
}
//init body
body, err := option.initBody(response.ctx)
body, err := ctx.option.initBody(ctx.ctx)
if err != nil {
return response, tools.WrapError(err, errors.New("tempRequest init body error"), err)
return tools.WrapError(err, errors.New("tempRequest init body error"), err)
}
//create request
reqs, err := NewRequestWithContext(response.ctx, option.Method, href, body)
reqs, err := NewRequestWithContext(ctx.Context(), ctx.option.Method, href, body)
if err != nil {
return response, tools.WrapError(errFatal, errors.New("tempRequest 构造request失败"), err)
return tools.WrapError(errFatal, errors.New("tempRequest 构造request失败"), err)
}
reqs.Header = headers
//add Referer
if reqs.Header.Get("Referer") == "" && option.Referer != "" {
reqs.Header.Set("Referer", option.Referer)
if reqs.Header.Get("Referer") == "" && ctx.option.Referer != "" {
reqs.Header.Set("Referer", ctx.option.Referer)
}
//set ContentType
if option.ContentType != "" && reqs.Header.Get("Content-Type") == "" {
reqs.Header.Set("Content-Type", option.ContentType)
if ctx.option.ContentType != "" && reqs.Header.Get("Content-Type") == "" {
reqs.Header.Set("Content-Type", ctx.option.ContentType)
}
//add host
if option.Host != "" {
reqs.Host = option.Host
if ctx.option.Host != "" {
reqs.Host = ctx.option.Host
} else if reqs.Header.Get("Host") != "" {
reqs.Host = reqs.Header.Get("Host")
} else {
@@ -309,43 +298,44 @@ func (obj *Client) request(ctx context.Context, option *RequestOption) (response
}
//init cookies
cookies, err := option.initCookies()
cookies, err := ctx.option.initCookies()
if err != nil {
return response, tools.WrapError(err, errors.New("tempRequest init cookies error"), err)
return tools.WrapError(err, errors.New("tempRequest init cookies error"), err)
}
if cookies != nil {
addCookie(reqs, cookies)
}
ctx.request = reqs
//send req
response.response, err = obj.do(reqs, option)
err = obj.do(ctx)
if err != nil && err != ErrUseLastResponse {
err = tools.WrapError(err, "client do error")
return
}
if response.response == nil {
if ctx.response == nil {
err = errors.New("response is nil")
return
}
if response.Body() != nil {
response.rawConn = response.Body().(*readWriteCloser)
if ctx.Body() != nil {
ctx.rawConn = ctx.Body().(*readWriteCloser)
}
if !response.requestOption.DisUnZip {
response.requestOption.DisUnZip = response.response.Uncompressed
if !ctx.option.DisUnZip {
ctx.option.DisUnZip = ctx.response.Uncompressed
}
if response.response.StatusCode == 101 {
response.webSocket = websocket.NewClientConn(response.rawConn.Conn(), websocket.GetResponseHeaderOption(response.response.Header))
} else if strings.Contains(response.response.Header.Get("Content-Type"), "text/event-stream") {
response.sse = newSSE(response)
} else if !response.requestOption.DisUnZip {
if ctx.response.StatusCode == 101 {
ctx.webSocket = websocket.NewClientConn(ctx.rawConn.Conn(), websocket.GetResponseHeaderOption(ctx.response.Header))
} else if strings.Contains(ctx.response.Header.Get("Content-Type"), "text/event-stream") {
ctx.sse = newSSE(ctx)
} else if !ctx.option.DisUnZip {
var unCompressionBody io.ReadCloser
unCompressionBody, err = tools.CompressionDecode(response.Body(), response.ContentEncoding())
unCompressionBody, err = tools.CompressionDecode(ctx.Body(), ctx.ContentEncoding())
if err != nil {
if err != io.ErrUnexpectedEOF && err != io.EOF {
return
}
}
if unCompressionBody != nil {
response.response.Body = unCompressionBody
ctx.response.Body = unCompressionBody
}
}
return

View File

@@ -22,18 +22,61 @@ import (
"github.com/gospider007/websocket"
)
func NewResponse(ctx context.Context, option RequestOption) *Response {
return &Response{
ctx: ctx,
option: &option,
}
}
func (obj *Response) Err() error {
if obj.err != nil {
return obj.err
}
if obj.request != nil {
return obj.request.Context().Err()
}
return obj.ctx.Err()
}
func (obj *Response) Request() *http.Request {
return obj.request
}
func (obj *Response) Response() *http.Response {
return obj.response
}
func (obj *Response) Context() context.Context {
if obj.request != nil {
return obj.request.Context()
}
return obj.ctx
}
func (obj *Response) Option() *RequestOption {
return obj.option
}
func (obj *Response) Client() *Client {
return obj.client
}
type Response struct {
rawConn *readWriteCloser
response *http.Response
webSocket *websocket.Conn
sse *SSE
ctx context.Context
cnl context.CancelFunc
requestOption *RequestOption
content []byte
encoding string
filePath string
readBody bool
err error
request *http.Request
rawConn *readWriteCloser
response *http.Response
webSocket *websocket.Conn
sse *SSE
ctx context.Context
cnl context.CancelFunc
option *RequestOption
content []byte
encoding string
filePath string
readBody bool
client *Client
requestId string
proxys []*url.URL
isNewConn bool
}
type SSE struct {
@@ -271,7 +314,7 @@ func (obj *Response) Body() io.ReadCloser {
// return true if response is stream
func (obj *Response) IsStream() bool {
return obj.requestOption.Stream
return obj.option.Stream
}
// return true if response is other stream
@@ -294,7 +337,7 @@ func (obj *Response) ReadBody() (err error) {
bBody := bytes.NewBuffer(nil)
done := make(chan struct{})
go func() {
if obj.requestOption.Bar && obj.ContentLength() > 0 {
if obj.option.Bar && obj.ContentLength() > 0 {
_, err = io.Copy(&barBody{
bar: bar.NewClient(obj.response.ContentLength),
body: bBody,
@@ -312,9 +355,9 @@ func (obj *Response) ReadBody() (err error) {
err = obj.ctx.Err()
case <-done:
}
if obj.requestOption.Logger != nil {
obj.requestOption.Logger(Log{
Id: obj.requestOption.requestId,
if obj.option.Logger != nil {
obj.option.Logger(Log{
Id: obj.requestId,
Time: time.Now(),
Type: LogType_ResponseBody,
Msg: "response body",
@@ -324,7 +367,7 @@ func (obj *Response) ReadBody() (err error) {
obj.ForceCloseConn()
return errors.New("response read content error: " + err.Error())
}
if !obj.requestOption.DisDecode && obj.defaultDecode() {
if !obj.option.DisDecode && obj.defaultDecode() {
obj.content, obj.encoding, _ = tools.Charset(bBody.Bytes(), obj.ContentType())
} else {
obj.content = bBody.Bytes()
@@ -335,7 +378,7 @@ func (obj *Response) ReadBody() (err error) {
// conn is new conn
func (obj *Response) IsNewConn() bool {
return obj.requestOption.isNewConn
return obj.isNewConn
}
// conn proxy

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"strings"
"time"
"net/http"
@@ -20,27 +21,31 @@ import (
)
type reqTask struct {
option *RequestOption
ctx context.Context
cnl context.CancelFunc
req *http.Request
res *http.Response
reqCtx *Response
emptyPool chan struct{}
err error
retry int
}
func (obj *reqTask) suppertRetry() bool {
if obj.req.Body == nil {
if obj.reqCtx.request.Body == nil {
return true
} else if body, ok := obj.req.Body.(io.Seeker); ok {
} else if body, ok := obj.reqCtx.request.Body.(io.Seeker); ok {
if i, err := body.Seek(0, io.SeekStart); i == 0 && err == nil {
return true
}
}
return false
}
func getKey(option *RequestOption, req *http.Request) (key string) {
return fmt.Sprintf("%s@%s", getAddr(option.proxy), getAddr(req.URL))
func getKey(ctx *Response) (key string) {
adds := []string{}
for _, p := range ctx.proxys {
adds = append(adds, getAddr(p))
}
adds = append(adds, getAddr(ctx.Request().URL))
return strings.Join(adds, "@")
}
type roundTripper struct {
@@ -91,37 +96,37 @@ func (obj *roundTripper) newConnecotr() *connecotr {
return conne
}
func (obj *roundTripper) http3Dial(ctx context.Context, option *RequestOption, remtoeAddress Address, proxyAddress ...Address) (udpConn net.PacketConn, err error) {
func (obj *roundTripper) http3Dial(ctx *Response, remtoeAddress Address, proxyAddress ...Address) (udpConn net.PacketConn, err error) {
if len(proxyAddress) > 0 {
if proxyAddress[len(proxyAddress)-1].Scheme != "socks5" {
err = errors.New("http3 last proxy must socks5 proxy")
return
}
udpConn, _, err = obj.dialer.DialProxyContext(ctx, option, "tcp", option.TlsConfig.Clone(), append(proxyAddress, remtoeAddress)...)
udpConn, _, err = obj.dialer.DialProxyContext(ctx, "tcp", ctx.option.TlsConfig.Clone(), append(proxyAddress, remtoeAddress)...)
} else {
udpConn, err = net.ListenUDP("udp", nil)
}
return
}
func (obj *roundTripper) ghttp3Dial(ctx context.Context, option *RequestOption, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
udpConn, err := obj.http3Dial(ctx, option, remoteAddress, proxyAddress...)
func (obj *roundTripper) ghttp3Dial(ctx *Response, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
udpConn, err := obj.http3Dial(ctx, remoteAddress, proxyAddress...)
if err != nil {
return nil, err
}
tlsConfig := option.TlsConfig.Clone()
tlsConfig := ctx.option.TlsConfig.Clone()
tlsConfig.NextProtos = []string{http3.NextProtoH3}
tlsConfig.ServerName = remoteAddress.Host
if remoteAddress.IP == nil {
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name, option)
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name)
if err != nil {
return nil, err
}
}
var quicConfig *quic.Config
if option.UquicConfig != nil {
quicConfig = option.QuicConfig.Clone()
if ctx.option.UquicConfig != nil {
quicConfig = ctx.option.QuicConfig.Clone()
}
netConn, err := quic.DialEarly(ctx, udpConn, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
netConn, err := quic.DialEarly(ctx.Context(), udpConn, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
conn = obj.newConnecotr()
conn.Conn = http3.NewClient(netConn, func() {
conn.forceCnl(errors.New("http3 client close"))
@@ -129,34 +134,34 @@ func (obj *roundTripper) ghttp3Dial(ctx context.Context, option *RequestOption,
return
}
func (obj *roundTripper) uhttp3Dial(ctx context.Context, option *RequestOption, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
spec, err := ja3.CreateSpecWithUSpec(option.UJa3Spec)
func (obj *roundTripper) uhttp3Dial(ctx *Response, remoteAddress Address, proxyAddress ...Address) (conn *connecotr, err error) {
spec, err := ja3.CreateSpecWithUSpec(ctx.option.UJa3Spec)
if err != nil {
return nil, err
}
udpConn, err := obj.http3Dial(ctx, option, remoteAddress, proxyAddress...)
udpConn, err := obj.http3Dial(ctx, remoteAddress, proxyAddress...)
if err != nil {
return nil, err
}
tlsConfig := option.UtlsConfig.Clone()
tlsConfig := ctx.option.UtlsConfig.Clone()
tlsConfig.NextProtos = []string{http3.NextProtoH3}
tlsConfig.ServerName = remoteAddress.Host
if remoteAddress.IP == nil {
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name, option)
remoteAddress.IP, err = obj.dialer.loadHost(ctx, remoteAddress.Name)
if err != nil {
return nil, err
}
}
var quicConfig *uquic.Config
if option.UquicConfig != nil {
quicConfig = option.UquicConfig.Clone()
if ctx.option.UquicConfig != nil {
quicConfig = ctx.option.UquicConfig.Clone()
}
netConn, err := (&uquic.UTransport{
Transport: &uquic.Transport{
Conn: udpConn,
},
QUICSpec: &spec,
}).DialEarly(ctx, &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
}).DialEarly(ctx.Context(), &net.UDPAddr{IP: remoteAddress.IP, Port: remoteAddress.Port}, tlsConfig, quicConfig)
conn = obj.newConnecotr()
conn.Conn = http3.NewUClient(netConn, func() {
conn.forceCnl(errors.New("http3 client close"))
@@ -164,32 +169,32 @@ func (obj *roundTripper) uhttp3Dial(ctx context.Context, option *RequestOption,
return
}
func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *connecotr, err error) {
proxys, err := obj.initProxys(option, req)
func (obj *roundTripper) dial(ctx *Response) (conn *connecotr, err error) {
proxys, err := obj.initProxys(ctx)
if err != nil {
return nil, err
}
remoteAddress, err := GetAddressWithUrl(req.URL)
remoteAddress, err := GetAddressWithUrl(ctx.request.URL)
if err != nil {
return nil, err
}
if option.H3 {
if option.UJa3Spec.IsSet() {
return obj.uhttp3Dial(req.Context(), option, remoteAddress, proxys...)
if ctx.option.H3 {
if ctx.option.UJa3Spec.IsSet() {
return obj.uhttp3Dial(ctx, remoteAddress, proxys...)
} else {
return obj.ghttp3Dial(req.Context(), option, remoteAddress, proxys...)
return obj.ghttp3Dial(ctx, remoteAddress, proxys...)
}
}
var netConn net.Conn
if len(proxys) > 0 {
_, netConn, err = obj.dialer.DialProxyContext(req.Context(), option, "tcp", option.TlsConfig.Clone(), append(proxys, remoteAddress)...)
_, netConn, err = obj.dialer.DialProxyContext(ctx, "tcp", ctx.option.TlsConfig.Clone(), append(proxys, remoteAddress)...)
} else {
var remoteAddress Address
remoteAddress, err = GetAddressWithUrl(req.URL)
remoteAddress, err = GetAddressWithUrl(ctx.request.URL)
if err != nil {
return nil, err
}
netConn, err = obj.dialer.DialContext(req.Context(), option, "tcp", remoteAddress)
netConn, err = obj.dialer.DialContext(ctx, "tcp", remoteAddress)
}
defer func() {
if err != nil && netConn != nil {
@@ -200,14 +205,14 @@ func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *c
return nil, err
}
var h2 bool
if req.URL.Scheme == "https" {
netConn, h2, err = obj.dialAddTls(option, req, netConn)
if option.Logger != nil {
option.Logger(Log{
Id: option.requestId,
if ctx.request.URL.Scheme == "https" {
netConn, h2, err = obj.dialAddTls(ctx.option, ctx.request, netConn)
if ctx.option.Logger != nil {
ctx.option.Logger(Log{
Id: ctx.requestId,
Time: time.Now(),
Type: LogType_TLSHandshake,
Msg: fmt.Sprintf("host:%s, h2:%t", getHost(req), h2),
Msg: fmt.Sprintf("host:%s, h2:%t", getHost(ctx.request), h2),
})
}
if err != nil {
@@ -217,7 +222,7 @@ func (obj *roundTripper) dial(option *RequestOption, req *http.Request) (conn *c
conne := obj.newConnecotr()
conne.proxys = proxys
conne.c = netConn
err = obj.dialConnecotr(option, req, conne, h2)
err = obj.dialConnecotr(ctx.option, ctx.request, conne, h2)
if err != nil {
return nil, err
}
@@ -257,21 +262,14 @@ func (obj *roundTripper) dialAddTls(option *RequestOption, req *http.Request, ne
}
}
}
func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([]Address, error) {
func (obj *roundTripper) initProxys(ctx *Response) ([]Address, error) {
var proxys []Address
if option.DisProxy {
if ctx.option.DisProxy {
return nil, nil
}
if option.proxy != nil {
proxyAddress, err := GetAddressWithUrl(option.proxy)
if err != nil {
return nil, err
}
proxys = []Address{proxyAddress}
}
if len(proxys) == 0 && len(option.proxys) > 0 {
proxys = make([]Address, len(option.proxys))
for i, proxy := range option.proxys {
if len(ctx.proxys) > 0 {
proxys = make([]Address, len(ctx.proxys))
for i, proxy := range ctx.proxys {
proxyAddress, err := GetAddressWithUrl(proxy)
if err != nil {
return nil, err
@@ -279,8 +277,8 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
proxys[i] = proxyAddress
}
}
if len(proxys) == 0 && option.GetProxy != nil {
proxyStr, err := option.GetProxy(req.Context(), req.URL)
if len(proxys) == 0 && ctx.option.GetProxy != nil {
proxyStr, err := ctx.option.GetProxy(ctx)
if err != nil {
return proxys, err
}
@@ -296,8 +294,8 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
proxys = []Address{proxyAddress}
}
}
if len(proxys) == 0 && option.GetProxys != nil {
proxyStrs, err := option.GetProxys(req.Context(), req.URL)
if len(proxys) == 0 && ctx.option.GetProxys != nil {
proxyStrs, err := ctx.option.GetProxys(ctx)
if err != nil {
return proxys, err
}
@@ -319,30 +317,31 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
return proxys, nil
}
func (obj *roundTripper) poolRoundTrip(option *RequestOption, pool *connPool, task *reqTask, key string) (isOk bool, err error) {
task.ctx, task.cnl = context.WithTimeout(task.req.Context(), option.ResponseHeaderTimeout)
func (obj *roundTripper) poolRoundTrip(pool *connPool, task *reqTask, key string) (isOk bool, err error) {
task.ctx, task.cnl = context.WithTimeout(task.reqCtx.Context(), task.reqCtx.option.ResponseHeaderTimeout)
select {
case pool.tasks <- task:
select {
case <-task.emptyPool:
return false, nil
case <-task.ctx.Done():
if task.err == nil && task.res == nil {
if task.err == nil && task.reqCtx.response == nil {
task.err = context.Cause(task.ctx)
}
return true, task.err
}
default:
return obj.createPool(option, task, key)
return obj.createPool(task, key)
}
}
func (obj *roundTripper) createPool(option *RequestOption, task *reqTask, key string) (isOk bool, err error) {
option.isNewConn = true
conn, err := obj.dial(option, task.req)
func (obj *roundTripper) createPool(task *reqTask, key string) (isOk bool, err error) {
task.reqCtx.isNewConn = true
conn, err := obj.dial(task.reqCtx)
if err != nil {
if task.option.ErrCallBack != nil {
if err2 := task.option.ErrCallBack(task.req.Context(), task.option, nil, err); err2 != nil {
if task.reqCtx.option.ErrCallBack != nil {
task.reqCtx.err = err
if err2 := task.reqCtx.option.ErrCallBack(task.reqCtx); err2 != nil {
return true, err2
}
}
@@ -364,50 +363,46 @@ func (obj *roundTripper) forceCloseConns() {
obj.connPools.del(key)
}
}
func (obj *roundTripper) newReqTask(req *http.Request, option *RequestOption) *reqTask {
if option.ResponseHeaderTimeout == 0 {
option.ResponseHeaderTimeout = time.Second * 300
func (obj *roundTripper) newReqTask(ctx *Response) *reqTask {
if ctx.option.ResponseHeaderTimeout == 0 {
ctx.option.ResponseHeaderTimeout = time.Second * 300
}
task := new(reqTask)
task.req = req
task.option = option
task.reqCtx = ctx
task.emptyPool = make(chan struct{})
return task
}
func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, err error) {
option := GetRequestOption(req.Context())
if option.RequestCallBack != nil {
if err = option.RequestCallBack(req.Context(), req, nil); err != nil {
func (obj *roundTripper) RoundTrip(ctx *Response) (err error) {
if ctx.option.RequestCallBack != nil {
if err = ctx.option.RequestCallBack(ctx); err != nil {
if err == http.ErrUseLastResponse {
if req.Response == nil {
return nil, errors.New("errUseLastResponse response is nil")
if ctx.response == nil {
return errors.New("errUseLastResponse response is nil")
} else {
return req.Response, nil
return nil
}
}
return nil, err
return err
}
}
key := getKey(option, req) //pool key
task := obj.newReqTask(req, option)
maxRetry := 10
var errNum int
key := getKey(ctx) //pool key
task := obj.newReqTask(ctx)
var isOk bool
for {
select {
case <-req.Context().Done():
return nil, context.Cause(req.Context())
case <-ctx.Context().Done():
return context.Cause(ctx.Context())
default:
}
if errNum >= maxRetry {
task.err = fmt.Errorf("roundTrip retry %d times", maxRetry)
if task.retry >= maxRetryCount {
task.err = fmt.Errorf("roundTrip retry %d times", maxRetryCount)
break
}
pool := obj.connPools.get(key)
if pool == nil {
isOk, err = obj.createPool(option, task, key)
isOk, err = obj.createPool(task, key)
} else {
isOk, err = obj.poolRoundTrip(option, pool, task, key)
isOk, err = obj.poolRoundTrip(pool, task, key)
}
if isOk {
if err != nil {
@@ -416,13 +411,13 @@ func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response,
break
}
if err != nil {
errNum++
task.retry++
}
}
if task.err == nil && option.RequestCallBack != nil {
if err = option.RequestCallBack(task.req.Context(), task.req, task.res); err != nil {
if task.err == nil && ctx.option.RequestCallBack != nil {
if err = ctx.option.RequestCallBack(ctx); err != nil {
task.err = err
}
}
return task.res, task.err
return task.err
}

View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"log"
"net/http"
"testing"
"github.com/gospider007/requests"
@@ -18,9 +17,8 @@ func TestSetCookies(t *testing.T) {
}
_, err = session.Get(context.TODO(), "https://www.baidu.com", requests.RequestOption{
ClientOption: requests.ClientOption{
RequestCallBack: func(ctx context.Context, request *http.Request, response *http.Response) error {
if request.Cookies() == nil {
RequestCallBack: func(ctx *requests.Response) error {
if ctx.Request().Cookies() == nil {
log.Panic("cookie is nil")
}
return nil

View File

@@ -1,7 +1,6 @@
package main
import (
"context"
"errors"
"testing"
@@ -14,10 +13,10 @@ func TestErrCallBack(t *testing.T) {
ClientOption: requests.ClientOption{
MaxRetries: 3,
ResultCallBack: func(ctx context.Context, option *requests.RequestOption, response *requests.Response) error {
ResultCallBack: func(ctx *requests.Response) error {
return errors.New("try")
},
ErrCallBack: func(ctx context.Context, option *requests.RequestOption, response *requests.Response, err error) error {
ErrCallBack: func(ctx *requests.Response) error {
if n == 0 {
n++
return nil

View File

@@ -1,7 +1,6 @@
package main
import (
"context"
"testing"
"github.com/gospider007/requests"
@@ -11,8 +10,8 @@ func TestOptionCallBack(t *testing.T) {
resp, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{
ClientOption: requests.ClientOption{
OptionCallBack: func(ctx context.Context, option *requests.RequestOption) error {
option.Params = map[string]string{"name": "test"}
OptionCallBack: func(ctx *requests.Response) error {
ctx.Option().Params = map[string]string{"name": "test"}
return nil
},
},

View File

@@ -2,7 +2,6 @@ package main
import (
"context"
"net/http"
"testing"
"github.com/gospider007/requests"
@@ -11,9 +10,8 @@ import (
func TestRedirectCallBack(t *testing.T) {
response, err := requests.Get(context.TODO(), "http://www.baidu.com", requests.RequestOption{
ClientOption: requests.ClientOption{
RequestCallBack: func(ctx context.Context, request *http.Request, response *http.Response) error {
if response != nil {
RequestCallBack: func(ctx *requests.Response) error {
if ctx.Response() != nil {
return requests.ErrUseLastResponse
}
return nil

View File

@@ -1,9 +1,7 @@
package main
import (
"context"
"errors"
"net/http"
"strings"
"testing"
@@ -13,8 +11,8 @@ import (
func TestRequestCallBack(t *testing.T) {
_, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{
ClientOption: requests.ClientOption{
RequestCallBack: func(ctx context.Context, request *http.Request, response *http.Response) error {
RequestCallBack: func(ctx *requests.Response) error {
response := ctx.Response()
if response != nil {
if response.ContentLength > 100 {
return errors.New("max length")

View File

@@ -1,7 +1,6 @@
package main
import (
"context"
"errors"
"testing"
@@ -12,12 +11,11 @@ func TestResultCallBack(t *testing.T) {
var code int
_, err := requests.Get(nil, "https://httpbin.org/anything", requests.RequestOption{
ClientOption: requests.ClientOption{
ResultCallBack: func(ctx context.Context, option *requests.RequestOption, response *requests.Response) error {
if response.StatusCode() != 200 {
ResultCallBack: func(ctx *requests.Response) error {
if ctx.StatusCode() != 200 {
return errors.New("resp.StatusCode!= 200")
}
code = response.StatusCode()
code = ctx.StatusCode()
return nil
},
},

View File

@@ -1,7 +1,6 @@
package main
import (
"context"
"errors"
"testing"
@@ -14,8 +13,7 @@ func TestMaxRetries(t *testing.T) {
ClientOption: requests.ClientOption{
MaxRetries: 3,
ResultCallBack: func(ctx context.Context, option *requests.RequestOption, response *requests.Response) error {
ResultCallBack: func(ctx *requests.Response) error {
if n == 0 {
n++
return errors.New("try")

View File

@@ -1,7 +1,6 @@
package main
import (
"context"
"log"
"testing"
@@ -16,8 +15,8 @@ func TestHttp1(t *testing.T) {
Logger: func(l requests.Log) {
log.Print(l)
},
ErrCallBack: func(ctx context.Context, option *requests.RequestOption, response *requests.Response, err error) error {
log.Print(err)
ErrCallBack: func(ctx *requests.Response) error {
log.Print(ctx.Err())
return nil
},
},

View File

@@ -1,8 +1,6 @@
package main
import (
"context"
"net/url"
"testing"
"github.com/gospider007/requests"
@@ -21,7 +19,7 @@ func TestProxy(t *testing.T) {
}
func TestGetProxy(t *testing.T) {
session, _ := requests.NewClient(nil, requests.ClientOption{
GetProxy: func(ctx context.Context, url *url.URL) (string, error) { //Penalty when creating a new connection
GetProxy: func(ctx *requests.Response) (string, error) { //Penalty when creating a new connection
proxy := "" //set proxy,ex:"http://127.0.0.1:8080","https://127.0.0.1:8080","socks5://127.0.0.1:8080"
return proxy, nil
},