diff --git a/pkg/daemon/action/sshdaemon.go b/pkg/daemon/action/sshdaemon.go index 2ec795da..6ba7cff6 100644 --- a/pkg/daemon/action/sshdaemon.go +++ b/pkg/daemon/action/sshdaemon.go @@ -58,7 +58,7 @@ func (svr *Server) SshStart(ctx context.Context, req *rpc.SshStartRequest) (*rpc ctx2, cancelF := context.WithCancel(ctx) wait.UntilWithContext(ctx2, func(ctx context.Context) { ip, _, _ := net.ParseCIDR(DefaultServerIP) - ok, err := util.Ping(ip.String()) + ok, err := util.Ping(ctx2, ip.String()) if err != nil { } else if ok { cancelF() diff --git a/pkg/daemon/handler/ssh.go b/pkg/daemon/handler/ssh.go index 42d63735..68fc0f0c 100644 --- a/pkg/daemon/handler/ssh.go +++ b/pkg/daemon/handler/ssh.go @@ -131,15 +131,17 @@ func (w *wsHandler) handle(ctx context.Context) { } log.Info("tunnel connected") go func() { - ticker := time.NewTicker(time.Second * 2) for { select { case <-ctx.Done(): return - case <-ticker.C: - _, _ = util.Ping(clientIP.IP.String()) - _, _ = util.Ping(ip.String()) - _ = exec.CommandContext(ctx, "ping", "-c", "4", "-b", tun.Name, ip.String()).Run() + default: + _, _ = util.Ping(ctx, clientIP.IP.String()) + _, _ = util.Ping(ctx, ip.String()) + ctx2, cancelFunc := context.WithTimeout(ctx, time.Second*5) + _ = exec.CommandContext(ctx2, "ping", "-c", "4", "-b", tun.Name, ip.String()).Run() + cancelFunc() + time.Sleep(time.Second * 5) } } }() diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 12823c91..87d6d59c 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -904,7 +904,7 @@ RetryWithDNSClient: ctx2, cancelFunc := context.WithTimeout(ctx, time.Second*10) wait.UntilWithContext(ctx2, func(context.Context) { for _, ip := range ips { - pong, err2 := util.Ping(ip) + pong, err2 := util.Ping(ctx2, ip) if err2 == nil && pong { ips = []string{ip} cancelFunc() @@ -1152,7 +1152,7 @@ func (c *ConnectOptions) heartbeats(ctx context.Context) { err := c.dhcp.ForEach(func(ip net.IP) { go func() { - _, _ = util.Ping(ip.String()) + _, _ = util.Ping(ctx, ip.String()) }() }) if err != nil { diff --git a/pkg/util/net.go b/pkg/util/net.go index b533946c..de9ad6cf 100644 --- a/pkg/util/net.go +++ b/pkg/util/net.go @@ -1,6 +1,7 @@ package util import ( + "context" "fmt" "net" "strings" @@ -59,7 +60,7 @@ func GetTunDeviceByConn(tun net.Conn) (*net.Interface, error) { return nil, fmt.Errorf("can not found any interface with ip %v", ip) } -func Ping(targetIP string) (bool, error) { +func Ping(ctx context.Context, targetIP string) (bool, error) { pinger, err := probing.NewPinger(targetIP) if err != nil { return false, err @@ -68,7 +69,7 @@ func Ping(targetIP string) (bool, error) { pinger.SetPrivileged(true) pinger.Count = 3 pinger.Timeout = time.Millisecond * 1500 - err = pinger.Run() // Blocks until finished. + err = pinger.RunWithContext(ctx) // Blocks until finished. if err != nil { return false, err } diff --git a/pkg/util/pod.go b/pkg/util/pod.go index 5346ed0b..b24e583b 100644 --- a/pkg/util/pod.go +++ b/pkg/util/pod.go @@ -142,7 +142,7 @@ func Heartbeats() { for ; true; <-ticker.C { for _, ip := range []net.IP{config.RouterIP, config.RouterIP6} { time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000))) - _, _ = Ping(ip.String()) + _, _ = Ping(context.Background(), ip.String()) } } }