diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 6806737c..4feb4d2d 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -26,7 +26,6 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) { go device.handlePacket(ctx, h.forward) go device.readFromTun(ctx) go device.writeToTun(ctx) - go heartbeats(ctx, device.tun) select { case <-device.errChan: case <-ctx.Done(): @@ -73,7 +72,13 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo go func() { defer util.HandleCrash() for packet := range tunInbound { - _, err := conn.Write(packet.data[:packet.length]) + err := conn.SetWriteDeadline(time.Now().Add(config.KeepAliveTime)) + if err != nil { + plog.G(ctx).Errorf("Failed to set write deadline: %v", err) + util.SafeWrite(errChan, errors.Wrap(err, "failed to set write deadline")) + return + } + _, err = conn.Write(packet.data[:packet.length]) config.LPool.Put(packet.data[:]) if err != nil { plog.G(ctx).Errorf("Failed to write packet to remote: %v", err) @@ -87,6 +92,12 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo defer util.HandleCrash() for { buf := config.LPool.Get().([]byte)[:] + err := conn.SetReadDeadline(time.Now().Add(config.KeepAliveTime)) + if err != nil { + plog.G(ctx).Errorf("Failed to set read deadline: %v", err) + util.SafeWrite(errChan, errors.Wrap(err, "failed to set read deadline")) + return + } n, err := conn.Read(buf[:]) if err != nil { config.LPool.Put(buf[:]) @@ -167,31 +178,3 @@ func (d *ClientDevice) Close() { util.SafeClose(d.tunInbound) util.SafeClose(d.tunOutbound) } - -func heartbeats(ctx context.Context, tun net.Conn) { - tunIfi, err := util.GetTunDeviceByConn(tun) - if err != nil { - plog.G(ctx).Errorf("Failed to get tun device: %v", err) - return - } - srcIPv4, srcIPv6, dockerSrcIPv4, err := util.GetTunDeviceIP(tunIfi.Name) - if err != nil { - plog.G(ctx).Errorf("Failed to get tun device %s IP: %v", tunIfi.Name, err) - return - } - - ticker := time.NewTicker(config.KeepAliveTime) - defer ticker.Stop() - - for ; ctx.Err() == nil; <-ticker.C { - if srcIPv4 != nil { - util.Ping(ctx, srcIPv4.String(), config.RouterIP.String()) - } - if srcIPv6 != nil { - util.Ping(ctx, srcIPv6.String(), config.RouterIP6.String()) - } - if dockerSrcIPv4 != nil { - util.Ping(ctx, dockerSrcIPv4.String(), config.DockerRouterIP.String()) - } - } -} diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 67423e9f..da3a1492 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -329,7 +329,9 @@ func (c *ConnectOptions) portForward(ctx context.Context, portPair []string) err podName := pod.GetName() // try to detect pod is delete event, if pod is deleted, needs to redo port-forward go util.CheckPodStatus(childCtx, cancelFunc, podName, c.clientset.CoreV1().Pods(c.Namespace)) - go healthCheck(childCtx, cancelFunc, readyChan, strings.Split(portPair[1], ":")[0], fmt.Sprintf("%s.%s", config.ConfigMapPodTrafficManager, c.Namespace), c.localTunIPv4.IP) + domain := fmt.Sprintf("%s.%s", config.ConfigMapPodTrafficManager, c.Namespace) + go healthCheckPortForward(childCtx, cancelFunc, readyChan, strings.Split(portPair[1], ":")[0], domain, c.localTunIPv4.IP) + go healthCheckTCPConn(childCtx, cancelFunc, readyChan, domain, util.GetPodIP(pod)[0]) if *first { go func() { select { @@ -1204,7 +1206,7 @@ func (c *ConnectOptions) ProxyResources() ProxyList { return c.proxyWorkloads } -func healthCheck(ctx context.Context, cancelFunc context.CancelFunc, readyChan chan struct{}, localGvisorUDPPort string, domain string, ipv4 net.IP) { +func healthCheckPortForward(ctx context.Context, cancelFunc context.CancelFunc, readyChan chan struct{}, localGvisorUDPPort string, domain string, ipv4 net.IP) { defer cancelFunc() ticker := time.NewTicker(time.Second * 60) defer ticker.Stop() @@ -1258,3 +1260,40 @@ func healthCheck(ctx context.Context, cancelFunc context.CancelFunc, readyChan c } } } + +func healthCheckTCPConn(ctx context.Context, cancelFunc context.CancelFunc, readyChan chan struct{}, domain string, dnsServer string) { + defer cancelFunc() + ticker := time.NewTicker(time.Second * 60) + defer ticker.Stop() + + select { + case <-readyChan: + case <-ticker.C: + plog.G(ctx).Debugf("Wait port-forward to be ready timeout") + return + case <-ctx.Done(): + return + } + + var healthChecker = func() error { + msg := new(miekgdns.Msg) + msg.SetQuestion(miekgdns.Fqdn(domain), miekgdns.TypeA) + client := miekgdns.Client{Net: "udp", Timeout: time.Second * 10} + _, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(dnsServer, "53")) + return err + } + + newTicker := time.NewTicker(config.KeepAliveTime / 2) + defer newTicker.Stop() + for ; ctx.Err() == nil; <-newTicker.C { + err := retry.OnError(wait.Backoff{Duration: time.Second * 10, Steps: 6}, func(err error) bool { + return err != nil + }, func() error { + return healthChecker() + }) + if err != nil { + plog.G(ctx).Errorf("Failed to query DNS: %v", err) + return + } + } +}