mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 11:56:22 +08:00 
			
		
		
		
	device: reduce size of trie struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
		| @@ -15,13 +15,13 @@ import ( | ||||
| ) | ||||
|  | ||||
| type trieEntry struct { | ||||
| 	child        [2]*trieEntry | ||||
| 	peer         *Peer | ||||
| 	bits         net.IP | ||||
| 	cidr         uint | ||||
| 	bit_at_byte  uint | ||||
| 	bit_at_shift uint | ||||
| 	perPeerElem  *list.Element | ||||
| 	peer        *Peer | ||||
| 	child       [2]*trieEntry | ||||
| 	cidr        uint8 | ||||
| 	bitAtByte   uint8 | ||||
| 	bitAtShift  uint8 | ||||
| 	bits        net.IP | ||||
| 	perPeerElem *list.Element | ||||
| } | ||||
|  | ||||
| func isLittleEndian() bool { | ||||
| @@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 { | ||||
| 	return bits.ReverseBytes64(i) | ||||
| } | ||||
|  | ||||
| func commonBits(ip1 net.IP, ip2 net.IP) uint { | ||||
| func commonBits(ip1 net.IP, ip2 net.IP) uint8 { | ||||
| 	size := len(ip1) | ||||
| 	if size == net.IPv4len { | ||||
| 		a := (*uint32)(unsafe.Pointer(&ip1[0])) | ||||
| 		b := (*uint32)(unsafe.Pointer(&ip2[0])) | ||||
| 		x := *a ^ *b | ||||
| 		return uint(bits.LeadingZeros32(swapU32(x))) | ||||
| 		return uint8(bits.LeadingZeros32(swapU32(x))) | ||||
| 	} else if size == net.IPv6len { | ||||
| 		a := (*uint64)(unsafe.Pointer(&ip1[0])) | ||||
| 		b := (*uint64)(unsafe.Pointer(&ip2[0])) | ||||
| 		x := *a ^ *b | ||||
| 		if x != 0 { | ||||
| 			return uint(bits.LeadingZeros64(swapU64(x))) | ||||
| 			return uint8(bits.LeadingZeros64(swapU64(x))) | ||||
| 		} | ||||
| 		a = (*uint64)(unsafe.Pointer(&ip1[8])) | ||||
| 		b = (*uint64)(unsafe.Pointer(&ip2[8])) | ||||
| 		x = *a ^ *b | ||||
| 		return 64 + uint(bits.LeadingZeros64(swapU64(x))) | ||||
| 		return 64 + uint8(bits.LeadingZeros64(swapU64(x))) | ||||
| 	} else { | ||||
| 		panic("Wrong size bit string") | ||||
| 	} | ||||
| @@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { | ||||
| } | ||||
|  | ||||
| func (node *trieEntry) choose(ip net.IP) byte { | ||||
| 	return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 | ||||
| 	return (ip[node.bitAtByte] >> node.bitAtShift) & 1 | ||||
| } | ||||
|  | ||||
| func (node *trieEntry) maskSelf() { | ||||
| @@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { | ||||
| func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { | ||||
|  | ||||
| 	// at leaf | ||||
|  | ||||
| 	if node == nil { | ||||
| 		node := &trieEntry{ | ||||
| 			bits:         ip, | ||||
| 			peer:         peer, | ||||
| 			cidr:         cidr, | ||||
| 			bit_at_byte:  cidr / 8, | ||||
| 			bit_at_shift: 7 - (cidr % 8), | ||||
| 			bits:       ip, | ||||
| 			peer:       peer, | ||||
| 			cidr:       cidr, | ||||
| 			bitAtByte:  cidr / 8, | ||||
| 			bitAtShift: 7 - (cidr % 8), | ||||
| 		} | ||||
| 		node.maskSelf() | ||||
| 		node.addToPeerEntries() | ||||
| @@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { | ||||
| 	// split node | ||||
|  | ||||
| 	newNode := &trieEntry{ | ||||
| 		bits:         ip, | ||||
| 		peer:         peer, | ||||
| 		cidr:         cidr, | ||||
| 		bit_at_byte:  cidr / 8, | ||||
| 		bit_at_shift: 7 - (cidr % 8), | ||||
| 		bits:       ip, | ||||
| 		peer:       peer, | ||||
| 		cidr:       cidr, | ||||
| 		bitAtByte:  cidr / 8, | ||||
| 		bitAtShift: 7 - (cidr % 8), | ||||
| 	} | ||||
| 	newNode.maskSelf() | ||||
| 	newNode.addToPeerEntries() | ||||
|  | ||||
| 	cidr = min(cidr, common) | ||||
| 	if common < cidr { | ||||
| 		cidr = common | ||||
| 	} | ||||
|  | ||||
| 	// check for shorter prefix | ||||
|  | ||||
| @@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { | ||||
| 	// create new parent for node & newNode | ||||
|  | ||||
| 	parent := &trieEntry{ | ||||
| 		bits:         append([]byte{}, ip...), | ||||
| 		peer:         nil, | ||||
| 		cidr:         cidr, | ||||
| 		bit_at_byte:  cidr / 8, | ||||
| 		bit_at_shift: 7 - (cidr % 8), | ||||
| 		bits:       append([]byte{}, ip...), | ||||
| 		peer:       nil, | ||||
| 		cidr:       cidr, | ||||
| 		bitAtByte:  cidr / 8, | ||||
| 		bitAtShift: 7 - (cidr % 8), | ||||
| 	} | ||||
| 	parent.maskSelf() | ||||
|  | ||||
| @@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { | ||||
|  | ||||
| func (node *trieEntry) lookup(ip net.IP) *Peer { | ||||
| 	var found *Peer | ||||
| 	size := uint(len(ip)) | ||||
| 	size := uint8(len(ip)) | ||||
| 	for node != nil && commonBits(node.bits, ip) >= node.cidr { | ||||
| 		if node.peer != nil { | ||||
| 			found = node.peer | ||||
| 		} | ||||
| 		if node.bit_at_byte == size { | ||||
| 		if node.bitAtByte == size { | ||||
| 			break | ||||
| 		} | ||||
| 		bit := node.choose(ip) | ||||
| @@ -208,7 +210,7 @@ type AllowedIPs struct { | ||||
| 	mutex sync.RWMutex | ||||
| } | ||||
|  | ||||
| func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { | ||||
| func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { | ||||
| 	table.mutex.RLock() | ||||
| 	defer table.mutex.RUnlock() | ||||
|  | ||||
| @@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { | ||||
| 	table.IPv6 = table.IPv6.removeByPeer(peer) | ||||
| } | ||||
|  | ||||
| func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { | ||||
| func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { | ||||
| 	table.mutex.Lock() | ||||
| 	defer table.mutex.Unlock() | ||||
|  | ||||
|   | ||||
| @@ -19,7 +19,7 @@ const ( | ||||
|  | ||||
| type SlowNode struct { | ||||
| 	peer *Peer | ||||
| 	cidr uint | ||||
| 	cidr uint8 | ||||
| 	bits []byte | ||||
| } | ||||
|  | ||||
| @@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) { | ||||
| 	r[i], r[j] = r[j], r[i] | ||||
| } | ||||
|  | ||||
| func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { | ||||
| func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { | ||||
| 	for _, t := range r { | ||||
| 		if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { | ||||
| 			t.peer = peer | ||||
| @@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) { | ||||
| 	for n := 0; n < NumberOfAddresses; n++ { | ||||
| 		var addr [AddressLength]byte | ||||
| 		rand.Read(addr[:]) | ||||
| 		cidr := uint(rand.Uint32() % (AddressLength * 8)) | ||||
| 		cidr := uint8(rand.Uint32() % (AddressLength * 8)) | ||||
| 		index := rand.Int() % NumberOfPeers | ||||
| 		trie = trie.insert(addr[:], cidr, peers[index]) | ||||
| 		slow = slow.Insert(addr[:], cidr, peers[index]) | ||||
| @@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) { | ||||
| 	for n := 0; n < NumberOfAddresses; n++ { | ||||
| 		var addr [AddressLength]byte | ||||
| 		rand.Read(addr[:]) | ||||
| 		cidr := uint(rand.Uint32() % (AddressLength * 8)) | ||||
| 		cidr := uint8(rand.Uint32() % (AddressLength * 8)) | ||||
| 		index := rand.Int() % NumberOfPeers | ||||
| 		trie = trie.insert(addr[:], cidr, peers[index]) | ||||
| 		slow = slow.Insert(addr[:], cidr, peers[index]) | ||||
|   | ||||
| @@ -11,13 +11,10 @@ import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| /* Todo: More comprehensive | ||||
|  */ | ||||
|  | ||||
| type testPairCommonBits struct { | ||||
| 	s1    []byte | ||||
| 	s2    []byte | ||||
| 	match uint | ||||
| 	match uint8 | ||||
| } | ||||
|  | ||||
| func TestCommonBits(t *testing.T) { | ||||
| @@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test | ||||
| 	for n := 0; n < addressNumber; n++ { | ||||
| 		var addr [AddressLength]byte | ||||
| 		rand.Read(addr[:]) | ||||
| 		cidr := uint(rand.Uint32() % (AddressLength * 8)) | ||||
| 		cidr := uint8(rand.Uint32() % (AddressLength * 8)) | ||||
| 		index := rand.Int() % peerNumber | ||||
| 		trie = trie.insert(addr[:], cidr, peers[index]) | ||||
| 	} | ||||
| @@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) { | ||||
|  | ||||
| 	var trie *trieEntry | ||||
|  | ||||
| 	insert := func(peer *Peer, a, b, c, d byte, cidr uint) { | ||||
| 	insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { | ||||
| 		trie = trie.insert([]byte{a, b, c, d}, cidr, peer) | ||||
| 	} | ||||
|  | ||||
| @@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) { | ||||
| 		return out[:] | ||||
| 	} | ||||
|  | ||||
| 	insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { | ||||
| 	insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { | ||||
| 		var addr []byte | ||||
| 		addr = append(addr, expand(a)...) | ||||
| 		addr = append(addr, expand(b)...) | ||||
|   | ||||
| @@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) { | ||||
| 	} | ||||
| 	atomic.StoreInt32(&a.int32, flag) | ||||
| } | ||||
|  | ||||
| func min(a, b uint) uint { | ||||
| 	if a > b { | ||||
| 		return b | ||||
| 	} | ||||
| 	return a | ||||
| } | ||||
|   | ||||
| @@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error { | ||||
| 			sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) | ||||
| 			sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) | ||||
|  | ||||
| 			device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { | ||||
| 			device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { | ||||
| 				sendf("allowed_ip=%s/%d", ip.String(), cidr) | ||||
| 				return true | ||||
| 			}) | ||||
| @@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error | ||||
| 			return nil | ||||
| 		} | ||||
| 		ones, _ := network.Mask.Size() | ||||
| 		device.allowedips.Insert(network.IP, uint(ones), peer.Peer) | ||||
| 		device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) | ||||
|  | ||||
| 	case "protocol_version": | ||||
| 		if value != "1" { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jason A. Donenfeld
					Jason A. Donenfeld