Files
goproxy/internal/middleware/ratelimiter.go
2025-03-13 15:56:33 +08:00

185 lines
3.7 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiter 限流器接口
type RateLimiter interface {
// Allow 检查请求是否允许通过
Allow(key string) bool
}
// SimpleRateLimiter 简单限流器
type SimpleRateLimiter struct {
limiter *rate.Limiter
}
// NewSimpleRateLimiter 创建简单限流器
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
return &SimpleRateLimiter{
limiter: rate.NewLimiter(rate.Limit(r), b),
}
}
// Allow 检查请求是否允许通过
func (rl *SimpleRateLimiter) Allow(key string) bool {
return rl.limiter.Allow()
}
// IPRateLimiter 按IP限流
type IPRateLimiter struct {
ips map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
cleanupInterval time.Duration
lastSeen map[string]time.Time
}
// NewIPRateLimiter 创建IP限流器
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
limiter := &IPRateLimiter{
ips: make(map[string]*rate.Limiter),
rate: rate.Limit(r),
burst: b,
cleanupInterval: cleanup,
lastSeen: make(map[string]time.Time),
}
// 启动过期清理
if cleanup > 0 {
go limiter.startCleanup()
}
return limiter
}
// startCleanup 启动过期清理
func (rl *IPRateLimiter) startCleanup() {
ticker := time.NewTicker(rl.cleanupInterval)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
// cleanup 清理过期限流器
func (rl *IPRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for ip, lastSeen := range rl.lastSeen {
if now.Sub(lastSeen) > rl.cleanupInterval {
delete(rl.ips, ip)
delete(rl.lastSeen, ip)
}
}
}
// AddIP 添加IP限流器
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
limiter := rate.NewLimiter(rl.rate, rl.burst)
rl.ips[ip] = limiter
rl.lastSeen[ip] = time.Now()
return limiter
}
// GetLimiter 获取IP限流器
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.ips[ip]
rl.mu.RUnlock()
if !exists {
return rl.AddIP(ip)
}
// 更新最后访问时间
rl.mu.Lock()
rl.lastSeen[ip] = time.Now()
rl.mu.Unlock()
return limiter
}
// Allow 检查请求是否允许通过
func (rl *IPRateLimiter) Allow(ip string) bool {
limiter := rl.GetLimiter(ip)
return limiter.Allow()
}
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
limiter RateLimiter
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
return &RateLimitMiddleware{
limiter: limiter,
}
}
// Middleware 中间件处理函数
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 获取客户端IP
ip := getClientIP(r)
// 检查是否允许通过
if !m.limiter.Allow(ip) {
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
return
}
// 继续处理请求
next.ServeHTTP(w, r)
})
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 检查 X-Forwarded-For 头
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
// 取第一个IP
for i := 0; i < len(ip) && i < 15; i++ {
if ip[i] == ',' {
ip = ip[:i]
break
}
}
return ip
}
// 检查 X-Real-IP 头
ip = r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
// 从 RemoteAddr 获取
if r.RemoteAddr != "" {
// 去掉端口部分
for i := 0; i < len(r.RemoteAddr); i++ {
if r.RemoteAddr[i] == ':' {
return r.RemoteAddr[:i]
}
}
return r.RemoteAddr
}
return "unknown"
}