mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 03:46:20 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			274 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			274 lines
		
	
	
		
			5.0 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: GPL-2.0
 | |
|  *
 | |
|  * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package main
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| )
 | |
| 
 | |
| type trieEntry struct {
 | |
| 	cidr  uint
 | |
| 	child [2]*trieEntry
 | |
| 	bits  []byte
 | |
| 	peer  *Peer
 | |
| 
 | |
| 	// index of "branching" bit
 | |
| 
 | |
| 	bit_at_byte  uint
 | |
| 	bit_at_shift uint
 | |
| }
 | |
| 
 | |
| /* Finds length of matching prefix
 | |
|  *
 | |
|  * TODO: Only use during insertion (xor + prefix mask for lookup)
 | |
|  *       Check out
 | |
|  *       prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
 | |
|  *       https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
 | |
|  *
 | |
|  * Assumption:
 | |
|  *	  len(ip1) == len(ip2)
 | |
|  *	  len(ip1) mod 4 = 0
 | |
|  */
 | |
| func commonBits(ip1 []byte, ip2 []byte) uint {
 | |
| 	var i uint
 | |
| 	size := uint(len(ip1))
 | |
| 
 | |
| 	for i = 0; i < size; i++ {
 | |
| 		v := ip1[i] ^ ip2[i]
 | |
| 		if v != 0 {
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 7
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 6
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 5
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 4
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 3
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 2
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 1
 | |
| 			}
 | |
| 			return i * 8
 | |
| 		}
 | |
| 	}
 | |
| 	return i * 8
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
 | |
| 	if node == nil {
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// walk recursively
 | |
| 
 | |
| 	node.child[0] = node.child[0].removeByPeer(p)
 | |
| 	node.child[1] = node.child[1].removeByPeer(p)
 | |
| 
 | |
| 	if node.peer != p {
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// remove peer & merge
 | |
| 
 | |
| 	node.peer = nil
 | |
| 	if node.child[0] == nil {
 | |
| 		return node.child[1]
 | |
| 	}
 | |
| 	return node.child[0]
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) choose(ip net.IP) byte {
 | |
| 	return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
 | |
| 
 | |
| 	// at leaf
 | |
| 
 | |
| 	if node == nil {
 | |
| 		return &trieEntry{
 | |
| 			bits:         ip,
 | |
| 			peer:         peer,
 | |
| 			cidr:         cidr,
 | |
| 			bit_at_byte:  cidr / 8,
 | |
| 			bit_at_shift: 7 - (cidr % 8),
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// traverse deeper
 | |
| 
 | |
| 	common := commonBits(node.bits, ip)
 | |
| 	if node.cidr <= cidr && common >= node.cidr {
 | |
| 		if node.cidr == cidr {
 | |
| 			node.peer = peer
 | |
| 			return node
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node.child[bit] = node.child[bit].insert(ip, cidr, peer)
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// split node
 | |
| 
 | |
| 	newNode := &trieEntry{
 | |
| 		bits:         ip,
 | |
| 		peer:         peer,
 | |
| 		cidr:         cidr,
 | |
| 		bit_at_byte:  cidr / 8,
 | |
| 		bit_at_shift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 
 | |
| 	cidr = min(cidr, common)
 | |
| 
 | |
| 	// check for shorter prefix
 | |
| 
 | |
| 	if newNode.cidr == cidr {
 | |
| 		bit := newNode.choose(node.bits)
 | |
| 		newNode.child[bit] = node
 | |
| 		return newNode
 | |
| 	}
 | |
| 
 | |
| 	// create new parent for node & newNode
 | |
| 
 | |
| 	parent := &trieEntry{
 | |
| 		bits:         ip,
 | |
| 		peer:         nil,
 | |
| 		cidr:         cidr,
 | |
| 		bit_at_byte:  cidr / 8,
 | |
| 		bit_at_shift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 
 | |
| 	bit := parent.choose(ip)
 | |
| 	parent.child[bit] = newNode
 | |
| 	parent.child[bit^1] = node
 | |
| 
 | |
| 	return parent
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) lookup(ip net.IP) *Peer {
 | |
| 	var found *Peer
 | |
| 	size := uint(len(ip))
 | |
| 	for node != nil && commonBits(node.bits, ip) >= node.cidr {
 | |
| 		if node.peer != nil {
 | |
| 			found = node.peer
 | |
| 		}
 | |
| 		if node.bit_at_byte == size {
 | |
| 			break
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node = node.child[bit]
 | |
| 	}
 | |
| 	return found
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
 | |
| 	if node == nil {
 | |
| 		return results
 | |
| 	}
 | |
| 	if node.peer == p {
 | |
| 		var mask net.IPNet
 | |
| 		mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
 | |
| 		if len(node.bits) == net.IPv4len {
 | |
| 			mask.IP = net.IPv4(
 | |
| 				node.bits[0],
 | |
| 				node.bits[1],
 | |
| 				node.bits[2],
 | |
| 				node.bits[3],
 | |
| 			)
 | |
| 		} else if len(node.bits) == net.IPv6len {
 | |
| 			mask.IP = node.bits
 | |
| 		} else {
 | |
| 			panic(errors.New("unexpected address length"))
 | |
| 		}
 | |
| 		results = append(results, mask)
 | |
| 	}
 | |
| 	results = node.child[0].entriesForPeer(p, results)
 | |
| 	results = node.child[1].entriesForPeer(p, results)
 | |
| 	return results
 | |
| }
 | |
| 
 | |
| type AllowedIPs struct {
 | |
| 	IPv4  *trieEntry
 | |
| 	IPv6  *trieEntry
 | |
| 	mutex sync.RWMutex
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
 | |
| 	table.mutex.RLock()
 | |
| 	defer table.mutex.RUnlock()
 | |
| 
 | |
| 	allowed := make([]net.IPNet, 0, 10)
 | |
| 	allowed = table.IPv4.entriesForPeer(peer, allowed)
 | |
| 	allowed = table.IPv6.entriesForPeer(peer, allowed)
 | |
| 	return allowed
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) Reset() {
 | |
| 	table.mutex.Lock()
 | |
| 	defer table.mutex.Unlock()
 | |
| 
 | |
| 	table.IPv4 = nil
 | |
| 	table.IPv6 = nil
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
 | |
| 	table.mutex.Lock()
 | |
| 	defer table.mutex.Unlock()
 | |
| 
 | |
| 	table.IPv4 = table.IPv4.removeByPeer(peer)
 | |
| 	table.IPv6 = table.IPv6.removeByPeer(peer)
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
 | |
| 	table.mutex.Lock()
 | |
| 	defer table.mutex.Unlock()
 | |
| 
 | |
| 	switch len(ip) {
 | |
| 	case net.IPv6len:
 | |
| 		table.IPv6 = table.IPv6.insert(ip, cidr, peer)
 | |
| 	case net.IPv4len:
 | |
| 		table.IPv4 = table.IPv4.insert(ip, cidr, peer)
 | |
| 	default:
 | |
| 		panic(errors.New("inserting unknown address type"))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
 | |
| 	table.mutex.RLock()
 | |
| 	defer table.mutex.RUnlock()
 | |
| 	return table.IPv4.lookup(address)
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
 | |
| 	table.mutex.RLock()
 | |
| 	defer table.mutex.RUnlock()
 | |
| 	return table.IPv6.lookup(address)
 | |
| }
 | 
