This commit is contained in:
gospider
2024-12-17 13:30:05 +08:00
parent b0f620036d
commit 71a69f933a
4 changed files with 43 additions and 36 deletions

View File

@@ -46,8 +46,6 @@ type roundTripper struct {
cnl context.CancelFunc
connPools *connPools
dialer *DialClient
getProxy func(ctx context.Context, url *url.URL) (string, error)
getProxys func(ctx context.Context, url *url.URL) ([]string, error)
}
func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper {
@@ -67,12 +65,10 @@ func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper
ctx: ctx,
cnl: cnl,
dialer: dialClient,
getProxy: option.GetProxy,
getProxys: option.GetProxys,
connPools: newConnPools(),
}
}
func (obj *roundTripper) newConnPool(conn *connecotr, key string) *connPool {
func (obj *roundTripper) newConnPool(done chan struct{}, conn *connecotr, key string) *connPool {
pool := new(connPool)
pool.connKey = key
pool.forceCtx, pool.forceCnl = context.WithCancelCause(obj.ctx)
@@ -81,17 +77,19 @@ func (obj *roundTripper) newConnPool(conn *connecotr, key string) *connPool {
pool.connPools = obj.connPools
pool.total.Add(1)
go pool.rwMain(conn)
go pool.rwMain(done, conn)
return pool
}
func (obj *roundTripper) putConnPool(key string, conn *connecotr) {
pool := obj.connPools.get(key)
done := make(chan struct{})
if pool != nil {
pool.total.Add(1)
go pool.rwMain(conn)
go pool.rwMain(done, conn)
} else {
obj.connPools.set(key, obj.newConnPool(conn, key))
obj.connPools.set(key, obj.newConnPool(done, conn, key))
}
<-done
}
func (obj *roundTripper) newConnecotr() *connecotr {
conne := new(connecotr)
@@ -227,19 +225,21 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
proxys[i] = cloneUrl(proxy)
}
}
if len(proxys) == 0 && obj.getProxy != nil {
proxyStr, err := obj.getProxy(req.Context(), req.URL)
if len(proxys) == 0 && option.GetProxy != nil {
proxyStr, err := option.GetProxy(req.Context(), req.URL)
if err != nil {
return proxys, err
}
proxy, err := gtls.VerifyProxy(proxyStr)
if err != nil {
return proxys, err
if proxyStr != "" {
proxy, err := gtls.VerifyProxy(proxyStr)
if err != nil {
return proxys, err
}
proxys = []*url.URL{proxy}
}
proxys = []*url.URL{proxy}
}
if len(proxys) == 0 && obj.getProxys != nil {
proxyStrs, err := obj.getProxys(req.Context(), req.URL)
if len(proxys) == 0 && option.GetProxys != nil {
proxyStrs, err := option.GetProxys(req.Context(), req.URL)
if err != nil {
return proxys, err
}
@@ -257,13 +257,6 @@ func (obj *roundTripper) initProxys(option *RequestOption, req *http.Request) ([
return proxys, nil
}
func (obj *roundTripper) setGetProxy(getProxy func(ctx context.Context, url *url.URL) (string, error)) {
obj.getProxy = getProxy
}
func (obj *roundTripper) setGetProxys(getProxys func(ctx context.Context, url *url.URL) ([]string, error)) {
obj.getProxys = getProxys
}
func (obj *roundTripper) poolRoundTrip(option *RequestOption, pool *connPool, task *reqTask, key string) (isOk bool, err error) {
task.ctx, task.cnl = context.WithTimeout(task.req.Context(), option.ResponseHeaderTimeout)
select {