package middleware import ( "net/http" "sync" "time" "golang.org/x/time/rate" ) // RateLimiter 限流器接口 type RateLimiter interface { // Allow 检查请求是否允许通过 Allow(key string) bool } // SimpleRateLimiter 简单限流器 type SimpleRateLimiter struct { limiter *rate.Limiter } // NewSimpleRateLimiter 创建简单限流器 func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter { return &SimpleRateLimiter{ limiter: rate.NewLimiter(rate.Limit(r), b), } } // Allow 检查请求是否允许通过 func (rl *SimpleRateLimiter) Allow(key string) bool { return rl.limiter.Allow() } // IPRateLimiter 按IP限流 type IPRateLimiter struct { ips map[string]*rate.Limiter mu sync.RWMutex rate rate.Limit burst int cleanupInterval time.Duration lastSeen map[string]time.Time } // NewIPRateLimiter 创建IP限流器 func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter { limiter := &IPRateLimiter{ ips: make(map[string]*rate.Limiter), rate: rate.Limit(r), burst: b, cleanupInterval: cleanup, lastSeen: make(map[string]time.Time), } // 启动过期清理 if cleanup > 0 { go limiter.startCleanup() } return limiter } // startCleanup 启动过期清理 func (rl *IPRateLimiter) startCleanup() { ticker := time.NewTicker(rl.cleanupInterval) defer ticker.Stop() for range ticker.C { rl.cleanup() } } // cleanup 清理过期限流器 func (rl *IPRateLimiter) cleanup() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() for ip, lastSeen := range rl.lastSeen { if now.Sub(lastSeen) > rl.cleanupInterval { delete(rl.ips, ip) delete(rl.lastSeen, ip) } } } // AddIP 添加IP限流器 func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() limiter := rate.NewLimiter(rl.rate, rl.burst) rl.ips[ip] = limiter rl.lastSeen[ip] = time.Now() return limiter } // GetLimiter 获取IP限流器 func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { rl.mu.RLock() limiter, exists := rl.ips[ip] rl.mu.RUnlock() if !exists { return rl.AddIP(ip) } // 更新最后访问时间 rl.mu.Lock() rl.lastSeen[ip] = time.Now() rl.mu.Unlock() return limiter } // Allow 检查请求是否允许通过 func (rl *IPRateLimiter) Allow(ip string) bool { limiter := rl.GetLimiter(ip) return limiter.Allow() } // RateLimitMiddleware 限流中间件 type RateLimitMiddleware struct { limiter RateLimiter } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware { return &RateLimitMiddleware{ limiter: limiter, } } // Middleware 中间件处理函数 func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 获取客户端IP ip := getClientIP(r) // 检查是否允许通过 if !m.limiter.Allow(ip) { http.Error(w, "Too Many Requests", http.StatusTooManyRequests) return } // 继续处理请求 next.ServeHTTP(w, r) }) } // getClientIP 获取客户端IP func getClientIP(r *http.Request) string { // 检查 X-Forwarded-For 头 ip := r.Header.Get("X-Forwarded-For") if ip != "" { // 取第一个IP for i := 0; i < len(ip) && i < 15; i++ { if ip[i] == ',' { ip = ip[:i] break } } return ip } // 检查 X-Real-IP 头 ip = r.Header.Get("X-Real-IP") if ip != "" { return ip } // 从 RemoteAddr 获取 if r.RemoteAddr != "" { // 去掉端口部分 for i := 0; i < len(r.RemoteAddr); i++ { if r.RemoteAddr[i] == ':' { return r.RemoteAddr[:i] } } return r.RemoteAddr } return "unknown" }