mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 20:02:37 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			173 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			173 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: MIT
 | |
|  *
 | |
|  * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package ratelimiter
 | |
| 
 | |
| import (
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	packetsPerSecond   = 20
 | |
| 	packetsBurstable   = 5
 | |
| 	garbageCollectTime = time.Second
 | |
| 	packetCost         = 1000000000 / packetsPerSecond
 | |
| 	maxTokens          = packetCost * packetsBurstable
 | |
| )
 | |
| 
 | |
| type RatelimiterEntry struct {
 | |
| 	mu       sync.Mutex
 | |
| 	lastTime time.Time
 | |
| 	tokens   int64
 | |
| }
 | |
| 
 | |
| type Ratelimiter struct {
 | |
| 	mu      sync.RWMutex
 | |
| 	timeNow func() time.Time
 | |
| 
 | |
| 	stopReset chan struct{} // send to reset, close to stop
 | |
| 	tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
 | |
| 	tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
 | |
| }
 | |
| 
 | |
| func (rate *Ratelimiter) Close() {
 | |
| 	rate.mu.Lock()
 | |
| 	defer rate.mu.Unlock()
 | |
| 
 | |
| 	if rate.stopReset != nil {
 | |
| 		close(rate.stopReset)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (rate *Ratelimiter) Init() {
 | |
| 	rate.mu.Lock()
 | |
| 	defer rate.mu.Unlock()
 | |
| 
 | |
| 	if rate.timeNow == nil {
 | |
| 		rate.timeNow = time.Now
 | |
| 	}
 | |
| 
 | |
| 	// stop any ongoing garbage collection routine
 | |
| 	if rate.stopReset != nil {
 | |
| 		close(rate.stopReset)
 | |
| 	}
 | |
| 
 | |
| 	rate.stopReset = make(chan struct{})
 | |
| 	rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
 | |
| 	rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
 | |
| 
 | |
| 	stopReset := rate.stopReset // store in case Init is called again.
 | |
| 
 | |
| 	// Start garbage collection routine.
 | |
| 	go func() {
 | |
| 		ticker := time.NewTicker(time.Second)
 | |
| 		ticker.Stop()
 | |
| 		for {
 | |
| 			select {
 | |
| 			case _, ok := <-stopReset:
 | |
| 				ticker.Stop()
 | |
| 				if !ok {
 | |
| 					return
 | |
| 				}
 | |
| 				ticker = time.NewTicker(time.Second)
 | |
| 			case <-ticker.C:
 | |
| 				if rate.cleanup() {
 | |
| 					ticker.Stop()
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}()
 | |
| }
 | |
| 
 | |
| func (rate *Ratelimiter) cleanup() (empty bool) {
 | |
| 	rate.mu.Lock()
 | |
| 	defer rate.mu.Unlock()
 | |
| 
 | |
| 	for key, entry := range rate.tableIPv4 {
 | |
| 		entry.mu.Lock()
 | |
| 		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
 | |
| 			delete(rate.tableIPv4, key)
 | |
| 		}
 | |
| 		entry.mu.Unlock()
 | |
| 	}
 | |
| 
 | |
| 	for key, entry := range rate.tableIPv6 {
 | |
| 		entry.mu.Lock()
 | |
| 		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
 | |
| 			delete(rate.tableIPv6, key)
 | |
| 		}
 | |
| 		entry.mu.Unlock()
 | |
| 	}
 | |
| 
 | |
| 	return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
 | |
| }
 | |
| 
 | |
| func (rate *Ratelimiter) Allow(ip net.IP) bool {
 | |
| 	var entry *RatelimiterEntry
 | |
| 	var keyIPv4 [net.IPv4len]byte
 | |
| 	var keyIPv6 [net.IPv6len]byte
 | |
| 
 | |
| 	// lookup entry
 | |
| 
 | |
| 	IPv4 := ip.To4()
 | |
| 	IPv6 := ip.To16()
 | |
| 
 | |
| 	rate.mu.RLock()
 | |
| 
 | |
| 	if IPv4 != nil {
 | |
| 		copy(keyIPv4[:], IPv4)
 | |
| 		entry = rate.tableIPv4[keyIPv4]
 | |
| 	} else {
 | |
| 		copy(keyIPv6[:], IPv6)
 | |
| 		entry = rate.tableIPv6[keyIPv6]
 | |
| 	}
 | |
| 
 | |
| 	rate.mu.RUnlock()
 | |
| 
 | |
| 	// make new entry if not found
 | |
| 
 | |
| 	if entry == nil {
 | |
| 		entry = new(RatelimiterEntry)
 | |
| 		entry.tokens = maxTokens - packetCost
 | |
| 		entry.lastTime = rate.timeNow()
 | |
| 		rate.mu.Lock()
 | |
| 		if IPv4 != nil {
 | |
| 			rate.tableIPv4[keyIPv4] = entry
 | |
| 			if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
 | |
| 				rate.stopReset <- struct{}{}
 | |
| 			}
 | |
| 		} else {
 | |
| 			rate.tableIPv6[keyIPv6] = entry
 | |
| 			if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
 | |
| 				rate.stopReset <- struct{}{}
 | |
| 			}
 | |
| 		}
 | |
| 		rate.mu.Unlock()
 | |
| 		return true
 | |
| 	}
 | |
| 
 | |
| 	// add tokens to entry
 | |
| 
 | |
| 	entry.mu.Lock()
 | |
| 	now := rate.timeNow()
 | |
| 	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
 | |
| 	entry.lastTime = now
 | |
| 	if entry.tokens > maxTokens {
 | |
| 		entry.tokens = maxTokens
 | |
| 	}
 | |
| 
 | |
| 	// subtract cost of packet
 | |
| 
 | |
| 	if entry.tokens > packetCost {
 | |
| 		entry.tokens -= packetCost
 | |
| 		entry.mu.Unlock()
 | |
| 		return true
 | |
| 	}
 | |
| 	entry.mu.Unlock()
 | |
| 	return false
 | |
| }
 | 
