mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-05 16:47:02 +08:00
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -12,6 +12,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
)
|
||||
|
||||
type parentIndirection struct {
|
||||
@@ -26,7 +28,7 @@ type trieEntry struct {
|
||||
cidr uint8
|
||||
bitAtByte uint8
|
||||
bitAtShift uint8
|
||||
bits net.IP
|
||||
bits []byte
|
||||
perPeerElem *list.Element
|
||||
}
|
||||
|
||||
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
|
||||
return bits.ReverseBytes64(i)
|
||||
}
|
||||
|
||||
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
|
||||
func commonBits(ip1, ip2 []byte) uint8 {
|
||||
size := len(ip1)
|
||||
if size == net.IPv4len {
|
||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
||||
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) choose(ip net.IP) byte {
|
||||
func (node *trieEntry) choose(ip []byte) byte {
|
||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||
}
|
||||
|
||||
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
|
||||
node.parent.parentBit = nil
|
||||
}
|
||||
|
||||
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||
parent = node
|
||||
if parent.cidr == cidr {
|
||||
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
|
||||
return
|
||||
}
|
||||
|
||||
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||
if *trie.parentBit == nil {
|
||||
node := &trieEntry{
|
||||
peer: peer,
|
||||
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||
var found *Peer
|
||||
size := uint8(len(ip))
|
||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||
@@ -229,13 +231,13 @@ type AllowedIPs struct {
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||
node := elem.Value.(*trieEntry)
|
||||
if !cb(node.bits, node.cidr) {
|
||||
if !cb(netip.PrefixFrom(netip.AddrFromSlice(node.bits), int(node.cidr))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -283,28 +285,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
|
||||
case net.IPv4len:
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
|
||||
default:
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else {
|
||||
panic(errors.New("inserting unknown address type"))
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Lookup(address []byte) *Peer {
|
||||
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
switch len(address) {
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
return table.IPv6.lookup(address)
|
||||
return table.IPv6.lookup(ip)
|
||||
case net.IPv4len:
|
||||
return table.IPv4.lookup(address)
|
||||
return table.IPv4.lookup(ip)
|
||||
default:
|
||||
panic(errors.New("looking up unknown address type"))
|
||||
}
|
||||
|
Reference in New Issue
Block a user