diff --git a/cmd/kubevpn/cmds/upgrade.go b/cmd/kubevpn/cmds/upgrade.go index dd0e7bd0..94056a41 100644 --- a/cmd/kubevpn/cmds/upgrade.go +++ b/cmd/kubevpn/cmds/upgrade.go @@ -49,8 +49,8 @@ func CmdUpgrade(cmdutil.Factory) *cobra.Command { } _, _ = fmt.Fprintln(os.Stdout, fmt.Sprintf("Current version is: %s less than latest version: %s, needs to upgrade", config.Version, latestVersion)) _ = os.Setenv(envLatestUrl, url) - _ = quit(cmd.Context(), false) _ = quit(cmd.Context(), true) + _ = quit(cmd.Context(), false) } return upgrade.Main(cmd.Context(), client, url) }, diff --git a/pkg/core/ssh.go b/pkg/core/ssh.go index e2889b33..230d0e6c 100644 --- a/pkg/core/ssh.go +++ b/pkg/core/ssh.go @@ -38,7 +38,7 @@ func (s *sshHandler) Handle(ctx context.Context, conn net.Conn) { }), Handler: ssh.Handler(func(s ssh.Session) { io.WriteString(s, "Remote forwarding available...\n") - select {} + <-s.Context().Done() }), ReversePortForwardingCallback: ssh.ReversePortForwardingCallback(func(ctx ssh.Context, host string, port uint32) bool { plog.G(ctx).Infoln("attempt to bind", host, port, "granted") diff --git a/pkg/core/tcphandler.go b/pkg/core/tcphandler.go index dd60690d..1c5eaf18 100644 --- a/pkg/core/tcphandler.go +++ b/pkg/core/tcphandler.go @@ -128,7 +128,7 @@ func (h *UDPOverTCPHandler) removeFromRouteMapTCP(ctx context.Context, tcpConn n }) } -var _ net.PacketConn = (*UDPConnOverTCP)(nil) +var _ net.Conn = (*UDPConnOverTCP)(nil) // UDPConnOverTCP fake udp connection over tcp connection type UDPConnOverTCP struct { @@ -141,20 +141,20 @@ func newUDPConnOverTCP(ctx context.Context, conn net.Conn) (net.Conn, error) { return &UDPConnOverTCP{ctx: ctx, Conn: conn}, nil } -func (c *UDPConnOverTCP) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *UDPConnOverTCP) Read(b []byte) (int, error) { select { case <-c.ctx.Done(): - return 0, nil, c.ctx.Err() + return 0, c.ctx.Err() default: datagram, err := readDatagramPacket(c.Conn, b) if err != nil { - return 0, nil, err + return 0, err } - return int(datagram.DataLength), nil, nil + return int(datagram.DataLength), nil } } -func (c *UDPConnOverTCP) WriteTo(b []byte, _ net.Addr) (int, error) { +func (c *UDPConnOverTCP) Write(b []byte) (int, error) { buf := config.LPool.Get().([]byte)[:] n := copy(buf, b) defer config.LPool.Put(buf) diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 62dd83b7..e5770477 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -37,12 +37,7 @@ func TunHandler(forward *Forwarder, node *Node) Handler { func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) { if remote := h.node.Remote; remote != "" { - remoteAddr, err := net.ResolveUDPAddr("udp", remote) - if err != nil { - plog.G(ctx).Errorf("Failed to resolve udp addr %s: %v", remote, err) - return - } - h.HandleClient(ctx, tun, remoteAddr) + h.HandleClient(ctx, tun) } else { h.HandleServer(ctx, tun) } diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 746070ac..9ea8e85a 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -14,7 +14,7 @@ import ( "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) -func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr *net.UDPAddr) { +func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) { device := &ClientDevice{ tun: tun, tunInbound: make(chan *Packet, MaxSize), @@ -23,7 +23,7 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn, remoteAddr } defer device.Close() - go device.handlePacket(ctx, remoteAddr, h.forward) + go device.handlePacket(ctx, h.forward) go device.readFromTun(ctx) go device.writeToTun(ctx) go heartbeats(ctx, device.tun) @@ -43,56 +43,40 @@ type ClientDevice struct { forward *Forwarder } -func (d *ClientDevice) handlePacket(ctx context.Context, remoteAddr *net.UDPAddr, forward *Forwarder) { +func (d *ClientDevice) handlePacket(ctx context.Context, forward *Forwarder) { for ctx.Err() == nil { - packetConn, err := getRemotePacketConn(ctx, forward) + conn, err := forwardConn(ctx, forward) if err != nil { - plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), remoteAddr, err) + plog.G(ctx).Errorf("Failed to get remote conn from %s -> %s: %s", d.tun.LocalAddr(), forward.node.Remote, err) time.Sleep(time.Second * 1) continue } - err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, packetConn, remoteAddr) + err = handlePacketClient(ctx, d.tunInbound, d.tunOutbound, conn) if err != nil { - plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", remoteAddr, err) + plog.G(ctx).Errorf("Failed to transport data to remote %s: %v", conn.RemoteAddr(), err) } } } -func getRemotePacketConn(ctx context.Context, forwarder *Forwarder) (net.PacketConn, error) { +func forwardConn(ctx context.Context, forwarder *Forwarder) (net.Conn, error) { conn, err := forwarder.DialContext(ctx) if err != nil { return nil, errors.Wrap(err, "failed to dial forwarder") } - - if packetConn, ok := conn.(net.PacketConn); !ok { - return nil, errors.Errorf("failed to cast packet conn to PacketConn") - } else { - return packetConn, nil - } + return conn, nil } -func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, packetConn net.PacketConn, remoteAddr net.Addr) error { +func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbound chan<- *Packet, conn net.Conn) error { errChan := make(chan error, 2) - defer packetConn.Close() + defer conn.Close() go func() { defer util.HandleCrash() for packet := range tunInbound { - if packet.src.Equal(packet.dst) { - util.SafeWrite(tunOutbound, packet, func(v *Packet) { - var p = "unknown" - if _, _, protocol, err := util.ParseIP(v.data[:v.length]); err == nil { - p = layers.IPProtocol(protocol).String() - } - config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, p, v.length) - }) - continue - } - _, err := packetConn.WriteTo(packet.data[:packet.length], remoteAddr) + _, err := conn.Write(packet.data[:packet.length]) config.LPool.Put(packet.data[:]) if err != nil { - util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr))) + util.SafeWrite(errChan, errors.Wrap(err, "failed to write packet to remote")) return } } @@ -102,10 +86,10 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo defer util.HandleCrash() for { buf := config.LPool.Get().([]byte)[:] - n, _, err := packetConn.ReadFrom(buf[:]) + n, err := conn.Read(buf[:]) if err != nil { config.LPool.Put(buf[:]) - util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))) + util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", conn.RemoteAddr()))) return } if n == 0 { @@ -115,7 +99,7 @@ func handlePacketClient(ctx context.Context, tunInbound <-chan *Packet, tunOutbo } util.SafeWrite(tunOutbound, NewPacket(buf[:], n, nil, nil), func(v *Packet) { config.LPool.Put(v.data[:]) - plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", packetConn.LocalAddr(), remoteAddr, v.length) + plog.G(context.Background()).Errorf("Drop packet, LocalAddr: %s, Remote: %s, Length: %d", conn.LocalAddr(), conn.RemoteAddr(), v.length) }) } }() @@ -150,10 +134,16 @@ func (d *ClientDevice) readFromTun(ctx context.Context) { continue } plog.G(context.Background()).Debugf("SRC: %s, DST: %s, Protocol: %s, Length: %d", src, dst, layers.IPProtocol(protocol).String(), n) - util.SafeWrite(d.tunInbound, NewPacket(buf[:], n, src, dst), func(v *Packet) { + packet := NewPacket(buf[:], n, src, dst) + f := func(v *Packet) { config.LPool.Put(v.data[:]) plog.G(context.Background()).Errorf("Drop packet, SRC: %s, DST: %s, Protocol: %s, Length: %d", v.src, v.dst, layers.IPProtocol(protocol).String(), v.length) - }) + } + if packet.src.Equal(packet.dst) { + util.SafeWrite(d.tunOutbound, packet, f) + continue + } + util.SafeWrite(d.tunInbound, packet, f) } } @@ -188,7 +178,7 @@ func heartbeats(ctx context.Context, tun net.Conn) { return } - ticker := time.NewTicker(time.Second * 60) + ticker := time.NewTicker(config.KeepAliveTime) defer ticker.Stop() for ; ctx.Err() == nil; <-ticker.C { diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 422daf41..c70c94ea 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -955,18 +955,28 @@ func (c *ConnectOptions) upgradeDeploy(ctx context.Context) error { if len(deploy.Spec.Template.Spec.Containers) == 0 { return fmt.Errorf("can not found any container in deploy %s", deploy.Name) } + // check running pod, sometime deployment is rolling back, so need to check running pod + list, err := c.GetRunningPodList(ctx) + if err != nil { + return err + } clientVer := config.Version clientImg := config.Image serverImg := deploy.Spec.Template.Spec.Containers[0].Image + runningPodImg := list[0].Spec.Containers[0].Image isNeedUpgrade, err := util.IsNewer(clientVer, clientImg, serverImg) - if !isNeedUpgrade { + isPodNeedUpgrade, err1 := util.IsNewer(clientVer, clientImg, runningPodImg) + if !isNeedUpgrade && !isPodNeedUpgrade { return nil } if err != nil { return err } + if err1 != nil { + return err1 + } // 1) update secret err = upgradeSecretSpec(ctx, c.factory, c.Namespace)