diff --git a/vpn/iface/iface.go b/vpn/iface/iface.go index 96908fc..521ee8c 100644 --- a/vpn/iface/iface.go +++ b/vpn/iface/iface.go @@ -28,8 +28,7 @@ var _ RoutingTable = (*TunInterface)(nil) type TunInterface struct { dev tun.Device ifName string - ipv6 *lru.Cache[string, []*net.IPNet] - ipv4 *lru.Cache[string, []*net.IPNet] + routing *lru.Cache[string, []*net.IPNet] peers *lru.Cache[string, net.Addr] peersMutex sync.RWMutex } @@ -46,11 +45,10 @@ func Create(tunName string, cfg Config) (*TunInterface, error) { link.SetupLink(tunName, cfg.IPv6) } return &TunInterface{ - dev: device, - ifName: tunName, - ipv6: lru.New[string, []*net.IPNet](256), - ipv4: lru.New[string, []*net.IPNet](256), - peers: lru.New[string, net.Addr](1024), + dev: device, + ifName: tunName, + routing: lru.New[string, []*net.IPNet](512), + peers: lru.New[string, net.Addr](1024), }, nil } @@ -62,18 +60,7 @@ func (r *TunInterface) GetPeer(ip string) (net.Addr, bool) { return peerID, true } dstIP := net.ParseIP(ip) - if dstIP.To4() != nil { - k, _, _ := r.ipv4.Find(func(k string, v []*net.IPNet) bool { - for _, cidr := range v { - if cidr.Contains(dstIP) { - return true - } - } - return false - }) - return r.peers.Get(k) - } - k, _, _ := r.ipv6.Find(func(k string, v []*net.IPNet) bool { + k, _, _ := r.routing.Find(func(k string, v []*net.IPNet) bool { for _, cidr := range v { if cidr.Contains(dstIP) { return true @@ -98,26 +85,14 @@ func (r *TunInterface) AddPeer(ipv4, ipv6 string, peer net.Addr) { func (r *TunInterface) AddRoute(cidr *net.IPNet, via net.IP) { r.peersMutex.Lock() defer r.peersMutex.Unlock() - var cidrs []*net.IPNet - if via.To4() != nil { - cidrs, _ = r.ipv4.Get(via.String()) - for _, cmp := range cidrs { - if cmp.String() == cidr.String() { - return - } + cidrs, _ := r.routing.Get(via.String()) + for _, cmp := range cidrs { + if cmp.String() == cidr.String() { + return } - cidrs = append(cidrs, cidr) - r.ipv4.Put(via.String(), cidrs) - } else { - cidrs, _ := r.ipv6.Get(via.String()) - for _, cmp := range cidrs { - if cmp.String() == cidr.String() { - return - } - } - cidrs = append(cidrs, cidr) - r.ipv6.Put(via.String(), cidrs) } + cidrs = append(cidrs, cidr) + r.routing.Put(via.String(), cidrs) r.updateRoute(via, cidrs[:len(cidrs)-1], cidrs) } diff --git a/vpn/iface/iface_unix.go b/vpn/iface/iface_unix.go index f7058a0..01f4088 100644 --- a/vpn/iface/iface_unix.go +++ b/vpn/iface/iface_unix.go @@ -22,9 +22,8 @@ func CreateFD(tunFD int, cfg Config) (*TunInterface, error) { return nil, err } return &TunInterface{ - dev: device, - ipv6: lru.New[string, []*net.IPNet](256), - ipv4: lru.New[string, []*net.IPNet](256), - peers: lru.New[string, net.Addr](1024), + dev: device, + routing: lru.New[string, []*net.IPNet](512), + peers: lru.New[string, net.Addr](1024), }, nil }