package middleware import ( "bytes" "io" "math" "net" "net/http" "time" ) // RetryPolicy 重试策略 type RetryPolicy struct { // 最大重试次数 MaxRetries int // 基础退避时间 BaseBackoff time.Duration // 最大退避时间 MaxBackoff time.Duration // 重试判断函数 ShouldRetry func(req *http.Request, resp *http.Response, err error) bool } // DefaultRetryPolicy 默认重试策略 func DefaultRetryPolicy() *RetryPolicy { return &RetryPolicy{ MaxRetries: 3, BaseBackoff: 100 * time.Millisecond, MaxBackoff: 2 * time.Second, ShouldRetry: defaultShouldRetry, } } // defaultShouldRetry 默认重试判断 func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool { // 不重试非幂等请求 if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions { return false } // 检查错误 if err != nil { // 重试网络错误 if netErr, ok := err.(net.Error); ok { return netErr.Temporary() || netErr.Timeout() } return false } // 检查响应状态码 if resp != nil { // 重试服务器错误 return resp.StatusCode >= 500 && resp.StatusCode < 600 } return false } // RetryRoundTripper 重试HTTP传输 type RetryRoundTripper struct { // 下一级传输 Next http.RoundTripper // 重试策略 Policy *RetryPolicy } // NewRetryRoundTripper 创建重试HTTP传输 func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper { if next == nil { next = http.DefaultTransport } if policy == nil { policy = DefaultRetryPolicy() } return &RetryRoundTripper{ Next: next, Policy: policy, } } // RoundTrip 执行HTTP请求 func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // 需要保留原始请求体,以便重试 var reqBodyBytes []byte if req.Body != nil { var err error reqBodyBytes, err = io.ReadAll(req.Body) if err != nil { return nil, err } req.Body.Close() } var resp *http.Response var err error // 尝试请求直到成功或达到最大重试次数 for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ { // 复制请求体 if len(reqBodyBytes) > 0 { req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes)) } // 发送请求 resp, err = rt.Next.RoundTrip(req) // 检查是否需要重试 if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) { // 如果需要重试,先关闭当前响应 if resp != nil { resp.Body.Close() } // 计算退避时间 backoff := rt.calculateBackoff(attempt) time.Sleep(backoff) continue } // 不需要重试,返回响应 return resp, err } // 所有重试都失败 return resp, err } // calculateBackoff 计算退避时间 func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration { // 指数退避: baseBackoff * 2^attempt backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt))) if backoff > rt.Policy.MaxBackoff { backoff = rt.Policy.MaxBackoff } return backoff } // RetryMiddleware 重试中间件 type RetryMiddleware struct { policy *RetryPolicy } // NewRetryMiddleware 创建重试中间件 func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware { if policy == nil { policy = DefaultRetryPolicy() } return &RetryMiddleware{ policy: policy, } } // Middleware 中间件处理函数 func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper { return NewRetryRoundTripper(next, m.policy) }