diff --git a/pkg/core/gvisortcphandler.go b/pkg/core/gvisortcphandler.go index d0bfa976..52b24fd3 100644 --- a/pkg/core/gvisortcphandler.go +++ b/pkg/core/gvisortcphandler.go @@ -81,5 +81,6 @@ func GvisorTCPListener(addr string) (net.Listener, error) { _ = listener.Close() return nil, err } + plog.G(context.Background()).Debugf("Use tls mode") return tls.NewListener(&tcpKeepAliveListener{TCPListener: listener}, serverConfig), nil } diff --git a/pkg/core/tcp.go b/pkg/core/tcp.go index f3c8df96..af85247d 100644 --- a/pkg/core/tcp.go +++ b/pkg/core/tcp.go @@ -38,7 +38,7 @@ func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, erro plog.G(ctx).Debugf("tls config not found in config, use raw tcp mode") return conn, nil } - plog.G(ctx).Debugf("use tls mode") + plog.G(ctx).Debugf("Use tls mode") return tls.Client(conn, tr.tlsConfig), nil } diff --git a/pkg/daemon/handler/ssh.go b/pkg/daemon/handler/ssh.go index 0baa17c5..050f60c2 100644 --- a/pkg/daemon/handler/ssh.go +++ b/pkg/daemon/handler/ssh.go @@ -150,8 +150,8 @@ func (w *wsHandler) createTwoWayTUNTunnel(ctx context.Context, cli *ssh.Client) plog.G(ctx).Info("Connected private safe tunnel") go func() { for ctx.Err() == nil { - util.Ping(ctx, clientIP.IP.String(), ip.String()) - time.Sleep(time.Second * 5) + util.PingOnce(ctx, clientIP.IP.String(), ip.String()) + time.Sleep(time.Second * 15) } }() return nil @@ -436,7 +436,7 @@ func init() { if errors.Is(err, io.EOF) { return } else if err != nil { - plog.G(context.Background()).Errorf("Session %s windos change w: %d h: %d failed: %v", sessionID, r.Width, r.Height, err) + plog.G(context.Background()).Errorf("Session %s windows change w: %d h: %d failed: %v", sessionID, r.Width, r.Height, err) } } })) diff --git a/pkg/handler/proxy.go b/pkg/handler/proxy.go index 06a34a96..bb962234 100644 --- a/pkg/handler/proxy.go +++ b/pkg/handler/proxy.go @@ -162,7 +162,11 @@ func (m *Mapper) Run(connectNamespace string) { local := netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(containerPort)) remote := netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(envoyRulePort)) for ctx.Err() == nil { - _ = ssh.ExposeLocalPortToRemote(ctx, remoteSSHServer, remote, local) + func() { + ctx2, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + _ = ssh.ExposeLocalPortToRemote(ctx2, remoteSSHServer, remote, local) + }() time.Sleep(time.Second * 2) } }(containerPort, envoyRulePort) diff --git a/pkg/ssh/reverse.go b/pkg/ssh/reverse.go index 68943b13..3e5042de 100644 --- a/pkg/ssh/reverse.go +++ b/pkg/ssh/reverse.go @@ -37,6 +37,7 @@ func ExposeLocalPortToRemote(ctx context.Context, remoteSSHServer, remotePort, l plog.G(ctx).Errorf("Dial into remote server error: %s", err) return err } + defer serverConn.Close() // Listen on remote server port listener, err := serverConn.Listen("tcp", remotePort.String()) @@ -46,6 +47,12 @@ func ExposeLocalPortToRemote(ctx context.Context, remoteSSHServer, remotePort, l } defer listener.Close() + go func() { + <-ctx.Done() + listener.Close() + serverConn.Close() + }() + // handle incoming connections on reverse forwarded tunnel for { client, err := listener.Accept() diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index b5def737..de2e41a9 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -91,10 +91,16 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr var lc net.ListenConfig localListen, e := lc.Listen(ctx, "tcp", local.String()) if e != nil { + plog.G(ctx).Errorf("failed to listen %s: %v", local.String(), e) return e } plog.G(ctx).Debugf("SSH listening on local %s forward to %s", local.String(), remote.String()) + go func() { + <-ctx.Done() + localListen.Close() + }() + go func() { defer localListen.Close() @@ -108,6 +114,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1) continue } + plog.G(ctx).Debugf("Accepted ssh conn from %s", localConn.RemoteAddr().String()) go func() { defer localConn.Close() @@ -117,11 +124,13 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr // if ssh server not permitted ssh port-forward, do nothing until exit if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited { plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err) + plog.G(ctx).Errorf("Failed to open ssh port-forward: %s: %v", remote.String(), err) cancelFunc1() } plog.G(ctx).Debugf("Failed to get remote conn: %v", err) return } + plog.G(ctx).Debugf("Opened ssh port-forward: %s", remote.String()) defer remoteConn.Close() copyStream(ctx, localConn, remoteConn) @@ -182,7 +191,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b err = errors.Wrap(err, string(stderr)) return } - if len(stdout) == 0 { + if len(bytes.TrimSpace(stdout)) == 0 { err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(stderr)) return } @@ -217,28 +226,34 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b var rawConfig api.Config rawConfig, err = matchVersionFlags.ToRawKubeConfigLoader().RawConfig() if err != nil { + plog.G(ctx).WithError(err).Errorf("failed to build config: %v", err) return } if err = api.FlattenConfig(&rawConfig); err != nil { + plog.G(ctx).Errorf("failed to flatten config: %v", err) return } if rawConfig.Contexts == nil { err = errors.New("kubeconfig is invalid") + plog.G(ctx).Error("can not get contexts") return } kubeContext := rawConfig.Contexts[rawConfig.CurrentContext] if kubeContext == nil { err = errors.New("kubeconfig is invalid") + plog.G(ctx).Errorf("can not find kubeconfig context %s", rawConfig.CurrentContext) return } cluster := rawConfig.Clusters[kubeContext.Cluster] if cluster == nil { err = errors.New("kubeconfig is invalid") + plog.G(ctx).Errorf("can not find cluster %s", kubeContext.Cluster) return } var u *url.URL u, err = url.Parse(cluster.Server) if err != nil { + plog.G(ctx).Errorf("failed to parse cluster url: %v", err) return } @@ -252,6 +267,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b } else { // handle other schemes if necessary err = errors.New("kubeconfig is invalid: wrong protocol") + plog.G(ctx).Error(err) return } } @@ -263,6 +279,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b if len(ips) == 0 { // handle error: no IP associated with the hostname err = fmt.Errorf("kubeconfig: no IP associated with the hostname %s", serverHost) + plog.G(ctx).Error(err) return } @@ -306,11 +323,13 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b var convertedObj runtime.Object convertedObj, err = latest.Scheme.ConvertToVersion(&rawConfig, latest.ExternalVersion) if err != nil { + plog.G(ctx).Errorf("failed to build config: %v", err) return } var marshal []byte marshal, err = json.Marshal(convertedObj) if err != nil { + plog.G(ctx).Errorf("failed to marshal config: %v", err) return } var temp *os.File @@ -407,7 +426,9 @@ func JumpTo(ctx context.Context, bClient *gossh.Client, to SshConfig, stopChan < return } -func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) { +func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (net.Conn, error) { + var conn net.Conn + var err error clientMap.Range(func(key, value any) bool { cli := value.(*sshClientWrap) ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10) @@ -416,40 +437,35 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re if err != nil { plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err) clientMap.Delete(key) + plog.G(ctx).Error("Delete invalid ssh client from map") _ = cli.Close() return true } return false }) if conn != nil { - return + return conn, nil } ctx1, cancelFunc1 := context.WithCancel(ctx) - defer func() { - if err != nil { - cancelFunc1() - } - }() - - ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc2() var client *gossh.Client - client, err = DialSshRemote(ctx2, conf, ctx1.Done()) + client, err = DialSshRemote(ctx1, conf, ctx1.Done()) if err != nil { plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err) - return nil, err - } - - ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc3() - conn, err = client.DialContext(ctx3, "tcp", remote.String()) - if err != nil { - plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) - _ = client.Close() + cancelFunc1() return nil, err } clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1)) + plog.G(ctx1).Debug("Connected to remote SSH server") + + ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc2() + conn, err = client.DialContext(ctx2, "tcp", remote.String()) + if err != nil { + plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) + return nil, err + } + plog.G(ctx).Debugf("Connected to remote addr: %s", remote.String()) return conn, nil } @@ -462,7 +478,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { defer config.LPool.Put(buf[:]) _, err := io.CopyBuffer(local, remote, buf) if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { - plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err) + plog.G(ctx).Errorf("Failed to copy remote -> local: %s", err) } chDone <- true }() @@ -473,7 +489,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { defer config.LPool.Put(buf[:]) _, err := io.CopyBuffer(remote, local, buf) if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { - plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err) + plog.G(ctx).Errorf("Failed to copy local -> remote: %s", err) } chDone <- true }()