mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 11:56:22 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			295 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			295 lines
		
	
	
		
			6.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: MIT
 | |
|  *
 | |
|  * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package device
 | |
| 
 | |
| import (
 | |
| 	"container/list"
 | |
| 	"encoding/binary"
 | |
| 	"errors"
 | |
| 	"math/bits"
 | |
| 	"net"
 | |
| 	"net/netip"
 | |
| 	"sync"
 | |
| 	"unsafe"
 | |
| )
 | |
| 
 | |
| type parentIndirection struct {
 | |
| 	parentBit     **trieEntry
 | |
| 	parentBitType uint8
 | |
| }
 | |
| 
 | |
| type trieEntry struct {
 | |
| 	peer        *Peer
 | |
| 	child       [2]*trieEntry
 | |
| 	parent      parentIndirection
 | |
| 	cidr        uint8
 | |
| 	bitAtByte   uint8
 | |
| 	bitAtShift  uint8
 | |
| 	bits        []byte
 | |
| 	perPeerElem *list.Element
 | |
| }
 | |
| 
 | |
| func commonBits(ip1, ip2 []byte) uint8 {
 | |
| 	size := len(ip1)
 | |
| 	if size == net.IPv4len {
 | |
| 		a := binary.BigEndian.Uint32(ip1)
 | |
| 		b := binary.BigEndian.Uint32(ip2)
 | |
| 		x := a ^ b
 | |
| 		return uint8(bits.LeadingZeros32(x))
 | |
| 	} else if size == net.IPv6len {
 | |
| 		a := binary.BigEndian.Uint64(ip1)
 | |
| 		b := binary.BigEndian.Uint64(ip2)
 | |
| 		x := a ^ b
 | |
| 		if x != 0 {
 | |
| 			return uint8(bits.LeadingZeros64(x))
 | |
| 		}
 | |
| 		a = binary.BigEndian.Uint64(ip1[8:])
 | |
| 		b = binary.BigEndian.Uint64(ip2[8:])
 | |
| 		x = a ^ b
 | |
| 		return 64 + uint8(bits.LeadingZeros64(x))
 | |
| 	} else {
 | |
| 		panic("Wrong size bit string")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) addToPeerEntries() {
 | |
| 	node.perPeerElem = node.peer.trieEntries.PushBack(node)
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) removeFromPeerEntries() {
 | |
| 	if node.perPeerElem != nil {
 | |
| 		node.peer.trieEntries.Remove(node.perPeerElem)
 | |
| 		node.perPeerElem = nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) choose(ip []byte) byte {
 | |
| 	return (ip[node.bitAtByte] >> node.bitAtShift) & 1
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) maskSelf() {
 | |
| 	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
 | |
| 	for i := 0; i < len(mask); i++ {
 | |
| 		node.bits[i] &= mask[i]
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) zeroizePointers() {
 | |
| 	// Make the garbage collector's life slightly easier
 | |
| 	node.peer = nil
 | |
| 	node.child[0] = nil
 | |
| 	node.child[1] = nil
 | |
| 	node.parent.parentBit = nil
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
 | |
| 	for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
 | |
| 		parent = node
 | |
| 		if parent.cidr == cidr {
 | |
| 			exact = true
 | |
| 			return
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node = node.child[bit]
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
 | |
| 	if *trie.parentBit == nil {
 | |
| 		node := &trieEntry{
 | |
| 			peer:       peer,
 | |
| 			parent:     trie,
 | |
| 			bits:       ip,
 | |
| 			cidr:       cidr,
 | |
| 			bitAtByte:  cidr / 8,
 | |
| 			bitAtShift: 7 - (cidr % 8),
 | |
| 		}
 | |
| 		node.maskSelf()
 | |
| 		node.addToPeerEntries()
 | |
| 		*trie.parentBit = node
 | |
| 		return
 | |
| 	}
 | |
| 	node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
 | |
| 	if exact {
 | |
| 		node.removeFromPeerEntries()
 | |
| 		node.peer = peer
 | |
| 		node.addToPeerEntries()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	newNode := &trieEntry{
 | |
| 		peer:       peer,
 | |
| 		bits:       ip,
 | |
| 		cidr:       cidr,
 | |
| 		bitAtByte:  cidr / 8,
 | |
| 		bitAtShift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 	newNode.maskSelf()
 | |
| 	newNode.addToPeerEntries()
 | |
| 
 | |
| 	var down *trieEntry
 | |
| 	if node == nil {
 | |
| 		down = *trie.parentBit
 | |
| 	} else {
 | |
| 		bit := node.choose(ip)
 | |
| 		down = node.child[bit]
 | |
| 		if down == nil {
 | |
| 			newNode.parent = parentIndirection{&node.child[bit], bit}
 | |
| 			node.child[bit] = newNode
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 	common := commonBits(down.bits, ip)
 | |
| 	if common < cidr {
 | |
| 		cidr = common
 | |
| 	}
 | |
| 	parent := node
 | |
| 
 | |
| 	if newNode.cidr == cidr {
 | |
| 		bit := newNode.choose(down.bits)
 | |
| 		down.parent = parentIndirection{&newNode.child[bit], bit}
 | |
| 		newNode.child[bit] = down
 | |
| 		if parent == nil {
 | |
| 			newNode.parent = trie
 | |
| 			*trie.parentBit = newNode
 | |
| 		} else {
 | |
| 			bit := parent.choose(newNode.bits)
 | |
| 			newNode.parent = parentIndirection{&parent.child[bit], bit}
 | |
| 			parent.child[bit] = newNode
 | |
| 		}
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	node = &trieEntry{
 | |
| 		bits:       append([]byte{}, newNode.bits...),
 | |
| 		cidr:       cidr,
 | |
| 		bitAtByte:  cidr / 8,
 | |
| 		bitAtShift: 7 - (cidr % 8),
 | |
| 	}
 | |
| 	node.maskSelf()
 | |
| 
 | |
| 	bit := node.choose(down.bits)
 | |
| 	down.parent = parentIndirection{&node.child[bit], bit}
 | |
| 	node.child[bit] = down
 | |
| 	bit = node.choose(newNode.bits)
 | |
| 	newNode.parent = parentIndirection{&node.child[bit], bit}
 | |
| 	node.child[bit] = newNode
 | |
| 	if parent == nil {
 | |
| 		node.parent = trie
 | |
| 		*trie.parentBit = node
 | |
| 	} else {
 | |
| 		bit := parent.choose(node.bits)
 | |
| 		node.parent = parentIndirection{&parent.child[bit], bit}
 | |
| 		parent.child[bit] = node
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (node *trieEntry) lookup(ip []byte) *Peer {
 | |
| 	var found *Peer
 | |
| 	size := uint8(len(ip))
 | |
| 	for node != nil && commonBits(node.bits, ip) >= node.cidr {
 | |
| 		if node.peer != nil {
 | |
| 			found = node.peer
 | |
| 		}
 | |
| 		if node.bitAtByte == size {
 | |
| 			break
 | |
| 		}
 | |
| 		bit := node.choose(ip)
 | |
| 		node = node.child[bit]
 | |
| 	}
 | |
| 	return found
 | |
| }
 | |
| 
 | |
| type AllowedIPs struct {
 | |
| 	IPv4  *trieEntry
 | |
| 	IPv6  *trieEntry
 | |
| 	mutex sync.RWMutex
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
 | |
| 	table.mutex.RLock()
 | |
| 	defer table.mutex.RUnlock()
 | |
| 
 | |
| 	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
 | |
| 		node := elem.Value.(*trieEntry)
 | |
| 		a, _ := netip.AddrFromSlice(node.bits)
 | |
| 		if !cb(netip.PrefixFrom(a, int(node.cidr))) {
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
 | |
| 	table.mutex.Lock()
 | |
| 	defer table.mutex.Unlock()
 | |
| 
 | |
| 	var next *list.Element
 | |
| 	for elem := peer.trieEntries.Front(); elem != nil; elem = next {
 | |
| 		next = elem.Next()
 | |
| 		node := elem.Value.(*trieEntry)
 | |
| 
 | |
| 		node.removeFromPeerEntries()
 | |
| 		node.peer = nil
 | |
| 		if node.child[0] != nil && node.child[1] != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		bit := 0
 | |
| 		if node.child[0] == nil {
 | |
| 			bit = 1
 | |
| 		}
 | |
| 		child := node.child[bit]
 | |
| 		if child != nil {
 | |
| 			child.parent = node.parent
 | |
| 		}
 | |
| 		*node.parent.parentBit = child
 | |
| 		if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
 | |
| 			node.zeroizePointers()
 | |
| 			continue
 | |
| 		}
 | |
| 		parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
 | |
| 		if parent.peer != nil {
 | |
| 			node.zeroizePointers()
 | |
| 			continue
 | |
| 		}
 | |
| 		child = parent.child[node.parent.parentBitType^1]
 | |
| 		if child != nil {
 | |
| 			child.parent = parent.parent
 | |
| 		}
 | |
| 		*parent.parent.parentBit = child
 | |
| 		node.zeroizePointers()
 | |
| 		parent.zeroizePointers()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
 | |
| 	table.mutex.Lock()
 | |
| 	defer table.mutex.Unlock()
 | |
| 
 | |
| 	if prefix.Addr().Is6() {
 | |
| 		ip := prefix.Addr().As16()
 | |
| 		parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
 | |
| 	} else if prefix.Addr().Is4() {
 | |
| 		ip := prefix.Addr().As4()
 | |
| 		parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
 | |
| 	} else {
 | |
| 		panic(errors.New("inserting unknown address type"))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (table *AllowedIPs) Lookup(ip []byte) *Peer {
 | |
| 	table.mutex.RLock()
 | |
| 	defer table.mutex.RUnlock()
 | |
| 	switch len(ip) {
 | |
| 	case net.IPv6len:
 | |
| 		return table.IPv6.lookup(ip)
 | |
| 	case net.IPv4len:
 | |
| 		return table.IPv4.lookup(ip)
 | |
| 	default:
 | |
| 		panic(errors.New("looking up unknown address type"))
 | |
| 	}
 | |
| }
 | 
