mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 03:46:20 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			229 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			229 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package main
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"net"
 | |
| )
 | |
| 
 | |
| /* Binary trie
 | |
|  *
 | |
|  * The net.IPs used here are not formatted the
 | |
|  * same way as those created by the "net" functions.
 | |
|  * Here the IPs are slices of either 4 or 16 byte (not always 16)
 | |
|  *
 | |
|  * Synchronization done separately
 | |
|  * See: routing.go
 | |
|  */
 | |
| 
 | |
| type Trie struct {
 | |
| 	cidr  uint
 | |
| 	child [2]*Trie
 | |
| 	bits  []byte
 | |
| 	peer  *Peer
 | |
| 
 | |
| 	// index of "branching" bit
 | |
| 
 | |
| 	bit_at_byte  uint
 | |
| 	bit_at_shift uint
 | |
| }
 | |
| 
 | |
| /* Finds length of matching prefix
 | |
|  *
 | |
|  * TODO: Only use during insertion (xor + prefix mask for lookup)
 | |
|  *       Check out
 | |
|  *       prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
 | |
|  *       https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
 | |
|  *
 | |
|  * Assumption:
 | |
|  *	  len(ip1) == len(ip2)
 | |
|  *	  len(ip1) mod 4 = 0
 | |
|  */
 | |
| func commonBits(ip1 []byte, ip2 []byte) uint {
 | |
| 	var i uint
 | |
| 	size := uint(len(ip1))
 | |
| 
 | |
| 	for i = 0; i < size; i++ {
 | |
| 		v := ip1[i] ^ ip2[i]
 | |
| 		if v != 0 {
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 7
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 6
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 5
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 4
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 3
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 2
 | |
| 			}
 | |
| 
 | |
| 			v >>= 1
 | |
| 			if v == 0 {
 | |
| 				return i*8 + 1
 | |
| 			}
 | |
| 			return i * 8
 | |
| 		}
 | |
| 	}
 | |
| 	return i * 8
 | |
| }
 | |
| 
 | |
| func (node *Trie) RemovePeer(p *Peer) *Trie {
 | |
| 	if node == nil {
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// walk recursively
 | |
| 
 | |
| 	node.child[0] = node.child[0].RemovePeer(p)
 | |
| 	node.child[1] = node.child[1].RemovePeer(p)
 | |
| 
 | |
| 	if node.peer != p {
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// remove peer & merge
 | |
| 
 | |
| 	node.peer = nil
 | |
| 	if node.child[0] == nil {
 | |
| 		return node.child[1]
 | |
| 	}
 | |
| 	return node.child[0]
 | |
| }
 | |
| 
 | |
| func (node *Trie) choose(ip net.IP) byte {
 | |
| 	return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
 | |
| }
 | |
| 
 | |
| func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
 | |
| 
 | |
| 	// at leaf
 | |
| 
 | |
| 	if node == nil {
 | |
| 		return &Trie{
 | |
| 			bits:         ip,
 | |
| 			peer:         peer,
 | |
| 			cidr:         cidr,
 | |
| 			bit_at_byte:  cidr / 8,
 | |
| 			bit_at_shift: 7 - (cidr % 8),
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// traverse deeper
 | |
| 
 | |
| 	common := commonBits(node.bits, ip)
 | |
| 	if node.cidr <= cidr && common >= node.cidr {
 | |
| 		if node.cidr == cidr {
 | |
| 			node.peer = peer
 | |
| 			return node
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
 | |
| 		return node
 | |
| 	}
 | |
| 
 | |
| 	// split node
 | |
| 
 | |
| 	newNode := &Trie{
 | |
| 		bits:         ip,
 | |
| 		peer:         peer,
 | |
| 		cidr:         cidr,
 | |
| 		bit_at_byte:  cidr / 8,
 | |
| 		bit_at_shift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 
 | |
| 	cidr = min(cidr, common)
 | |
| 
 | |
| 	// check for shorter prefix
 | |
| 
 | |
| 	if newNode.cidr == cidr {
 | |
| 		bit := newNode.choose(node.bits)
 | |
| 		newNode.child[bit] = node
 | |
| 		return newNode
 | |
| 	}
 | |
| 
 | |
| 	// create new parent for node & newNode
 | |
| 
 | |
| 	parent := &Trie{
 | |
| 		bits:         ip,
 | |
| 		peer:         nil,
 | |
| 		cidr:         cidr,
 | |
| 		bit_at_byte:  cidr / 8,
 | |
| 		bit_at_shift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 
 | |
| 	bit := parent.choose(ip)
 | |
| 	parent.child[bit] = newNode
 | |
| 	parent.child[bit^1] = node
 | |
| 
 | |
| 	return parent
 | |
| }
 | |
| 
 | |
| func (node *Trie) Lookup(ip net.IP) *Peer {
 | |
| 	var found *Peer
 | |
| 	size := uint(len(ip))
 | |
| 	for node != nil && commonBits(node.bits, ip) >= node.cidr {
 | |
| 		if node.peer != nil {
 | |
| 			found = node.peer
 | |
| 		}
 | |
| 		if node.bit_at_byte == size {
 | |
| 			break
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node = node.child[bit]
 | |
| 	}
 | |
| 	return found
 | |
| }
 | |
| 
 | |
| func (node *Trie) Count() uint {
 | |
| 	if node == nil {
 | |
| 		return 0
 | |
| 	}
 | |
| 	l := node.child[0].Count()
 | |
| 	r := node.child[1].Count()
 | |
| 	return l + r
 | |
| }
 | |
| 
 | |
| func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
 | |
| 	if node == nil {
 | |
| 		return results
 | |
| 	}
 | |
| 	if node.peer == p {
 | |
| 		var mask net.IPNet
 | |
| 		mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
 | |
| 		if len(node.bits) == net.IPv4len {
 | |
| 			mask.IP = net.IPv4(
 | |
| 				node.bits[0],
 | |
| 				node.bits[1],
 | |
| 				node.bits[2],
 | |
| 				node.bits[3],
 | |
| 			)
 | |
| 		} else if len(node.bits) == net.IPv6len {
 | |
| 			mask.IP = node.bits
 | |
| 		} else {
 | |
| 			panic(errors.New("bug: unexpected address length"))
 | |
| 		}
 | |
| 		results = append(results, mask)
 | |
| 	}
 | |
| 	results = node.child[0].AllowedIPs(p, results)
 | |
| 	results = node.child[1].AllowedIPs(p, results)
 | |
| 	return results
 | |
| }
 | 
