device: remove recursion from insertion and connect parent pointers

This makes the insertion algorithm a bit more efficient, while also now
taking on the additional task of connecting up parent pointers. This
will be handy in the following commit.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld
2021-06-03 14:50:28 +02:00
parent 4a57024b94
commit b41f4cc768
3 changed files with 94 additions and 58 deletions

View File

@@ -14,9 +14,15 @@ import (
"unsafe"
)
type parentIndirection struct {
parentBit **trieEntry
parentBitType uint8
}
type trieEntry struct {
peer *Peer
child [2]*trieEntry
parent parentIndirection
cidr uint8
bitAtByte uint8
bitAtShift uint8
@@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
}
}
func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
exact = true
return
}
bit := node.choose(ip)
node = node.child[bit]
}
return
}
// at leaf
if node == nil {
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
bits: ip,
peer: peer,
parent: trie,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
return node
*trie.parentBit = node
return
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
if exact {
node.removeFromPeerEntries()
node.peer = peer
node.addToPeerEntries()
return
}
// split node
newNode := &trieEntry{
bits: ip,
peer: peer,
bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
@@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
newNode.maskSelf()
newNode.addToPeerEntries()
var down *trieEntry
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
return
}
}
common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
// check for shorter prefix
parent := node
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
newNode.child[bit] = node
return newNode
bit := newNode.choose(down.bits)
down.parent = parentIndirection{&newNode.child[bit], bit}
newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
newNode.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
}
// create new parent for node & newNode
parent := &trieEntry{
bits: append([]byte{}, ip...),
peer: nil,
node = &trieEntry{
bits: append([]byte{}, newNode.bits...),
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
parent.maskSelf()
node.maskSelf()
bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
bit := node.choose(down.bits)
down.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = down
bit = node.choose(newNode.bits)
newNode.parent = parentIndirection{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
node.parent = parentIndirection{&parent.child[bit], bit}
parent.child[bit] = node
}
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
@@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}