mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-05 16:47:02 +08:00
device: remove recursion from insertion and connect parent pointers
This makes the insertion algorithm a bit more efficient, while also now taking on the additional task of connecting up parent pointers. This will be handy in the following commit. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -14,9 +14,15 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type parentIndirection struct {
|
||||
parentBit **trieEntry
|
||||
parentBitType uint8
|
||||
}
|
||||
|
||||
type trieEntry struct {
|
||||
peer *Peer
|
||||
child [2]*trieEntry
|
||||
parent parentIndirection
|
||||
cidr uint8
|
||||
bitAtByte uint8
|
||||
bitAtShift uint8
|
||||
@@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
|
||||
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||
parent = node
|
||||
if parent.cidr == cidr {
|
||||
exact = true
|
||||
return
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node = node.child[bit]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// at leaf
|
||||
|
||||
if node == nil {
|
||||
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
if *trie.parentBit == nil {
|
||||
node := &trieEntry{
|
||||
bits: ip,
|
||||
peer: peer,
|
||||
parent: trie,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
node.maskSelf()
|
||||
node.addToPeerEntries()
|
||||
return node
|
||||
*trie.parentBit = node
|
||||
return
|
||||
}
|
||||
|
||||
// traverse deeper
|
||||
|
||||
common := commonBits(node.bits, ip)
|
||||
if node.cidr <= cidr && common >= node.cidr {
|
||||
if node.cidr == cidr {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = peer
|
||||
node.addToPeerEntries()
|
||||
return node
|
||||
}
|
||||
bit := node.choose(ip)
|
||||
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
||||
return node
|
||||
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
||||
if exact {
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = peer
|
||||
node.addToPeerEntries()
|
||||
return
|
||||
}
|
||||
|
||||
// split node
|
||||
|
||||
newNode := &trieEntry{
|
||||
bits: ip,
|
||||
peer: peer,
|
||||
bits: ip,
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
@@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
|
||||
newNode.maskSelf()
|
||||
newNode.addToPeerEntries()
|
||||
|
||||
var down *trieEntry
|
||||
if node == nil {
|
||||
down = *trie.parentBit
|
||||
} else {
|
||||
bit := node.choose(ip)
|
||||
down = node.child[bit]
|
||||
if down == nil {
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
return
|
||||
}
|
||||
}
|
||||
common := commonBits(down.bits, ip)
|
||||
if common < cidr {
|
||||
cidr = common
|
||||
}
|
||||
|
||||
// check for shorter prefix
|
||||
parent := node
|
||||
|
||||
if newNode.cidr == cidr {
|
||||
bit := newNode.choose(node.bits)
|
||||
newNode.child[bit] = node
|
||||
return newNode
|
||||
bit := newNode.choose(down.bits)
|
||||
down.parent = parentIndirection{&newNode.child[bit], bit}
|
||||
newNode.child[bit] = down
|
||||
if parent == nil {
|
||||
newNode.parent = trie
|
||||
*trie.parentBit = newNode
|
||||
} else {
|
||||
bit := parent.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = newNode
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// create new parent for node & newNode
|
||||
|
||||
parent := &trieEntry{
|
||||
bits: append([]byte{}, ip...),
|
||||
peer: nil,
|
||||
node = &trieEntry{
|
||||
bits: append([]byte{}, newNode.bits...),
|
||||
cidr: cidr,
|
||||
bitAtByte: cidr / 8,
|
||||
bitAtShift: 7 - (cidr % 8),
|
||||
}
|
||||
parent.maskSelf()
|
||||
node.maskSelf()
|
||||
|
||||
bit := parent.choose(ip)
|
||||
parent.child[bit] = newNode
|
||||
parent.child[bit^1] = node
|
||||
|
||||
return parent
|
||||
bit := node.choose(down.bits)
|
||||
down.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = down
|
||||
bit = node.choose(newNode.bits)
|
||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
||||
node.child[bit] = newNode
|
||||
if parent == nil {
|
||||
node.parent = trie
|
||||
*trie.parentBit = node
|
||||
} else {
|
||||
bit := parent.choose(node.bits)
|
||||
node.parent = parentIndirection{&parent.child[bit], bit}
|
||||
parent.child[bit] = node
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||
@@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
|
||||
case net.IPv4len:
|
||||
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
|
||||
default:
|
||||
panic(errors.New("inserting unknown address type"))
|
||||
}
|
||||
|
Reference in New Issue
Block a user