diff --git a/Makefile b/Makefile index 7b1e7fa6..8e3bbdc9 100644 --- a/Makefile +++ b/Makefile @@ -93,4 +93,4 @@ container-local: kubevpn-linux-amd64 .PHONY: container-test container-test: kubevpn-linux-amd64 - docker buildx build --platform linux/amd64 -t ${IMAGE} -f $(BUILD_DIR)/test.Dockerfile . \ No newline at end of file + docker buildx build --platform linux/amd64 -t ${IMAGE} -f $(BUILD_DIR)/test.Dockerfile --push . \ No newline at end of file diff --git a/build/test.Dockerfile b/build/test.Dockerfile index 2253bf4e..bcc815ff 100644 --- a/build/test.Dockerfile +++ b/build/test.Dockerfile @@ -1,4 +1,4 @@ -FROM naison/kubevpn:v1.1.19 +FROM naison/kubevpn:latest WORKDIR /app diff --git a/pkg/core/tcphandler.go b/pkg/core/tcphandler.go index 043cb576..57d236dd 100644 --- a/pkg/core/tcphandler.go +++ b/pkg/core/tcphandler.go @@ -26,8 +26,14 @@ func (c *fakeUDPTunnelConnector) ConnectContext(ctx context.Context, conn net.Co if err != nil { return nil, err } - con.SetKeepAlive(true) - con.SetKeepAlivePeriod(30 * time.Second) + err = con.SetKeepAlive(true) + if err != nil { + return nil, err + } + err = con.SetKeepAlivePeriod(15 * time.Second) + if err != nil { + return nil, err + } } return newFakeUDPTunnelConnOverTCP(ctx, conn) } @@ -68,7 +74,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) { for { dgram, err := readDatagramPacket(tcpConn, b[:]) if err != nil { - log.Debugf("[udp-tun] %s -> 0 : %v", tcpConn.RemoteAddr(), err) + log.Debugf("[tcpserver] %s -> 0 : %v", tcpConn.RemoteAddr(), err) errChan <- err return } @@ -89,7 +95,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) { for { n, err := udpConn.Read(b[:]) if err != nil { - log.Debugf("[udp-tun] %s : %s", tcpConn.RemoteAddr(), err) + log.Debugf("[tcpserver] %s : %s", tcpConn.RemoteAddr(), err) errChan <- err return } @@ -145,5 +151,11 @@ func (c *fakeUDPTunnelConn) WriteTo(b []byte, _ net.Addr) (int, error) { } func (c *fakeUDPTunnelConn) Close() error { + if cc, ok := c.Conn.(interface{ CloseRead() error }); ok { + _ = cc.CloseRead() + } + if cc, ok := c.Conn.(interface{ CloseWrite() error }); ok { + _ = cc.CloseWrite() + } return c.Conn.Close() } diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 672445a2..a8f20549 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -256,25 +256,27 @@ func (h *tunHandler) HandleServer(ctx context.Context, tunConn net.Conn) { tun.Start() for { - var lc net.ListenConfig - packetConn, err := lc.ListenPacket(ctx, "udp", h.node.Addr) - if err != nil { - log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err) - goto errH - } - - err = h.transportTun(ctx, tun, packetConn) - if err != nil { - log.Debugf("[tun] %s: %v", tunConn.LocalAddr(), err) - } - errH: select { case <-h.chExit: + return case <-ctx.Done(): return default: - log.Debugf("next loop, err: %v", err) } + func() { + cancel, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + var lc net.ListenConfig + packetConn, err := lc.ListenPacket(cancel, "udp", h.node.Addr) + if err != nil { + log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err) + return + } + err = h.transportTun(cancel, tun, packetConn) + if err != nil { + log.Debugf("[tun] %s: %v", tunConn.LocalAddr(), err) + } + }() } } @@ -416,34 +418,29 @@ func (h *tunHandler) transportTun(ctx context.Context, tun *Device, conn net.Pac p.Start() go func() { - var err error for e := range tun.tunInbound { - retry: + select { + case <-ctx.Done(): + return + default: + } + addr := h.routes.RouteTo(e.dst) if addr == nil { + config.LPool.Put(e.data[:]) log.Debug(fmt.Errorf("[tun] no route for %s -> %s", e.src, e.dst)) continue } log.Debugf("[tun] find route: %s -> %s", e.dst, addr) - _, err = conn.WriteTo(e.data[:e.length], addr) - // err should never nil, so retry is not work - if err != nil { - h.routes.Remove(e.dst, addr) - log.Debugf("[tun] remove invalid route: %s -> %s", e.dst, addr) - goto retry - } + _, err := conn.WriteTo(e.data[:e.length], addr) config.LPool.Put(e.data[:]) - if err != nil { - goto errH + log.Debugf("[tun] can not route: %s -> %s", e.dst, addr) + errChan <- err + return } } - errH: - if err != nil { - errChan <- err - return - } }() select { diff --git a/pkg/core/tunhandlercli.go b/pkg/core/tunhandlercli.go index 3680d814..94f26adc 100644 --- a/pkg/core/tunhandlercli.go +++ b/pkg/core/tunhandlercli.go @@ -34,36 +34,48 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) { for i := 0; i < MaxThread; i++ { go func() { for { - if ctx.Err() != nil { + select { + case <-ctx.Done(): return + default: } - var packetConn net.PacketConn - if !h.chain.IsEmpty() { - cc, errs := h.chain.DialContext(ctx) + + func() { + cancel, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + var packetConn net.PacketConn + defer func() { + if packetConn != nil { + _ = packetConn.Close() + } + }() + if !h.chain.IsEmpty() { + cc, errs := h.chain.DialContext(cancel) + if errs != nil { + log.Debug(errs) + time.Sleep(time.Second * 5) + return + } + var ok bool + if packetConn, ok = cc.(net.PacketConn); !ok { + errs = errors.New("not a packet connection") + log.Errorf("[tun] %s - %s: %s", tun.LocalAddr(), remoteAddr, errs) + return + } + } else { + var errs error + var lc net.ListenConfig + packetConn, errs = lc.ListenPacket(cancel, "udp", "") + if errs != nil { + log.Error(errs) + return + } + } + errs := h.transportTunCli(cancel, d, packetConn, remoteAddr) if errs != nil { - log.Debug(errs) - time.Sleep(time.Second * 5) - continue + log.Debugf("[tun] %s: %v", tun.LocalAddr(), errs) } - var ok bool - if packetConn, ok = cc.(net.PacketConn); !ok { - errs = errors.New("not a packet connection") - log.Errorf("[tun] %s - %s: %s", tun.LocalAddr(), remoteAddr, errs) - continue - } - } else { - var errs error - var lc net.ListenConfig - packetConn, errs = lc.ListenPacket(ctx, "udp", "") - if errs != nil { - log.Error(err) - continue - } - } - errs := h.transportTunCli(ctx, d, packetConn, remoteAddr) - if errs != nil { - log.Debugf("[tun] %s: %v", tun.LocalAddr(), errs) - } + }() } }() } @@ -82,8 +94,13 @@ func (h *tunHandler) transportTunCli(ctx context.Context, d *Device, conn net.Pa defer conn.Close() go func() { - var err error for e := range d.tunInbound { + select { + case <-ctx.Done(): + return + default: + } + if e.src.Equal(e.dst) { if d.closed.Load() { return @@ -91,7 +108,7 @@ func (h *tunHandler) transportTunCli(ctx context.Context, d *Device, conn net.Pa d.tunOutbound <- e continue } - _, err = conn.WriteTo(e.data[:e.length], remoteAddr) + _, err := conn.WriteTo(e.data[:e.length], remoteAddr) config.LPool.Put(e.data[:]) if err != nil { errChan <- err @@ -102,6 +119,12 @@ func (h *tunHandler) transportTunCli(ctx context.Context, d *Device, conn net.Pa go func() { for { + select { + case <-ctx.Done(): + return + default: + } + b := config.LPool.Get().([]byte) n, _, err := conn.ReadFrom(b[:]) if err != nil { diff --git a/pkg/util/portforward.go b/pkg/util/portforward.go index 4203e9c8..0f85c329 100644 --- a/pkg/util/portforward.go +++ b/pkg/util/portforward.go @@ -381,10 +381,16 @@ func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) { select { case <-remoteDone: case <-localError: + // wait for interrupt or conn closure + case <-pf.stopChan: + runtime.HandleError(errors.New("lost connection to pod")) } // always expect something on errorChan (it may be nil) - err = <-errorChan + select { + case err = <-errorChan: + default: + } if err != nil { if strings.Contains(err.Error(), "failed to find socat") { select {