Files
goproxy/internal/middleware/retry.go
2025-03-13 15:56:33 +08:00

157 lines
3.5 KiB
Go

package middleware
import (
"bytes"
"io"
"math"
"net"
"net/http"
"time"
)
// RetryPolicy 重试策略
type RetryPolicy struct {
// 最大重试次数
MaxRetries int
// 基础退避时间
BaseBackoff time.Duration
// 最大退避时间
MaxBackoff time.Duration
// 重试判断函数
ShouldRetry func(req *http.Request, resp *http.Response, err error) bool
}
// DefaultRetryPolicy 默认重试策略
func DefaultRetryPolicy() *RetryPolicy {
return &RetryPolicy{
MaxRetries: 3,
BaseBackoff: 100 * time.Millisecond,
MaxBackoff: 2 * time.Second,
ShouldRetry: defaultShouldRetry,
}
}
// defaultShouldRetry 默认重试判断
func defaultShouldRetry(req *http.Request, resp *http.Response, err error) bool {
// 不重试非幂等请求
if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
return false
}
// 检查错误
if err != nil {
// 重试网络错误
if netErr, ok := err.(net.Error); ok {
return netErr.Temporary() || netErr.Timeout()
}
return false
}
// 检查响应状态码
if resp != nil {
// 重试服务器错误
return resp.StatusCode >= 500 && resp.StatusCode < 600
}
return false
}
// RetryRoundTripper 重试HTTP传输
type RetryRoundTripper struct {
// 下一级传输
Next http.RoundTripper
// 重试策略
Policy *RetryPolicy
}
// NewRetryRoundTripper 创建重试HTTP传输
func NewRetryRoundTripper(next http.RoundTripper, policy *RetryPolicy) *RetryRoundTripper {
if next == nil {
next = http.DefaultTransport
}
if policy == nil {
policy = DefaultRetryPolicy()
}
return &RetryRoundTripper{
Next: next,
Policy: policy,
}
}
// RoundTrip 执行HTTP请求
func (rt *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// 需要保留原始请求体,以便重试
var reqBodyBytes []byte
if req.Body != nil {
var err error
reqBodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return nil, err
}
req.Body.Close()
}
var resp *http.Response
var err error
// 尝试请求直到成功或达到最大重试次数
for attempt := 0; attempt <= rt.Policy.MaxRetries; attempt++ {
// 复制请求体
if len(reqBodyBytes) > 0 {
req.Body = io.NopCloser(bytes.NewBuffer(reqBodyBytes))
}
// 发送请求
resp, err = rt.Next.RoundTrip(req)
// 检查是否需要重试
if attempt < rt.Policy.MaxRetries && rt.Policy.ShouldRetry(req, resp, err) {
// 如果需要重试,先关闭当前响应
if resp != nil {
resp.Body.Close()
}
// 计算退避时间
backoff := rt.calculateBackoff(attempt)
time.Sleep(backoff)
continue
}
// 不需要重试,返回响应
return resp, err
}
// 所有重试都失败
return resp, err
}
// calculateBackoff 计算退避时间
func (rt *RetryRoundTripper) calculateBackoff(attempt int) time.Duration {
// 指数退避: baseBackoff * 2^attempt
backoff := rt.Policy.BaseBackoff * time.Duration(math.Pow(2, float64(attempt)))
if backoff > rt.Policy.MaxBackoff {
backoff = rt.Policy.MaxBackoff
}
return backoff
}
// RetryMiddleware 重试中间件
type RetryMiddleware struct {
policy *RetryPolicy
}
// NewRetryMiddleware 创建重试中间件
func NewRetryMiddleware(policy *RetryPolicy) *RetryMiddleware {
if policy == nil {
policy = DefaultRetryPolicy()
}
return &RetryMiddleware{
policy: policy,
}
}
// Middleware 中间件处理函数
func (m *RetryMiddleware) Transport(next http.RoundTripper) http.RoundTripper {
return NewRetryRoundTripper(next, m.policy)
}