mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-18 22:44:43 +08:00
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||
if node == nil {
|
||||
return node
|
||||
}
|
||||
|
||||
// walk recursively
|
||||
|
||||
node.child[0] = node.child[0].removeByPeer(p)
|
||||
node.child[1] = node.child[1].removeByPeer(p)
|
||||
|
||||
if node.peer != p {
|
||||
return node
|
||||
}
|
||||
|
||||
// remove peer & merge
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] == nil {
|
||||
return node.child[1]
|
||||
}
|
||||
return node.child[0]
|
||||
}
|
||||
|
||||
func (node *trieEntry) choose(ip net.IP) byte {
|
||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||
}
|
||||
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
table.IPv4 = table.IPv4.removeByPeer(peer)
|
||||
table.IPv6 = table.IPv6.removeByPeer(peer)
|
||||
var next *list.Element
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
||||
next = elem.Next()
|
||||
node := elem.Value.(*trieEntry)
|
||||
|
||||
node.removeFromPeerEntries()
|
||||
node.peer = nil
|
||||
if node.child[0] != nil && node.child[1] != nil {
|
||||
continue
|
||||
}
|
||||
bit := 0
|
||||
if node.child[0] == nil {
|
||||
bit = 1
|
||||
}
|
||||
child := node.child[bit]
|
||||
if child != nil {
|
||||
child.parent = node.parent
|
||||
}
|
||||
*node.parent.parentBit = child
|
||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
||||
continue
|
||||
}
|
||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
||||
if parent.peer != nil {
|
||||
continue
|
||||
}
|
||||
child = parent.child[node.parent.parentBitType^1]
|
||||
if child != nil {
|
||||
child.parent = parent.parent
|
||||
}
|
||||
*parent.parent.parentBit = child
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
|
Reference in New Issue
Block a user