Files
go-libp2p/x/rate/limiter_test.go
sukun b82a39cb89 transport: rate limit new connections (#3283)
This rate limits new connections to prevent DoS attacks. 

For effectively rate limiting QUIC connections, we now gate QUIC connection attempts before the handshake, so that we don't spend compute on handshakes for connections that will eventually be cancelled. 

We can only set a single ConnContext per quic-go Transport, as there's only 1 listener per quic-go Transport. So we cannot set a different ConnContext for listeners on the same address. 

As we're now gating QUIC connections before the handshake, we use source address verification to ensure that spoofed IPs cannot DoS new connections from a particular IP. This is done by ensuring that some of the connection attempts always verify the source address. We get DoS protection at the expense of increased latency of source address verification.
2025-06-05 02:16:32 +05:30

257 lines
7.5 KiB
Go

package rate
import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
const rateLimitErrorTolerance = 0.05
func getSleepDurationAndRequestCount(rps float64) (time.Duration, int) {
sleepDuration := 100 * time.Millisecond
requestCount := int(sleepDuration.Seconds() * float64(rps))
if requestCount < 1 {
// Adding 1ms to ensure we do get 1 request. If the rate is low enough that
// 100ms won't have a single request adding 1ms won't error here.
sleepDuration = time.Duration((1/rps)*float64(time.Second)) + 1*time.Millisecond
requestCount = 1
}
return sleepDuration, requestCount
}
func assertLimiter(t *testing.T, rl *Limiter, ipAddr netip.Addr, allowed, errorMargin int) {
t.Helper()
for i := 0; i < allowed; i++ {
require.True(t, rl.Allow(ipAddr))
}
for i := 0; i < errorMargin; i++ {
rl.Allow(ipAddr)
}
require.False(t, rl.Allow(ipAddr))
}
func TestLimiterGlobal(t *testing.T) {
addr := netip.MustParseAddr("127.0.0.1")
limits := []Limit{
{RPS: 0.0, Burst: 1},
{RPS: 0.8, Burst: 1},
{RPS: 10, Burst: 20},
{RPS: 100, Burst: 200},
{RPS: 1000, Burst: 2000},
}
for _, limit := range limits {
t.Run(fmt.Sprintf("limit %0.1f", limit.RPS), func(t *testing.T) {
rl := &Limiter{
GlobalLimit: limit,
}
if limit.RPS == 0 {
// 0 implies no rate limiting, any large number would do
for i := 0; i < 1000; i++ {
require.True(t, rl.Allow(addr))
}
return
}
assertLimiter(t, rl, addr, limit.Burst, int(limit.RPS*rateLimitErrorTolerance))
sleepDuration, requestCount := getSleepDurationAndRequestCount(limit.RPS)
time.Sleep(sleepDuration)
assertLimiter(t, rl, addr, requestCount, int(float64(requestCount)*rateLimitErrorTolerance))
})
}
}
func TestLimiterNetworkPrefix(t *testing.T) {
local := netip.MustParseAddr("127.0.0.1")
public := netip.MustParseAddr("1.1.1.1")
rl := &Limiter{
NetworkPrefixLimits: []PrefixLimit{
{Prefix: netip.MustParsePrefix("127.0.0.0/24"), Limit: Limit{}},
},
GlobalLimit: Limit{RPS: 10, Burst: 10},
}
// element within prefix is allowed even over the limit
for range rl.GlobalLimit.Burst + 100 {
require.True(t, rl.Allow(local))
}
// rate limit public ips
assertLimiter(t, rl, public, rl.GlobalLimit.Burst, int(rl.GlobalLimit.RPS*rateLimitErrorTolerance))
// public ip rejected
require.False(t, rl.Allow(public))
// local ip accepted
for range 100 {
require.True(t, rl.Allow(local))
}
}
func TestLimiterNetworkPrefixWidth(t *testing.T) {
a1 := netip.MustParseAddr("1.1.1.1")
a2 := netip.MustParseAddr("1.1.0.1")
wideLimit := 20
narrowLimit := 10
rl := &Limiter{
NetworkPrefixLimits: []PrefixLimit{
{Prefix: netip.MustParsePrefix("1.1.0.0/16"), Limit: Limit{RPS: 0.01, Burst: wideLimit}},
{Prefix: netip.MustParsePrefix("1.1.1.0/24"), Limit: Limit{RPS: 0.01, Burst: narrowLimit}},
},
}
for range 2 * wideLimit {
rl.Allow(a1)
}
// a1 rejected
require.False(t, rl.Allow(a1))
// a2 accepted
for range wideLimit - narrowLimit {
require.True(t, rl.Allow(a2))
}
}
func subnetAddrs(prefix netip.Prefix) func() netip.Addr {
next := prefix.Addr()
return func() netip.Addr {
addr := next
next = addr.Next()
if !prefix.Contains(addr) {
next = prefix.Addr()
addr = next
}
return addr
}
}
func TestSubnetLimiter(t *testing.T) {
assertOutput := func(outcome bool, rl *SubnetLimiter, subnetAddrs func() netip.Addr, n int) {
t.Helper()
for range n {
require.Equal(t, outcome, rl.Allow(subnetAddrs(), time.Now()), "%d", n)
}
}
t.Run("Simple", func(*testing.T) {
// Keep the refil rate low
v4Small := SubnetLimit{PrefixLength: 24, Limit: Limit{RPS: 0.0001, Burst: 10}}
v4Large := SubnetLimit{PrefixLength: 16, Limit: Limit{RPS: 0.0001, Burst: 19}}
v6Small := SubnetLimit{PrefixLength: 64, Limit: Limit{RPS: 0.0001, Burst: 10}}
v6Large := SubnetLimit{PrefixLength: 48, Limit: Limit{RPS: 0.0001, Burst: 17}}
rl := &SubnetLimiter{
IPv4SubnetLimits: []SubnetLimit{v4Large, v4Small},
IPv6SubnetLimits: []SubnetLimit{v6Large, v6Small},
}
v4SubnetAddr1 := subnetAddrs(netip.MustParsePrefix("192.168.1.1/24"))
v4SubnetAddr2 := subnetAddrs(netip.MustParsePrefix("192.168.2.1/24"))
v6SubnetAddr1 := subnetAddrs(netip.MustParsePrefix("2001:0:0:1::/64"))
v6SubnetAddr2 := subnetAddrs(netip.MustParsePrefix("2001:0:0:2::/64"))
assertOutput(true, rl, v4SubnetAddr1, v4Small.Burst)
assertOutput(false, rl, v4SubnetAddr1, v4Large.Burst)
assertOutput(true, rl, v4SubnetAddr2, v4Large.Burst-v4Small.Burst)
assertOutput(false, rl, v4SubnetAddr2, v4Large.Burst)
assertOutput(true, rl, v6SubnetAddr1, v6Small.Burst)
assertOutput(false, rl, v6SubnetAddr1, v6Large.Burst)
assertOutput(true, rl, v6SubnetAddr2, v6Large.Burst-v6Small.Burst)
assertOutput(false, rl, v6SubnetAddr2, v6Large.Burst)
})
t.Run("Complex", func(*testing.T) {
limits := []SubnetLimit{
{PrefixLength: 32, Limit: Limit{RPS: 0.01, Burst: 10}},
{PrefixLength: 24, Limit: Limit{RPS: 0.01, Burst: 20}},
{PrefixLength: 16, Limit: Limit{RPS: 0.01, Burst: 30}},
{PrefixLength: 8, Limit: Limit{RPS: 0.01, Burst: 40}},
}
rl := &SubnetLimiter{
IPv4SubnetLimits: limits,
}
snAddrs := []func() netip.Addr{
subnetAddrs(netip.MustParsePrefix("192.168.1.1/32")),
subnetAddrs(netip.MustParsePrefix("192.168.1.2/24")),
subnetAddrs(netip.MustParsePrefix("192.168.2.1/16")),
subnetAddrs(netip.MustParsePrefix("192.0.1.1/8")),
}
for i, addrsFunc := range snAddrs {
prev := 0
if i > 0 {
prev = limits[i-1].Burst
}
assertOutput(true, rl, addrsFunc, limits[i].Burst-prev)
assertOutput(false, rl, addrsFunc, limits[i].Burst)
}
})
t.Run("Zero", func(t *testing.T) {
sl := SubnetLimiter{}
for range 10000 {
require.True(t, sl.Allow(netip.IPv6Loopback(), time.Now()))
}
})
}
func TestSubnetLimiterCleanup(t *testing.T) {
tc := []struct {
Limit
TTL time.Duration
}{
{Limit: Limit{RPS: 1, Burst: 10}, TTL: 10 * time.Second},
{Limit: Limit{RPS: 0.1, Burst: 2}, TTL: 20 * time.Second},
{Limit: Limit{RPS: 1, Burst: 100}, TTL: 100 * time.Second},
{Limit: Limit{RPS: 3, Burst: 6}, TTL: 2 * time.Second},
}
for i, tt := range tc {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
ip1, ip2 := netip.IPv6Loopback(), netip.MustParseAddr("2001::")
sl := SubnetLimiter{IPv6SubnetLimits: []SubnetLimit{{PrefixLength: 64, Limit: tt.Limit}}}
now := time.Now()
// Empty the ip1 bucket
for range tt.Burst {
require.True(t, sl.Allow(ip1, now))
}
for range tt.Burst / 2 {
require.True(t, sl.Allow(ip2, now))
}
epsilon := 100 * time.Millisecond
// just before ip1 expiry
now = now.Add(tt.TTL).Add(-epsilon)
sl.cleanUp(now) // ip2 will be removed
require.Equal(t, 1, sl.ipv6Heaps[0].Len())
// just after ip1 expiry
now = now.Add(2 * epsilon)
require.True(t, sl.Allow(ip2, now)) // remove the ip1 bucket
require.Equal(t, 1, sl.ipv6Heaps[0].Len()) // ip2 added in the previous call
})
}
}
func TestTokenBucketFullAfter(t *testing.T) {
tc := []struct {
*rate.Limiter
FullAfter time.Duration
}{
{Limiter: rate.NewLimiter(1, 10), FullAfter: 10 * time.Second},
{Limiter: rate.NewLimiter(0.01, 10), FullAfter: 1000 * time.Second},
{Limiter: rate.NewLimiter(0.01, 1), FullAfter: 100 * time.Second},
}
for i, tt := range tc {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
b := tokenBucket{tt.Limiter}
now := time.Now()
for range b.Burst() {
tt.Allow()
}
epsilon := 10 * time.Millisecond
require.GreaterOrEqual(t, tt.FullAfter+epsilon, b.FullAt(now).Sub(now))
require.LessOrEqual(t, tt.FullAfter-epsilon, b.FullAt(now).Sub(now))
})
}
}