Files
rpcx/serverplugin/req_rate_limiting_redis.go

54 lines
1.2 KiB
Go

package serverplugin
import (
"context"
"time"
"github.com/go-redis/redis_rate/v10"
"github.com/redis/go-redis/v9"
"github.com/smallnest/rpcx/protocol"
"github.com/smallnest/rpcx/server"
)
var _ server.PostReadRequestPlugin = (*RedisRateLimitingPlugin)(nil)
// RedisRateLimitingPlugin can limit requests per unit time
type RedisRateLimitingPlugin struct {
addrs []string
limiter redis_rate.Limiter
limit redis_rate.Limit
}
// NewRedisRateLimitingPlugin creates a new RateLimitingPlugin
func NewRedisRateLimitingPlugin(addrs []string, rate int, burst int, period time.Duration) *RedisRateLimitingPlugin {
limit := redis_rate.Limit{
Rate: rate,
Burst: burst,
Period: period,
}
rdb := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: addrs,
})
limiter := redis_rate.NewLimiter(rdb)
return &RedisRateLimitingPlugin{
addrs: addrs,
limiter: *limiter,
limit: limit,
}
}
// PostReadRequest can limit request processing.
func (plugin *RedisRateLimitingPlugin) PostReadRequest(ctx context.Context, r *protocol.Message, e error) error {
res, err := plugin.limiter.Allow(ctx, r.ServicePath+"/"+r.ServiceMethod, plugin.limit)
if err != nil {
return err
}
if res.Allowed > 0 {
return nil
}
return server.ErrReqReachLimit
}