From 3e51bf0f4dda08f866035069bcfcd43831eaaa7b Mon Sep 17 00:00:00 2001 From: naison <895703375@qq.com> Date: Mon, 13 May 2024 19:58:56 +0800 Subject: [PATCH] hotfix: close chan (#245) --- cmd/kubevpn/cmds/daemon.go | 17 ++++++- pkg/config/config.go | 3 +- pkg/core/tcphandler.go | 2 +- pkg/core/tunendpoint.go | 7 ++- pkg/core/tunhandler.go | 14 ++++-- pkg/core/tunhandlerclient.go | 9 ++-- pkg/daemon/action/connect-fork.go | 1 - pkg/daemon/action/connect.go | 1 - pkg/daemon/handler/ssh.go | 2 +- pkg/handler/cleaner.go | 11 ++-- pkg/handler/connect.go | 4 ++ pkg/util/chan.go | 29 +++++++++++ pkg/util/chan_test.go | 23 +++++++++ pkg/util/image.go | 2 +- pkg/util/ssh.go | 84 +++++++++++++++++++++++-------- 15 files changed, 163 insertions(+), 46 deletions(-) create mode 100644 pkg/util/chan.go create mode 100644 pkg/util/chan_test.go diff --git a/cmd/kubevpn/cmds/daemon.go b/cmd/kubevpn/cmds/daemon.go index 59604b85..da687963 100644 --- a/cmd/kubevpn/cmds/daemon.go +++ b/cmd/kubevpn/cmds/daemon.go @@ -3,13 +3,17 @@ package cmds import ( "crypto/rand" "encoding/base64" + "errors" + "net/http" "github.com/spf13/cobra" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/i18n" "k8s.io/kubectl/pkg/util/templates" + "github.com/wencaiwulue/kubevpn/v2/pkg/config" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) func CmdDaemon(_ cmdutil.Factory) *cobra.Command { @@ -24,10 +28,21 @@ func CmdDaemon(_ cmdutil.Factory) *cobra.Command { return err } opt.ID = base64.URLEncoding.EncodeToString(b) + + if opt.IsSudo { + go util.StartupPProf(config.SudoPProfPort) + } else { + go util.StartupPProf(config.PProfPort) + } return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) (err error) { defer opt.Stop() + defer func() { + if errors.Is(err, http.ErrServerClosed) { + err = nil + } + }() return opt.Start(cmd.Context()) }, Hidden: true, diff --git a/pkg/config/config.go b/pkg/config/config.go index 93896d00..eb6deafb 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -83,7 +83,8 @@ const ( ManageBy = konfig.ManagedbyLabelKey // pprof port - PProfPort = 32345 + PProfPort = 32345 + SudoPProfPort = 33345 // startup by KubeVPN EnvStartSudoKubeVPNByKubeVPN = "DEPTH_SIGNED_BY_NAISON" diff --git a/pkg/core/tcphandler.go b/pkg/core/tcphandler.go index fa72a293..eca05863 100644 --- a/pkg/core/tcphandler.go +++ b/pkg/core/tcphandler.go @@ -104,7 +104,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) { } else { log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr()) } - h.ch <- dgram + util.SafeWrite(h.ch, dgram) } } diff --git a/pkg/core/tunendpoint.go b/pkg/core/tunendpoint.go index f5e80bd0..168fd38e 100755 --- a/pkg/core/tunendpoint.go +++ b/pkg/core/tunendpoint.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "github.com/wencaiwulue/kubevpn/v2/pkg/config" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config.Engine, in chan<- *DataElem, out chan *DataElem) stack.LinkEndpoint { @@ -37,7 +38,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config i := config.LPool.Get().([]byte)[:] n := copy(i, bb) bb = nil - out <- NewDataElem(i[:], n, nil, nil) + util.SafeWrite(out, NewDataElem(i[:], n, nil, nil)) } } }() @@ -49,7 +50,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config read, err := tun.Read(bytes[:]) if err != nil { if errors.Is(err, os.ErrClosed) { - log.Errorf("[TUN] Error: tun device closed") return } // if context is done @@ -111,7 +111,7 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config log.Debugf("[TUN-%s] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), layers.IPProtocol(ipProtocol).String(), src.String(), dst, read) } else { log.Debugf("[TUN-RAW] IP-Protocol: %s, SRC: %s, DST: %s, Length: %d", layers.IPProtocol(ipProtocol).String(), src.String(), dst, read) - in <- NewDataElem(bytes[:], read, src, dst) + util.SafeWrite(in, NewDataElem(bytes[:], read, src, dst)) } } }() @@ -121,7 +121,6 @@ func NewTunEndpoint(ctx context.Context, tun net.Conn, mtu uint32, engine config config.LPool.Put(elem.Data()[:]) if err != nil { if errors.Is(err, os.ErrClosed) { - log.Errorf("[TUN] Error: tun device closed") return } // if context is done diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 9c69a2a7..57d613ec 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -193,10 +193,10 @@ func (d *Device) readFromTun() { return } if n != 0 { - d.tunInboundRaw <- &DataElem{ + util.SafeWrite(d.tunInboundRaw, &DataElem{ data: b[:], length: n, - } + }) } } } @@ -238,12 +238,16 @@ func (d *Device) parseIPHeader(ctx context.Context) { } log.Debugf("[tun] %s --> %s, length: %d", e.src, e.dst, e.length) - d.tunInbound <- e + util.SafeWrite(d.tunInbound, e) } } func (d *Device) Close() { d.tun.Close() + util.SafeClose(d.tunInbound) + util.SafeClose(d.tunOutbound) + util.SafeClose(d.tunInboundRaw) + util.SafeClose(Chan) } func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) { @@ -300,12 +304,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) { } else { src, dst = srcIPv6, config.RouterIP6 } - in <- &DataElem{ + util.SafeWrite(in, &DataElem{ data: data[:], length: length, src: src, dst: dst, - } + }) } time.Sleep(time.Second) } diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 80c3d3b8..279f92e4 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/wencaiwulue/kubevpn/v2/pkg/config" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) { @@ -24,7 +25,9 @@ func (h *tunHandler) HandleClient(ctx context.Context, tun net.Conn) { engine := h.node.Get(config.ConfigKubeVPNTransportEngine) endpoint := NewTunEndpoint(ctx, tun, uint32(config.DefaultMTU), config.Engine(engine), in, out) stack := NewStack(ctx, endpoint) - go stack.Wait() + defer stack.Destroy() + defer util.SafeClose(in) + defer util.SafeClose(out) d := &ClientDevice{ tun: tun, @@ -84,7 +87,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut go func() { for e := range tunInbound { if e.src.Equal(e.dst) { - tunOutbound <- e + util.SafeWrite(tunOutbound, e) continue } _, err := packetConn.WriteTo(e.data[:e.length], remoteAddr) @@ -104,7 +107,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr)) return } - tunOutbound <- &DataElem{data: b[:], length: n} + util.SafeWrite(tunOutbound, &DataElem{data: b[:], length: n}) } }() diff --git a/pkg/daemon/action/connect-fork.go b/pkg/daemon/action/connect-fork.go index ed9e3965..a83d865f 100644 --- a/pkg/daemon/action/connect-fork.go +++ b/pkg/daemon/action/connect-fork.go @@ -41,7 +41,6 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF var sshConf = util.ParseSshFromRPC(req.SshJump) var transferImage = req.TransferImage - go util.StartupPProf(config.PProfPort) defaultlog.Default().SetOutput(io.Discard) if transferImage { err = util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out) diff --git a/pkg/daemon/action/connect.go b/pkg/daemon/action/connect.go index 6f2fb205..ee67902e 100644 --- a/pkg/daemon/action/connect.go +++ b/pkg/daemon/action/connect.go @@ -61,7 +61,6 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe var sshConf = util.ParseSshFromRPC(req.SshJump) var transferImage = req.TransferImage - go util.StartupPProf(config.PProfPort) defaultlog.Default().SetOutput(io.Discard) if transferImage { err := util.TransferImage(ctx, sshConf, config.OriginImage, req.Image, out) diff --git a/pkg/daemon/handler/ssh.go b/pkg/daemon/handler/ssh.go index 138c3f17..fc28b28a 100644 --- a/pkg/daemon/handler/ssh.go +++ b/pkg/daemon/handler/ssh.go @@ -48,7 +48,7 @@ func (w *wsHandler) handle(ctx context.Context) { ctx, f := context.WithCancel(ctx) defer f() - cli, err := util.DialSshRemote(w.sshConfig) + cli, err := util.DialSshRemote(ctx, w.sshConfig) if err != nil { w.Log("Dial ssh remote error: %v", err) return diff --git a/pkg/handler/cleaner.go b/pkg/handler/cleaner.go index 50b36e3b..ee7703d8 100644 --- a/pkg/handler/cleaner.go +++ b/pkg/handler/cleaner.go @@ -69,6 +69,12 @@ func (c *ConnectOptions) Cleanup() { log.Errorf("can not update ref-count: %v", err) } } + // leave proxy resources + err := c.LeaveProxyResources(ctx) + if err != nil { + log.Errorf("leave proxy resources error: %v", err) + } + for _, function := range c.getRolloutFunc() { if function != nil { if err := function(); err != nil { @@ -76,11 +82,6 @@ func (c *ConnectOptions) Cleanup() { } } } - // leave proxy resources - err := c.LeaveProxyResources(ctx) - if err != nil { - log.Errorf("leave proxy resources error: %v", err) - } if c.cancel != nil { c.cancel() } diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index 11dd5321..cff57a50 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -620,6 +620,10 @@ func Run(ctx context.Context, servers []core.Server) error { errChan <- func() error { svr := servers[i] defer svr.Listener.Close() + go func() { + <-ctx.Done() + svr.Listener.Close() + }() for ctx.Err() == nil { conn, err := svr.Listener.Accept() if err != nil { diff --git a/pkg/util/chan.go b/pkg/util/chan.go new file mode 100644 index 00000000..d6c5a527 --- /dev/null +++ b/pkg/util/chan.go @@ -0,0 +1,29 @@ +package util + +func SafeRead[T any](c chan T) (T, bool) { + defer func() { + if r := recover(); r != nil { + } + }() + tt, ok := <-c + return tt, ok +} + +func SafeWrite[T any](c chan<- T, value T) { + defer func() { + if r := recover(); r != nil { + } + }() + select { + case c <- value: + default: + } +} + +func SafeClose[T any](c chan T) { + defer func() { + if r := recover(); r != nil { + } + }() + close(c) +} diff --git a/pkg/util/chan_test.go b/pkg/util/chan_test.go new file mode 100644 index 00000000..50b8a97a --- /dev/null +++ b/pkg/util/chan_test.go @@ -0,0 +1,23 @@ +package util + +import ( + "fmt" + "testing" + "time" +) + +func TestChanClose(t *testing.T) { + c := make(chan any) + close(c) + SafeWrite(c, nil) + + c = make(chan any) + go func() { + time.AfterFunc(time.Second*3, func() { + close(c) + }) + }() + for a := range c { + fmt.Printf("%v", a) + } +} diff --git a/pkg/util/image.go b/pkg/util/image.go index 0ac22bb1..877fa405 100644 --- a/pkg/util/image.go +++ b/pkg/util/image.go @@ -118,7 +118,7 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge // transfer image to remote var sshClient *ssh.Client - sshClient, err = DialSshRemote(conf) + sshClient, err = DialSshRemote(ctx, conf) if err != nil { return err } diff --git a/pkg/util/ssh.go b/pkg/util/ssh.go index 3c1a4e64..f6ff732e 100644 --- a/pkg/util/ssh.go +++ b/pkg/util/ssh.go @@ -114,7 +114,7 @@ func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) { } // DialSshRemote https://github.com/golang/go/issues/21478 -func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) { +func DialSshRemote(ctx context.Context, conf *SshConfig) (remote *ssh.Client, err error) { defer func() { if err != nil { if remote != nil { @@ -124,21 +124,24 @@ func DialSshRemote(conf *SshConfig) (remote *ssh.Client, err error) { }() if conf.ConfigAlias != "" { - remote, err = conf.AliasRecursion() + remote, err = conf.AliasRecursion(ctx) } else if conf.Jump != "" { - remote, err = conf.JumpRecursion() + remote, err = conf.JumpRecursion(ctx) } else { - remote, err = conf.Dial() + remote, err = conf.Dial(ctx) } // ref: https://github.com/golang/go/issues/21478 if err == nil { go func() { - ticker := time.NewTicker(time.Second * 15) - defer ticker.Stop() defer remote.Close() - for range ticker.C { + for ctx.Err() == nil { + time.Sleep(time.Second * 15) _, _, err := remote.SendRequest("keepalive@golang.org", true, nil) + if err == nil || err.Error() == "request failed" { + // Any response is a success. + continue + } if err != nil { return } @@ -234,7 +237,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) { return ssh.PublicKeys(key), nil } -func copyStream(local net.Conn, remote net.Conn) { +func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { chDone := make(chan bool, 2) // start remote -> local data transfer @@ -265,10 +268,15 @@ func copyStream(local net.Conn, remote net.Conn) { } }() - <-chDone + select { + case <-chDone: + return + case <-ctx.Done(): + return + } } -func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) { +func (config SshConfig) AliasRecursion(ctx context.Context) (client *ssh.Client, err error) { var name = config.ConfigAlias var jumper = "ProxyJump" var bastionList = []SshConfig{GetBastion(name, config)} @@ -283,12 +291,12 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) { } for i := len(bastionList) - 1; i >= 0; i-- { if client == nil { - client, err = bastionList[i].Dial() + client, err = bastionList[i].Dial(ctx) if err != nil { return } } else { - client, err = JumpTo(client, bastionList[i]) + client, err = JumpTo(ctx, client, bastionList[i]) if err != nil { return } @@ -297,7 +305,7 @@ func (config SshConfig) AliasRecursion() (client *ssh.Client, err error) { return } -func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) { +func (config SshConfig) JumpRecursion(ctx context.Context) (client *ssh.Client, err error) { flags := pflag.NewFlagSet("", pflag.ContinueOnError) var sshConf = &SshConfig{} AddSshFlags(flags, sshConf) @@ -306,7 +314,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) { return nil, err } var baseClient *ssh.Client - baseClient, err = DialSshRemote(sshConf) + baseClient, err = DialSshRemote(ctx, sshConf) if err != nil { return nil, err } @@ -331,7 +339,7 @@ func (config SshConfig) JumpRecursion() (client *ssh.Client, err error) { } for _, sshConfig := range bastionList { - client, err = JumpTo(baseClient, sshConfig) + client, err = JumpTo(ctx, baseClient, sshConfig) if err != nil { return } @@ -374,7 +382,7 @@ func GetBastion(name string, defaultValue SshConfig) SshConfig { return config } -func (config SshConfig) Dial() (*ssh.Client, error) { +func (config SshConfig) Dial(ctx context.Context) (client *ssh.Client, err error) { if strings.Index(config.Addr, ":") < 0 { // use default ssh port 22 config.Addr = net.JoinHostPort(config.Addr, "22") @@ -384,16 +392,31 @@ func (config SshConfig) Dial() (*ssh.Client, error) { if err != nil { return nil, err } - return ssh.Dial("tcp", config.Addr, &ssh.ClientConfig{ + conn, err := net.DialTimeout("tcp", config.Addr, time.Second*10) + if err != nil { + return nil, err + } + go func() { + <-ctx.Done() + conn.Close() + if client != nil { + client.Close() + } + }() + c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{ User: config.User, Auth: authMethod, HostKeyCallback: ssh.InsecureIgnoreHostKey(), BannerCallback: ssh.BannerDisplayStderr(), Timeout: time.Second * 10, }) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil } -func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) { +func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) { if strings.Index(to.Addr, ":") < 0 { // use default ssh port 22 to.Addr = net.JoinHostPort(to.Addr, "22") @@ -410,6 +433,14 @@ func JumpTo(bClient *ssh.Client, to SshConfig) (client *ssh.Client, err error) { if err != nil { return } + go func() { + <-ctx.Done() + conn.Close() + if client != nil { + client.Close() + } + bClient.Close() + }() defer func() { if err != nil { if client != nil { @@ -495,7 +526,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr } sshClient.Close() } - sshClient, err = DialSshRemote(conf) + sshClient, err = DialSshRemote(ctx, conf) if err != nil { log.Errorf("failed to dial remote ssh server: %v", err) return nil, err @@ -505,11 +536,20 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr go func() { defer localListen.Close() + go func() { + <-ctx.Done() + localListen.Close() + if sshClient != nil { + sshClient.Close() + } + }() for ctx.Err() == nil { localConn, err := localListen.Accept() if err != nil { - log.Errorf("failed to accept conn: %v", err) + if !errors.Is(err, net.ErrClosed) { + log.Errorf("failed to accept conn: %v", err) + } return } go func() { @@ -521,7 +561,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr return } defer remoteConn.Close() - copyStream(localConn, remoteConn) + copyStream(ctx, localConn, remoteConn) }() } }() @@ -551,7 +591,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b // pre-check network ip connect var cli *ssh.Client - cli, err = DialSshRemote(conf) + cli, err = DialSshRemote(ctx, conf) if err != nil { return }