185 lines
3.7 KiB
Go
185 lines
3.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// RateLimiter 限流器接口
|
|
type RateLimiter interface {
|
|
// Allow 检查请求是否允许通过
|
|
Allow(key string) bool
|
|
}
|
|
|
|
// SimpleRateLimiter 简单限流器
|
|
type SimpleRateLimiter struct {
|
|
limiter *rate.Limiter
|
|
}
|
|
|
|
// NewSimpleRateLimiter 创建简单限流器
|
|
func NewSimpleRateLimiter(r float64, b int) *SimpleRateLimiter {
|
|
return &SimpleRateLimiter{
|
|
limiter: rate.NewLimiter(rate.Limit(r), b),
|
|
}
|
|
}
|
|
|
|
// Allow 检查请求是否允许通过
|
|
func (rl *SimpleRateLimiter) Allow(key string) bool {
|
|
return rl.limiter.Allow()
|
|
}
|
|
|
|
// IPRateLimiter 按IP限流
|
|
type IPRateLimiter struct {
|
|
ips map[string]*rate.Limiter
|
|
mu sync.RWMutex
|
|
rate rate.Limit
|
|
burst int
|
|
cleanupInterval time.Duration
|
|
lastSeen map[string]time.Time
|
|
}
|
|
|
|
// NewIPRateLimiter 创建IP限流器
|
|
func NewIPRateLimiter(r float64, b int, cleanup time.Duration) *IPRateLimiter {
|
|
limiter := &IPRateLimiter{
|
|
ips: make(map[string]*rate.Limiter),
|
|
rate: rate.Limit(r),
|
|
burst: b,
|
|
cleanupInterval: cleanup,
|
|
lastSeen: make(map[string]time.Time),
|
|
}
|
|
|
|
// 启动过期清理
|
|
if cleanup > 0 {
|
|
go limiter.startCleanup()
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
// startCleanup 启动过期清理
|
|
func (rl *IPRateLimiter) startCleanup() {
|
|
ticker := time.NewTicker(rl.cleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
rl.cleanup()
|
|
}
|
|
}
|
|
|
|
// cleanup 清理过期限流器
|
|
func (rl *IPRateLimiter) cleanup() {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
for ip, lastSeen := range rl.lastSeen {
|
|
if now.Sub(lastSeen) > rl.cleanupInterval {
|
|
delete(rl.ips, ip)
|
|
delete(rl.lastSeen, ip)
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddIP 添加IP限流器
|
|
func (rl *IPRateLimiter) AddIP(ip string) *rate.Limiter {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
limiter := rate.NewLimiter(rl.rate, rl.burst)
|
|
rl.ips[ip] = limiter
|
|
rl.lastSeen[ip] = time.Now()
|
|
|
|
return limiter
|
|
}
|
|
|
|
// GetLimiter 获取IP限流器
|
|
func (rl *IPRateLimiter) GetLimiter(ip string) *rate.Limiter {
|
|
rl.mu.RLock()
|
|
limiter, exists := rl.ips[ip]
|
|
rl.mu.RUnlock()
|
|
|
|
if !exists {
|
|
return rl.AddIP(ip)
|
|
}
|
|
|
|
// 更新最后访问时间
|
|
rl.mu.Lock()
|
|
rl.lastSeen[ip] = time.Now()
|
|
rl.mu.Unlock()
|
|
|
|
return limiter
|
|
}
|
|
|
|
// Allow 检查请求是否允许通过
|
|
func (rl *IPRateLimiter) Allow(ip string) bool {
|
|
limiter := rl.GetLimiter(ip)
|
|
return limiter.Allow()
|
|
}
|
|
|
|
// RateLimitMiddleware 限流中间件
|
|
type RateLimitMiddleware struct {
|
|
limiter RateLimiter
|
|
}
|
|
|
|
// NewRateLimitMiddleware 创建限流中间件
|
|
func NewRateLimitMiddleware(limiter RateLimiter) *RateLimitMiddleware {
|
|
return &RateLimitMiddleware{
|
|
limiter: limiter,
|
|
}
|
|
}
|
|
|
|
// Middleware 中间件处理函数
|
|
func (m *RateLimitMiddleware) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// 获取客户端IP
|
|
ip := getClientIP(r)
|
|
|
|
// 检查是否允许通过
|
|
if !m.limiter.Allow(ip) {
|
|
http.Error(w, "Too Many Requests", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
// 继续处理请求
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// getClientIP 获取客户端IP
|
|
func getClientIP(r *http.Request) string {
|
|
// 检查 X-Forwarded-For 头
|
|
ip := r.Header.Get("X-Forwarded-For")
|
|
if ip != "" {
|
|
// 取第一个IP
|
|
for i := 0; i < len(ip) && i < 15; i++ {
|
|
if ip[i] == ',' {
|
|
ip = ip[:i]
|
|
break
|
|
}
|
|
}
|
|
return ip
|
|
}
|
|
|
|
// 检查 X-Real-IP 头
|
|
ip = r.Header.Get("X-Real-IP")
|
|
if ip != "" {
|
|
return ip
|
|
}
|
|
|
|
// 从 RemoteAddr 获取
|
|
if r.RemoteAddr != "" {
|
|
// 去掉端口部分
|
|
for i := 0; i < len(r.RemoteAddr); i++ {
|
|
if r.RemoteAddr[i] == ':' {
|
|
return r.RemoteAddr[:i]
|
|
}
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
return "unknown"
|
|
}
|