diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index b94203df..76506bc7 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -508,35 +508,42 @@ func (c *ConnectOptions) addRouteDynamic(ctx context.Context) error { return nil } -func (c *ConnectOptions) addRoute(ipStr string) error { +func (c *ConnectOptions) addRoute(ipStrList ...string) error { if c.tunName == "" { return nil } - - ip := net.ParseIP(ipStr) - if ip == nil { - return nil - } - for _, p := range c.apiServerIPs { - // if pod ip or service ip is equal to apiServer ip, can not add it to route table - if p.Equal(ip) { - return nil + var routes []types.Route + for _, ipStr := range ipStrList { + ip := net.ParseIP(ipStr) + if ip == nil { + continue } - } - - var mask net.IPMask - if ip.To4() != nil { - mask = net.CIDRMask(32, 32) - } else { - mask = net.CIDRMask(128, 128) - } - if r, err := netroute.New(); err == nil { - ifi, _, _, err := r.Route(ip) - if err == nil && ifi.Name == c.tunName { - return nil + var match bool + for _, p := range c.apiServerIPs { + // if pod ip or service ip is equal to apiServer ip, can not add it to route table + if p.Equal(ip) { + match = true + break + } } + if match { + continue + } + var mask net.IPMask + if ip.To4() != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + if r, err := netroute.New(); err == nil { + ifi, _, _, err := r.Route(ip) + if err == nil && ifi.Name == c.tunName { + continue + } + } + routes = append(routes, types.Route{Dst: net.IPNet{IP: ip, Mask: mask}}) } - err := tun.AddRoutes(c.tunName, types.Route{Dst: net.IPNet{IP: ip, Mask: mask}}) + err := tun.AddRoutes(c.tunName, routes...) return err } diff --git a/pkg/handler/function_test.go b/pkg/handler/function_test.go index c65e6900..a95345c1 100644 --- a/pkg/handler/function_test.go +++ b/pkg/handler/function_test.go @@ -40,12 +40,14 @@ var ( func TestFunctions(t *testing.T) { Init() kubevpnConnect(t) + kubevpnStatus(t) t.Run(runtime.FuncForPC(reflect.ValueOf(pingPodIP).Pointer()).Name(), pingPodIP) t.Run(runtime.FuncForPC(reflect.ValueOf(dialUDP).Pointer()).Name(), dialUDP) t.Run(runtime.FuncForPC(reflect.ValueOf(healthCheckPod).Pointer()).Name(), healthCheckPod) t.Run(runtime.FuncForPC(reflect.ValueOf(healthCheckService).Pointer()).Name(), healthCheckService) t.Run(runtime.FuncForPC(reflect.ValueOf(shortDomain).Pointer()).Name(), shortDomain) t.Run(runtime.FuncForPC(reflect.ValueOf(fullDomain).Pointer()).Name(), fullDomain) + kubevpnStatus(t) } func pingPodIP(t *testing.T) { @@ -338,6 +340,19 @@ func kubevpnConnect(t *testing.T) { } } +func kubevpnStatus(t *testing.T) { + cmd := exec.Command("kubevpn", "status") + stdout, stderr, err := util.RunWithRollingOutWithChecker(cmd, nil) + if err != nil { + t.Log(stdout, stderr) + t.Error(err) + t.Fail() + return + } + t.Log(stdout) + t.Log(stderr) +} + func Init() { var err error diff --git a/pkg/tun/route_darwin.go b/pkg/tun/route_darwin.go index 558e0a01..24ff1f35 100644 --- a/pkg/tun/route_darwin.go +++ b/pkg/tun/route_darwin.go @@ -10,33 +10,49 @@ import ( "golang.org/x/sys/unix" ) -func addRoute(seq int, r netip.Prefix, gw route.Addr) error { +func addRoute(gw route.Addr, r ...netip.Prefix) error { + if len(r) == 0 { + return nil + } return withRouteSocket(func(routeSocket int) error { - m := newRouteMessage(unix.RTM_ADD, seq, r, gw) - rb, err := m.Marshal() - if err != nil { - return err + for i, prefix := range r { + m := newRouteMessage(unix.RTM_ADD, i+1, prefix, gw) + rb, err := m.Marshal() + if err != nil { + return err + } + _, err = unix.Write(routeSocket, rb) + if errors.Is(err, unix.EEXIST) { + err = nil + } + if err != nil { + return err + } } - _, err = unix.Write(routeSocket, rb) - if errors.Is(err, unix.EEXIST) { - err = nil - } - return err + return nil }) } -func deleteRoute(seq int, r netip.Prefix, gw route.Addr) error { +func deleteRoute(gw route.Addr, r ...netip.Prefix) error { + if len(r) == 0 { + return nil + } return withRouteSocket(func(routeSocket int) error { - m := newRouteMessage(unix.RTM_DELETE, seq, r, gw) - rb, err := m.Marshal() - if err != nil { - return err + for i, prefix := range r { + m := newRouteMessage(unix.RTM_DELETE, i+1, prefix, gw) + rb, err := m.Marshal() + if err != nil { + return err + } + _, err = unix.Write(routeSocket, rb) + if errors.Is(err, unix.ESRCH) { + err = nil + } + if err != nil { + return err + } } - _, err = unix.Write(routeSocket, rb) - if errors.Is(err, unix.ESRCH) { - err = nil - } - return err + return nil }) } @@ -45,12 +61,11 @@ func withRouteSocket(f func(routeSocket int) error) error { if err != nil { return err } - + defer unix.Close(routeSocket) // Avoid the overhead of echoing messages back to sender if err = unix.SetsockoptInt(routeSocket, unix.SOL_SOCKET, unix.SO_USELOOPBACK, 0); err != nil { return err } - defer unix.Close(routeSocket) return f(routeSocket) } diff --git a/pkg/tun/tun_darwin.go b/pkg/tun/tun_darwin.go index c002d8cf..bbf0f55d 100644 --- a/pkg/tun/tun_darwin.go +++ b/pkg/tun/tun_darwin.go @@ -99,6 +99,7 @@ func addTunRoutes(ifName string, routes ...types.Route) error { } gw := &route.LinkAddr{Index: tunIfi.Index} + var prefixList []netip.Prefix for _, r := range routes { if r.Dst.String() == "" { continue @@ -108,10 +109,14 @@ func addTunRoutes(ifName string, routes ...types.Route) error { if err != nil { return err } - err = addRoute(1, prefix, gw) - if err != nil { - return fmt.Errorf("failed to add route: %v", err) - } + prefixList = append(prefixList, prefix) + } + if len(prefixList) == 0 { + return nil + } + err = addRoute(gw, prefixList...) + if err != nil { + return fmt.Errorf("failed to add route: %v", err) } return nil } diff --git a/pkg/util/route.go b/pkg/util/route.go index 2e3165ed..d0f7ed84 100644 --- a/pkg/util/route.go +++ b/pkg/util/route.go @@ -50,18 +50,20 @@ func GetNsForListPodAndSvc(ctx context.Context, clientset *kubernetes.Clientset, return } -func ListService(ctx context.Context, lister v12.ServiceInterface, addRouteFunc func(ipStr string) error) error { +func ListService(ctx context.Context, lister v12.ServiceInterface, addRouteFunc func(ipStr ...string) error) error { opts := metav1.ListOptions{Limit: 100, Continue: ""} for { serviceList, err := lister.List(ctx, opts) if err != nil { return err } + var ips []string for _, service := range serviceList.Items { - err = addRouteFunc(service.Spec.ClusterIP) - if err != nil { - log.Errorf("Failed to add service: %s IP: %s to route table: %v", service.Name, service.Spec.ClusterIP, err) - } + ips = append(ips, service.Spec.ClusterIP) + } + err = addRouteFunc(ips...) + if err != nil { + log.Errorf("Failed to add service IP: %s to route table: %v", ips, err) } if serviceList.Continue == "" { return nil @@ -70,7 +72,7 @@ func ListService(ctx context.Context, lister v12.ServiceInterface, addRouteFunc } } -func WatchServiceToAddRoute(ctx context.Context, watcher v12.ServiceInterface, routeFunc func(ipStr string) error) error { +func WatchServiceToAddRoute(ctx context.Context, watcher v12.ServiceInterface, routeFunc func(ipStr ...string) error) error { defer func() { if er := recover(); er != nil { log.Error(er) @@ -99,21 +101,23 @@ func WatchServiceToAddRoute(ctx context.Context, watcher v12.ServiceInterface, r } } -func ListPod(ctx context.Context, lister v12.PodInterface, addRouteFunc func(ipStr string) error) error { +func ListPod(ctx context.Context, lister v12.PodInterface, addRouteFunc func(ipStr ...string) error) error { opts := metav1.ListOptions{Limit: 100, Continue: ""} for { podList, err := lister.List(ctx, opts) if err != nil { return err } + var ips []string for _, pod := range podList.Items { if pod.Spec.HostNetwork { continue } - err = addRouteFunc(pod.Status.PodIP) - if err != nil { - log.Errorf("Failed to add pod: %s IP: %s to route table: %v", pod.Name, pod.Status.PodIP, err) - } + ips = append(ips, pod.Status.PodIP) + } + err = addRouteFunc(ips...) + if err != nil { + log.Errorf("Failed to add Pod IP: %v route table: %v", ips, err) } if podList.Continue == "" { return nil @@ -122,7 +126,7 @@ func ListPod(ctx context.Context, lister v12.PodInterface, addRouteFunc func(ipS } } -func WatchPodToAddRoute(ctx context.Context, watcher v12.PodInterface, addRouteFunc func(ipStr string) error) error { +func WatchPodToAddRoute(ctx context.Context, watcher v12.PodInterface, addRouteFunc func(ipStrList ...string) error) error { defer func() { if er := recover(); er != nil { log.Errorln(er)