This commit is contained in:
gospider
2024-12-15 15:04:53 +08:00
parent 4edba914a8
commit 1c2a6b332e
2 changed files with 39 additions and 32 deletions

15
conn.go
View File

@@ -150,8 +150,14 @@ func (obj *connecotr) waitBodyClose() error {
}
}
func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) {
func (obj *connecotr) taskMain(task *reqTask) (retry bool) {
defer func() {
if task.err != nil && task.option.ErrCallBack != nil {
if err2 := task.option.ErrCallBack(task.ctx, task.option, nil, task.err); err2 != nil {
retry = false
task.err = err2
}
}
if retry {
task.err = nil
obj.rawConn.CloseWithError(errors.New("taskMain retry close"))
@@ -159,7 +165,7 @@ func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) {
task.cnl()
if task.err != nil {
obj.rawConn.CloseWithError(task.err)
} else if waitBody {
} else {
if err := obj.waitBodyClose(); err != nil {
obj.rawConn.CloseWithError(err)
}
@@ -274,9 +280,6 @@ func (obj *connPool) rwMain(conn *connecotr) {
obj.safeClose()
}
}()
if err := conn.waitBodyClose(); err != nil {
return
}
for {
select {
case <-conn.safeCtx.Done(): //safe close conn
@@ -287,7 +290,7 @@ func (obj *connPool) rwMain(conn *connecotr) {
if task == nil {
return
}
if conn.taskMain(task, true) {
if conn.taskMain(task) {
obj.notice(task)
return
}

View File

@@ -235,47 +235,37 @@ func (obj *roundTripper) setGetProxys(getProxys func(ctx context.Context, url *u
obj.getProxys = getProxys
}
func (obj *roundTripper) poolRoundTrip(option *RequestOption, pool *connPool, task *reqTask, key string) (isTry bool) {
func (obj *roundTripper) poolRoundTrip(option *RequestOption, pool *connPool, task *reqTask, key string) (isOk bool, err error) {
task.ctx, task.cnl = context.WithTimeout(task.req.Context(), option.ResponseHeaderTimeout)
select {
case pool.tasks <- task:
select {
case <-task.emptyPool:
return true
return false, nil
case <-task.ctx.Done():
if task.err == nil && task.res == nil {
task.err = context.Cause(task.ctx)
}
return false
return true, nil
}
default:
obj.connRoundTripMain(option, task, key)
return false
return obj.createPool(option, task, key)
}
}
func (obj *roundTripper) connRoundTripMain(option *RequestOption, task *reqTask, key string) {
for range 10 {
if !obj.connRoundTrip(option, task, key) {
return
}
}
task.err = errors.New("connRoundTripMain retry 5 times")
}
func (obj *roundTripper) connRoundTrip(option *RequestOption, task *reqTask, key string) (retry bool) {
func (obj *roundTripper) createPool(option *RequestOption, task *reqTask, key string) (isOk bool, err error) {
option.isNewConn = true
conn, err := obj.dial(option, task.req)
if err != nil {
task.err = err
return
}
task.ctx, task.cnl = context.WithTimeout(task.req.Context(), option.ResponseHeaderTimeout)
retry = conn.taskMain(task, false)
if retry || task.err != nil {
return retry
if task.option.ErrCallBack != nil {
if err2 := task.option.ErrCallBack(task.req.Context(), task.option, nil, err); err2 != nil {
return true, err2
}
}
return false, err
}
obj.putConnPool(key, conn)
return retry
return false, nil
}
func (obj *roundTripper) closeConns() {
@@ -316,15 +306,29 @@ func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response,
}
key := getKey(option, req) //pool key
task := obj.newReqTask(req, option)
maxRetry := 10
var errNum int
var isOk bool
for {
pool := obj.connPools.get(key)
if pool == nil {
obj.connRoundTripMain(option, task, key)
if errNum >= maxRetry {
task.err = fmt.Errorf("roundTrip retry %d times", maxRetry)
break
}
if !obj.poolRoundTrip(option, pool, task, key) {
pool := obj.connPools.get(key)
if pool == nil {
isOk, err = obj.createPool(option, task, key)
} else {
isOk, err = obj.poolRoundTrip(option, pool, task, key)
}
if isOk {
if err != nil {
task.err = err
}
break
}
if err != nil {
errNum++
}
}
if task.err == nil && option.RequestCallBack != nil {
if err = option.RequestCallBack(task.req.Context(), task.req, task.res); err != nil {