diff --git a/cmd/kubevpn/cmds/ssh.go b/cmd/kubevpn/cmds/ssh.go index 370d3774..3b7a5117 100644 --- a/cmd/kubevpn/cmds/ssh.go +++ b/cmd/kubevpn/cmds/ssh.go @@ -1,6 +1,7 @@ package cmds import ( + "context" "encoding/json" "fmt" "io" @@ -87,10 +88,27 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command { return err } defer conn.Close() - go moniteSize(sessionID) - go io.Copy(conn, os.Stdin) - _, err = io.Copy(os.Stdout, conn) - return err + + errChan := make(chan error, 3) + go func() { + err := monitorSize(cmd.Context(), sessionID) + errChan <- err + }() + go func() { + _, err := io.Copy(conn, os.Stdin) + errChan <- err + }() + go func() { + _, err := io.Copy(os.Stdout, conn) + errChan <- err + }() + + select { + case err = <-errChan: + return err + case <-cmd.Context().Done(): + return cmd.Context().Err() + } }, } addSshFlags(cmd, sshConf) @@ -98,7 +116,7 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command { return cmd } -func moniteSize(sessionID string) error { +func monitorSize(ctx context.Context, sessionID string) error { conn := daemon.GetTCPClient(true) if conn == nil { return fmt.Errorf("conn is nil") @@ -125,14 +143,15 @@ func moniteSize(sessionID string) error { return err } encoder := json.NewEncoder(client) - for { + for ctx.Err() == nil { size := sizeQueue.Next() if size == nil { return nil } if err = encoder.Encode(&size); err != nil { log.Errorf("Encode resize: %s", err) - continue + return err } } + return nil } diff --git a/pkg/daemon/action/config.go b/pkg/daemon/action/config.go index ae27be4b..b4f8a84d 100644 --- a/pkg/daemon/action/config.go +++ b/pkg/daemon/action/config.go @@ -4,10 +4,10 @@ import ( "context" "github.com/spf13/pflag" - "github.com/wencaiwulue/kubevpn/v2/pkg/handler" - "github.com/wencaiwulue/kubevpn/v2/pkg/util" "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" + "github.com/wencaiwulue/kubevpn/v2/pkg/handler" + "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) var CancelFunc = make(map[string]context.CancelFunc) diff --git a/pkg/daemon/action/connect-fork.go b/pkg/daemon/action/connect-fork.go index ee5c8d72..513d6dca 100644 --- a/pkg/daemon/action/connect-fork.go +++ b/pkg/daemon/action/connect-fork.go @@ -202,6 +202,7 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp } log.Info(recv.Message) } + svr.secondaryConnect = append(svr.secondaryConnect[:i], svr.secondaryConnect[i+1:]...) break } } diff --git a/pkg/daemon/action/connect.go b/pkg/daemon/action/connect.go index e3b451a4..991838f7 100644 --- a/pkg/daemon/action/connect.go +++ b/pkg/daemon/action/connect.go @@ -25,7 +25,7 @@ import ( "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) -func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServer) error { +func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServer) (e error) { defer func() { log.SetOutput(svr.LogFile) log.SetLevel(log.DebugLevel) @@ -45,6 +45,15 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe // todo define already connect error? return status.Error(codes.AlreadyExists, s) } + defer func() { + if e != nil { + if svr.connect != nil { + svr.connect.Cleanup() + svr.connect = nil + } + svr.t = time.Time{} + } + }() svr.t = time.Now() svr.connect = &handler.ConnectOptions{ Namespace: req.Namespace, diff --git a/pkg/daemon/handler/ssh.go b/pkg/daemon/handler/ssh.go index ce7ef476..c7efbd84 100644 --- a/pkg/daemon/handler/ssh.go +++ b/pkg/daemon/handler/ssh.go @@ -43,14 +43,18 @@ type wsHandler struct { // 1) start remote kubevpn server // 2) start local tunnel // 3) ssh terminal -func (w *wsHandler) handle(ctx2 context.Context) { - conn := w.conn - sshConfig := w.sshConfig - cidr := w.cidr - ctx, cancelFunc := context.WithCancel(ctx2) - defer cancelFunc() +func (w *wsHandler) handle(ctx context.Context) { + ctx, f := context.WithCancel(ctx) + defer f() - err := w.remoteInstallKubevpnIfCommandNotFound(ctx, sshConfig) + cli, err := util.DialSshRemote(ctx, w.sshConfig) + if err != nil { + w.Log("Dial ssh remote error: %v", err) + return + } + defer cli.Close() + + err = w.installKubevpnOnRemote(ctx, cli) if err != nil { w.Log("Install kubevpn error: %v", err) return @@ -62,16 +66,32 @@ func (w *wsHandler) handle(ctx2 context.Context) { return } - local, err := w.portMap(ctx, sshConfig) + remotePort := 10800 + var localPort int + localPort, err = util.GetAvailableTCPPortOrDie() + if err != nil { + return + } + var remote netip.AddrPort + remote, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) + if err != nil { + return + } + var local netip.AddrPort + local, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(localPort))) + if err != nil { + return + } + err = util.PortMapUntil(ctx, cli, remote, local) if err != nil { w.Log("Port map error: %v", err) return } // startup daemon process if daemon process not start startDaemonCmd := fmt.Sprintf(`export %s=%s && kubevpn get service > /dev/null 2>&1 &`, config.EnvStartSudoKubeVPNByKubeVPN, "true") - _, _, _ = util.RemoteRun(sshConfig, startDaemonCmd, nil) + util.RemoteRun(cli, startDaemonCmd, nil) cmd := fmt.Sprintf(`export %s=%s && kubevpn ssh-daemon --client-ip %s`, config.EnvStartSudoKubeVPNByKubeVPN, "true", clientIP.String()) - serverIP, stderr, err := util.RemoteRun(sshConfig, cmd, nil) + serverIP, stderr, err := util.RemoteRun(cli, cmd, nil) if err != nil { log.Errorf("run error: %v", err) log.Errorf("run stdout: %v", string(serverIP)) @@ -86,12 +106,12 @@ func (w *wsHandler) handle(ctx2 context.Context) { } msg := fmt.Sprintf("| You can use client: %s to communicate with server: %s |", clientIP.IP.String(), ip.String()) w.PrintLine(msg) - cidr = append(cidr, string(serverIP)) + w.cidr = append(w.cidr, string(serverIP)) r := core.Route{ ServeNodes: []string{ - fmt.Sprintf("tun:/127.0.0.1:8422?net=%s&route=%s", clientIP, strings.Join(cidr, ",")), + fmt.Sprintf("tun:/127.0.0.1:8422?net=%s&route=%s", clientIP, strings.Join(w.cidr, ",")), }, - ChainNode: fmt.Sprintf("tcp://127.0.0.1:%d", local), + ChainNode: fmt.Sprintf("tcp://127.0.0.1:%d", localPort), Retries: 5, } servers, err := handler.Parse(r) @@ -124,78 +144,14 @@ func (w *wsHandler) handle(ctx2 context.Context) { } } }() - err = w.terminal(ctx, sshConfig, conn) + err = w.terminal(ctx, cli, w.conn) if err != nil { w.Log("Enter terminal error: %v", err) } return } -func (w *wsHandler) portMap(ctx context.Context, conf *util.SshConfig) (localPort int, err error) { - remotePort := 10800 - localPort, err = util.GetAvailableTCPPortOrDie() - if err != nil { - return - } - var remote netip.AddrPort - remote, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) - if err != nil { - return - } - var local netip.AddrPort - local, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(localPort))) - if err != nil { - return - } - - // pre-check network ip connect - var cli *ssh.Client - cli, err = util.DialSshRemote(conf) - if err != nil { - return - } else { - _ = cli.Close() - } - errChan := make(chan error, 1) - readyChan := make(chan struct{}, 1) - go func() { - for { - select { - case <-ctx.Done(): - return - default: - } - - err := util.Main(ctx, remote, local, conf, readyChan) - if err != nil { - if !errors.Is(err, context.Canceled) { - log.Errorf("ssh forward failed err: %v", err) - w.Log("Ssh forward failed err: %v", err) - } - select { - case errChan <- err: - default: - } - } - time.Sleep(time.Second * 2) - } - }() - select { - case <-readyChan: - return - case err = <-errChan: - w.Log("Ssh forward failed err: %v", err) - log.Errorf("ssh proxy err: %v", err) - return - } -} - -func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *websocket.Conn) error { - cli, err := util.DialSshRemote(conf) - if err != nil { - w.Log("Dial remote error: %v", err) - return err - } +func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn *websocket.Conn) error { session, err := cli.NewSession() if err != nil { w.Log("New session error: %v", err) @@ -205,7 +161,6 @@ func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *we go func() { <-ctx.Done() session.Close() - cli.Close() }() session.Stdout = conn session.Stderr = conn @@ -232,11 +187,11 @@ func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *we return session.Wait() } -func (w *wsHandler) remoteInstallKubevpnIfCommandNotFound(ctx context.Context, sshConfig *util.SshConfig) error { +func (w *wsHandler) installKubevpnOnRemote(ctx context.Context, sshClient *ssh.Client) error { cmd := `hash kubevpn || type kubevpn || which kubevpn || command -v kubevpn` - _, _, err := util.RemoteRun(sshConfig, cmd, nil) + _, _, err := util.RemoteRun(sshClient, cmd, nil) if err == nil { - w.Log("Remote kubevpn command found, not needs to install") + w.Log("Remote kubevpn command found, use it") return nil } log.Infof("remote kubevpn command not found, try to install it...") @@ -293,12 +248,12 @@ func (w *wsHandler) remoteInstallKubevpnIfCommandNotFound(ctx context.Context, s "chmod +x ~/.kubevpn/kubevpn", "sudo mv ~/.kubevpn/kubevpn /usr/local/bin/kubevpn", } - err = util.SCP(w.conn, w.conn, sshConfig, tempBin.Name(), "kubevpn", cmds...) + err = util.SCPAndExec(w.conn, w.conn, sshClient, tempBin.Name(), "kubevpn", cmds...) if err != nil { return err } // try to startup daemon process - go util.RemoteRun(sshConfig, "kubevpn get pods", nil) + go util.RemoteRun(sshClient, "kubevpn get pods", nil) return nil } @@ -354,12 +309,15 @@ func init() { condReady: cancelFunc, } CondReady[sessionID] = ctx + defer conn.Close() h.handle(conn.Request().Context()) })) http.Handle("/resize", websocket.Handler(func(conn *websocket.Conn) { sessionID := conn.Request().Header.Get("session-id") log.Infof("resize: %s", sessionID) + defer conn.Close() + var session *ssh.Session select { case <-conn.Request().Context().Done(): diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index ee922152..b44cd4f1 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -759,11 +759,18 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr } }() + // pre-check network ip connect + var cli *ssh.Client + cli, err = util.DialSshRemote(ctx, conf) + if err != nil { + return + } + configFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag() if conf.RemoteKubeconfig != "" || (flags != nil && flags.Changed("remote-kubeconfig")) { - var stdOut []byte - var errOut []byte + var stdout []byte + var stderr []byte if len(conf.RemoteKubeconfig) != 0 && conf.RemoteKubeconfig[0] == '~' { conf.RemoteKubeconfig = filepath.Join("/home", conf.User, conf.RemoteKubeconfig[1:]) } @@ -771,18 +778,18 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr // if `--remote-kubeconfig` is parsed then Entrypoint is reset conf.RemoteKubeconfig = filepath.Join("/home", conf.User, clientcmd.RecommendedHomeDir, clientcmd.RecommendedFileName) } - stdOut, errOut, err = util.RemoteRun(conf, + stdout, stderr, err = util.RemoteRun(cli, fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s'", conf.RemoteKubeconfig, conf.RemoteKubeconfig), map[string]string{clientcmd.RecommendedConfigPathEnvVar: conf.RemoteKubeconfig}, ) if err != nil { - err = errors.Wrap(err, string(errOut)) + err = errors.Wrap(err, string(stderr)) return } - if len(stdOut) == 0 { - err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(errOut)) + if len(stdout) == 0 { + err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(stderr)) return } @@ -793,7 +800,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr if err = temp.Close(); err != nil { return } - if err = os.WriteFile(temp.Name(), stdOut, 0644); err != nil { + if err = os.WriteFile(temp.Name(), stdout, 0644); err != nil { return } if err = os.Chmod(temp.Name(), 0644); err != nil { @@ -871,7 +878,6 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr if err != nil { return } - var port int port, err = util.GetAvailableTCPPortOrDie() if err != nil { @@ -883,43 +889,11 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr return } - // pre-check network ip connect - var cli *ssh.Client - cli, err = util.DialSshRemote(conf) - if err != nil { - return - } else { - _ = cli.Close() - } - errChan := make(chan error, 1) - readyChan := make(chan struct{}, 1) - go func() { - for { - select { - case <-ctx.Done(): - return - default: - } - - err := util.Main(ctx, remote, local, conf, readyChan) - if err != nil { - if !errors.Is(err, context.Canceled) { - log.Errorf("ssh forward failed err: %v", err) - } - select { - case errChan <- err: - default: - } - } - time.Sleep(time.Second * 2) - } - }() if print { log.Infof("wait jump to bastion host...") } - select { - case <-readyChan: - case err = <-errChan: + err = util.PortMapUntil(ctx, cli, remote, local) + if err != nil { log.Errorf("ssh proxy err: %v", err) return } diff --git a/pkg/util/image.go b/pkg/util/image.go index f289debd..e6e6542b 100644 --- a/pkg/util/image.go +++ b/pkg/util/image.go @@ -20,6 +20,7 @@ import ( "github.com/opencontainers/image-spec/specs-go/v1" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ssh" ) func GetClient() (*client.Client, *command.DockerCli, error) { @@ -109,6 +110,11 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge } // transfer image to remote + var sshClient *ssh.Client + sshClient, err = DialSshRemote(ctx, conf) + if err != nil { + return err + } var responseReader io.ReadCloser responseReader, err = cli.ImageSave(ctx, []string{imageTarget}) if err != nil { @@ -132,12 +138,12 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge logrus.Infof("Transferring image %s", imageTarget) filename := filepath.Base(file.Name()) cmd := fmt.Sprintf( - "(docker load image -i ~/.kubevpn/%s && docker push %s) || (nerdctl image load -i ~/.kubevpn/%s && nerdctl image push %s)", + "(docker load -i ~/.kubevpn/%s && docker push %s) || (nerdctl image load -i ~/.kubevpn/%s && nerdctl image push %s)", filename, imageTarget, filename, imageTarget, ) stdout := log.StandardLogger().Out - err = SCP(stdout, stdout, conf, file.Name(), filename, []string{cmd}...) + err = SCPAndExec(stdout, stdout, sshClient, file.Name(), filename, []string{cmd}...) if err != nil { return err } diff --git a/pkg/util/scp.go b/pkg/util/scp.go index 29cc28d4..d5601814 100644 --- a/pkg/util/scp.go +++ b/pkg/util/scp.go @@ -10,29 +10,20 @@ import ( "golang.org/x/crypto/ssh" ) -// SCP copy file to remote and exec command -func SCP(stdout, stderr io.Writer, conf *SshConfig, filename, to string, commands ...string) error { - remote, err := DialSshRemote(conf) - if err != nil { - log.Errorf("Dial into remote server error: %s", err) - return err - } - - sess, err := remote.NewSession() - if err != nil { - return err - } - err = main(sess, stdout, stderr, filename, to) +// SCPAndExec copy file to remote and exec command +func SCPAndExec(stdout, stderr io.Writer, client *ssh.Client, filename, to string, commands ...string) error { + err := SCP(client, stdout, stderr, filename, to) if err != nil { log.Errorf("Copy file to remote error: %s", err) return err } for _, command := range commands { - sess, err = remote.NewSession() + var session *ssh.Session + session, err = client.NewSession() if err != nil { return err } - output, err := sess.CombinedOutput(command) + output, err := session.CombinedOutput(command) if err != nil { log.Error(string(output)) return err @@ -43,24 +34,28 @@ func SCP(stdout, stderr io.Writer, conf *SshConfig, filename, to string, command return nil } -// https://blog.neilpang.com/%E6%94%B6%E8%97%8F-scp-secure-copy%E5%8D%8F%E8%AE%AE/ -func main(sess *ssh.Session, stdout, stderr io.Writer, filename string, to string) error { - open, err := os.Open(filename) +// SCP https://blog.neilpang.com/%E6%94%B6%E8%97%8F-scp-secure-copy%E5%8D%8F%E8%AE%AE/ +func SCP(client *ssh.Client, stdout, stderr io.Writer, filename, to string) error { + file, err := os.Open(filename) if err != nil { return err } - stat, err := open.Stat() + defer file.Close() + stat, err := file.Stat() + if err != nil { + return err + } + sess, err := client.NewSession() if err != nil { return err } - defer open.Close() defer sess.Close() go func() { w, _ := sess.StdinPipe() defer w.Close() fmt.Fprintln(w, "D0755", 0, ".kubevpn") // mkdir fmt.Fprintln(w, "C0644", stat.Size(), to) - err := sCopy(w, open, stat.Size(), stdout, stderr) + err := sCopy(w, file, stat.Size(), stdout, stderr) if err != nil { log.Errorf("failed to transfer file to remote: %v", err) return @@ -72,9 +67,9 @@ func main(sess *ssh.Session, stdout, stderr io.Writer, filename string, to strin func sCopy(dst io.Writer, src io.Reader, size int64, stdout, stderr io.Writer) error { total := float64(size) / 1024 / 1024 - s := fmt.Sprintf("Length: %d (%0.2fM)\n", size, total) + s := fmt.Sprintf("Length: %d (%0.2fM)", size, total) log.Info(s) - io.WriteString(stdout, s) + io.WriteString(stdout, s+"\n") bar := progressbar.NewOptions(int(size), progressbar.OptionSetWriter(stdout), diff --git a/pkg/util/ssh.go b/pkg/util/ssh.go index ab1bcd27..71097151 100644 --- a/pkg/util/ssh.go +++ b/pkg/util/ssh.go @@ -67,31 +67,7 @@ func (s *SshConfig) ToRPC() *rpc.SshJump { } } -func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, conf *SshConfig, done chan struct{}) error { - ctx, cancelFunc := context.WithCancel(pctx) - defer cancelFunc() - - sshClient, err := DialSshRemote(conf) - if err != nil { - log.Errorf("Dial into remote server error: %s", err) - return err - } - - // ref: https://github.com/golang/go/issues/21478 - go func() { - ticker := time.NewTicker(time.Second * 5) - defer ticker.Stop() - select { - case <-ticker.C: - _, _, err := sshClient.SendRequest("keepalive@golang.org", true, nil) - if err != nil { - log.Errorf("failed to send keep alive error: %s", err) - } - case <-ctx.Done(): - return - } - }() - +func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error { // Listen on remote server port var lc net.ListenConfig listen, err := lc.Listen(ctx, "tcp", localEndpoint.String()) @@ -105,34 +81,27 @@ func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, co default: } // handle incoming connections on reverse forwarded tunnel - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - local, err := listen.Accept() + for ctx.Err() == nil { + localConn, err := listen.Accept() if err != nil { return err } - go func(local net.Conn) { - defer local.Close() - conn, err := sshClient.Dial("tcp", remoteEndpoint.String()) + go func(localConn net.Conn) { + defer localConn.Close() + remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String()) if err != nil { log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err) - cancelFunc() return } - defer conn.Close() - handleClient(local, conn) - }(local) + defer remoteConn.Close() + copyStream(localConn, remoteConn) + }(localConn) } + return ctx.Err() } -// todo ssh heartbeats -// https://github.com/golang/go/issues/21478 -func DialSshRemote(conf *SshConfig) (*ssh.Client, error) { +// DialSshRemote https://github.com/golang/go/issues/21478 +func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) { var remote *ssh.Client var err error if conf.ConfigAlias != "" { @@ -189,22 +158,33 @@ func DialSshRemote(conf *SshConfig) (*ssh.Client, error) { // Connect to SSH remote server using serverEndpoint remote, err = ssh.Dial("tcp", conf.Addr, sshConfig) } + + // ref: https://github.com/golang/go/issues/21478 + if err == nil { + go func() { + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() + for ctx.Err() == nil { + select { + case <-ticker.C: + _, _, er := remote.SendRequest("keepalive@golang.org", true, nil) + if er != nil { + log.Errorf("failed to send keep alive error: %s", er) + } + } + } + }() + } return remote, err } -func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byte, errOut []byte, err error) { - var remote *ssh.Client - remote, err = DialSshRemote(conf) - if err != nil { - log.Errorf("Dial into remote server error: %s", err) - return - } - defer remote.Close() +func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) { var session *ssh.Session - session, err = remote.NewSession() + session, err = client.NewSession() if err != nil { return } + defer session.Close() for k, v := range env { // /etc/ssh/sshd_config // AcceptEnv DEBIAN_FRONTEND @@ -213,7 +193,6 @@ func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byt err = nil } } - defer remote.Close() var out bytes.Buffer var er bytes.Buffer session.Stdout = &out @@ -246,7 +225,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) { return ssh.PublicKeys(key), nil } -func handleClient(client net.Conn, remote net.Conn) { +func copyStream(client net.Conn, remote net.Conn) { chDone := make(chan bool, 2) // start remote -> local data transfer @@ -408,3 +387,32 @@ func init() { } }) } + +func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error { + errChan := make(chan error, 1) + readyChan := make(chan struct{}, 1) + go func() { + for ctx.Err() == nil { + err := PortMap(ctx, cli, remote, local, readyChan) + if err != nil { + if !errors.Is(err, context.Canceled) { + log.Errorf("Ssh forward failed err: %v", err) + } + select { + case errChan <- err: + default: + } + } + time.Sleep(time.Second * 2) + } + }() + select { + case <-readyChan: + return nil + case err := <-errChan: + log.Errorf("Ssh forward failed err: %v", err) + return err + case <-ctx.Done(): + return ctx.Err() + } +}