157 lines
3.5 KiB
Go
157 lines
3.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
)
|
|
|
|
// RetryPolicy 重试策略
|
|
type RetryPolicy struct {
|
|
// 最大重试次数
|
|
MaxRetries int
|
|
// 基础退避时间
|
|
BaseBackoff time.Duration
|
|
// 最大退避时间
|
|
MaxBackoff time.Duration
|
|
// 重试判断函数
|
|
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
|
|
}
|
|
|
|
// DefaultRetryPolicy 默认重试策略
|
|
func DefaultRetryPolicy() *RetryPolicy {
|
|
return &RetryPolicy{
|
|
MaxRetries: 3,
|
|
BaseBackoff: 100 * time.Millisecond,
|
|
MaxBackoff: 2 * time.Second,
|
|
ShouldRetry: defaultShouldRetry,
|
|
}
|
|
}
|
|
|
|
// defaultShouldRetry 默认重试判断
|
|
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
|
|
// 不重试非幂等请求
|
|
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
|
|
return false
|
|
}
|
|
|
|
// 检查错误
|
|
if err != nil {
|
|
// 重试网络错误
|
|
if netErr, ok := err.(net.Error); ok {
|
|
return netErr.Temporary() || netErr.Timeout()
|
|
}
|
|
return false
|
|
}
|
|
|
|
// 检查响应状态码
|
|
if resp != nil {
|
|
// 重试服务器错误
|
|
return resp.StatusCode >= 500 && resp.StatusCode < 600
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// RetryRoundTripper 重试HTTP传输
|
|
type RetryRoundTripper struct {
|
|
// 下一级传输
|
|
Next http.RoundTripper
|
|
// 重试策略
|
|
Policy *RetryPolicy
|
|
}
|
|
|
|
// NewRetryRoundTripper 创建重试HTTP传输
|
|
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
|
|
if next == nil {
|
|
next = http.DefaultTransport
|
|
}
|
|
if policy == nil {
|
|
policy = DefaultRetryPolicy()
|
|
}
|
|
return &RetryRoundTripper{
|
|
Next: next,
|
|
Policy: policy,
|
|
}
|
|
}
|
|
|
|
// RoundTrip 执行HTTP请求
|
|
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// 需要保留原始请求体,以便重试
|
|
var reqBodyBytes []byte
|
|
if req.Body != nil {
|
|
var err error
|
|
reqBodyBytes, err = io.ReadAll(req.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Body.Close()
|
|
}
|
|
|
|
var resp *http.Response
|
|
var err error
|
|
|
|
// 尝试请求直到成功或达到最大重试次数
|
|
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
|
|
// 复制请求体
|
|
if len(reqBodyBytes) > 0 {
|
|
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
|
|
}
|
|
|
|
// 发送请求
|
|
resp, err = rt.Next.RoundTrip(req)
|
|
|
|
// 检查是否需要重试
|
|
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
|
|
// 如果需要重试,先关闭当前响应
|
|
if resp != nil {
|
|
resp.Body.Close()
|
|
}
|
|
|
|
// 计算退避时间
|
|
backoff := rt.calculateBackoff(attempt)
|
|
time.Sleep(backoff)
|
|
continue
|
|
}
|
|
|
|
// 不需要重试,返回响应
|
|
return resp, err
|
|
}
|
|
|
|
// 所有重试都失败
|
|
return resp, err
|
|
}
|
|
|
|
// calculateBackoff 计算退避时间
|
|
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
|
|
// 指数退避: baseBackoff * 2^attempt
|
|
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
|
|
if backoff > rt.Policy.MaxBackoff {
|
|
backoff = rt.Policy.MaxBackoff
|
|
}
|
|
return backoff
|
|
}
|
|
|
|
// RetryMiddleware 重试中间件
|
|
type RetryMiddleware struct {
|
|
policy *RetryPolicy
|
|
}
|
|
|
|
// NewRetryMiddleware 创建重试中间件
|
|
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
|
|
if policy == nil {
|
|
policy = DefaultRetryPolicy()
|
|
}
|
|
return &RetryMiddleware{
|
|
policy: policy,
|
|
}
|
|
}
|
|
|
|
// Middleware 中间件处理函数
|
|
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
|
|
return NewRetryRoundTripper(next, m.policy)
|
|
}
|