From 1cae5d270b8363609b66a92ddc0b135277589f88 Mon Sep 17 00:00:00 2001 From: naison <895703375@qq.com> Date: Mon, 21 Apr 2025 22:19:31 +0800 Subject: [PATCH] refactor: optimize ssh logic (#555) --- pkg/ssh/config.go | 391 ++++++++++++++++++++++++++ pkg/ssh/ssh.go | 680 ++++++++++------------------------------------ 2 files changed, 542 insertions(+), 529 deletions(-) create mode 100644 pkg/ssh/config.go diff --git a/pkg/ssh/config.go b/pkg/ssh/config.go new file mode 100644 index 00000000..eb46324c --- /dev/null +++ b/pkg/ssh/config.go @@ -0,0 +1,391 @@ +package ssh + +import ( + "context" + "fmt" + "io" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + + "github.com/kevinburke/ssh_config" + "github.com/pkg/errors" + "github.com/spf13/pflag" + "golang.org/x/crypto/ssh" + "k8s.io/client-go/util/homedir" + + "github.com/wencaiwulue/kubevpn/v2/pkg/config" + "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" +) + +type SshConfig struct { + Addr string + User string + Password string + Keyfile string + Jump string + ConfigAlias string + RemoteKubeconfig string + // GSSAPI + GSSAPIKeytabConf string + GSSAPIPassword string + GSSAPICacheFile string +} + +func (conf SshConfig) Clone() SshConfig { + return SshConfig{ + Addr: conf.Addr, + User: conf.User, + Password: conf.Password, + Keyfile: conf.Keyfile, + Jump: conf.Jump, + ConfigAlias: conf.ConfigAlias, + RemoteKubeconfig: conf.RemoteKubeconfig, + GSSAPIKeytabConf: conf.GSSAPIKeytabConf, + GSSAPIPassword: conf.GSSAPIPassword, + GSSAPICacheFile: conf.GSSAPICacheFile, + } +} + +func ParseSshFromRPC(sshJump *rpc.SshJump) *SshConfig { + if sshJump == nil { + return &SshConfig{} + } + return &SshConfig{ + Addr: sshJump.Addr, + User: sshJump.User, + Password: sshJump.Password, + Keyfile: sshJump.Keyfile, + Jump: sshJump.Jump, + ConfigAlias: sshJump.ConfigAlias, + RemoteKubeconfig: sshJump.RemoteKubeconfig, + GSSAPIKeytabConf: sshJump.GSSAPIKeytabConf, + GSSAPIPassword: sshJump.GSSAPIPassword, + GSSAPICacheFile: sshJump.GSSAPICacheFile, + } +} + +func (conf SshConfig) ToRPC() *rpc.SshJump { + return &rpc.SshJump{ + Addr: conf.Addr, + User: conf.User, + Password: conf.Password, + Keyfile: conf.Keyfile, + Jump: conf.Jump, + ConfigAlias: conf.ConfigAlias, + RemoteKubeconfig: conf.RemoteKubeconfig, + GSSAPIKeytabConf: conf.GSSAPIKeytabConf, + GSSAPIPassword: conf.GSSAPIPassword, + GSSAPICacheFile: conf.GSSAPICacheFile, + } +} + +func (conf SshConfig) IsEmpty() bool { + return conf.ConfigAlias == "" && conf.Addr == "" && conf.Jump == "" +} + +func (conf SshConfig) GetAuth() ([]ssh.AuthMethod, error) { + host, _, _ := net.SplitHostPort(conf.Addr) + var auth []ssh.AuthMethod + var c Krb5InitiatorClient + var err error + var krb5Conf = GetKrb5Path() + if conf.Password != "" { + auth = append(auth, ssh.Password(conf.Password)) + } else if conf.GSSAPIPassword != "" { + c, err = NewKrb5InitiatorClientWithPassword(conf.User, conf.GSSAPIPassword, krb5Conf) + if err != nil { + return nil, err + } + auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host)) + } else if conf.GSSAPIKeytabConf != "" { + c, err = NewKrb5InitiatorClientWithKeytab(conf.User, krb5Conf, conf.GSSAPIKeytabConf) + if err != nil { + return nil, err + } + } else if conf.GSSAPICacheFile != "" { + c, err = NewKrb5InitiatorClientWithCache(krb5Conf, conf.GSSAPICacheFile) + if err != nil { + return nil, err + } + auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host)) + } else { + if conf.Keyfile == "" { + conf.Keyfile = filepath.Join(homedir.HomeDir(), ".ssh", "id_rsa") + } + var keyFile ssh.AuthMethod + keyFile, err = publicKeyFile(conf.Keyfile) + if err != nil { + return nil, err + } + auth = append(auth, keyFile) + } + return auth, nil +} + +func publicKeyFile(file string) (ssh.AuthMethod, error) { + var err error + if len(file) != 0 && file[0] == '~' { + file = filepath.Join(homedir.HomeDir(), file[1:]) + } + file, err = filepath.Abs(file) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file)) + return nil, err + } + buffer, err := os.ReadFile(file) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file)) + return nil, err + } + + key, err := ssh.ParsePrivateKey(buffer) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Cannot parse SSH public key file %s", file)) + return nil, err + } + return ssh.PublicKeys(key), nil +} + +func (conf SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { + var name = conf.ConfigAlias + var jumper = "ProxyJump" + var bastionList = []SshConfig{GetBastion(name, conf)} + for { + value := defaultSshConfigList.Get(name, jumper) + if value != "" { + bastionList = append(bastionList, GetBastion(value, conf)) + name = value + continue + } + break + } + for i := len(bastionList) - 1; i >= 0; i-- { + if client == nil { + client, err = bastionList[i].Dial(ctx, stopChan) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %s", bastionList[i])) + return + } + } else { + client, err = JumpTo(ctx, client, bastionList[i], stopChan) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", bastionList[i])) + return + } + } + } + return +} + +func (conf SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { + flags := pflag.NewFlagSet("", pflag.ContinueOnError) + var sshConf = &SshConfig{} + AddSshFlags(flags, sshConf) + err = flags.Parse(strings.Split(conf.Jump, " ")) + if err != nil { + return nil, err + } + var baseClient *ssh.Client + baseClient, err = DialSshRemote(ctx, sshConf, stopChan) + if err != nil { + return nil, err + } + + var bastionList []SshConfig + if conf.ConfigAlias != "" { + var name = conf.ConfigAlias + var jumper = "ProxyJump" + bastionList = append(bastionList, GetBastion(name, conf)) + for { + value := defaultSshConfigList.Get(name, jumper) + if value != "" { + bastionList = append(bastionList, GetBastion(value, conf)) + name = value + continue + } + break + } + } + if conf.Addr != "" { + bastionList = append(bastionList, conf) + } + + for _, sshConfig := range bastionList { + client, err = JumpTo(ctx, baseClient, sshConfig, stopChan) + if err != nil { + err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", sshConfig)) + return + } + } + return +} + +func (conf SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { + if _, _, err = net.SplitHostPort(conf.Addr); err != nil { + // use default ssh port 22 + conf.Addr = net.JoinHostPort(conf.Addr, "22") + err = nil + } + // connect to the bastion host + authMethod, err := conf.GetAuth() + if err != nil { + return nil, err + } + d := net.Dialer{Timeout: time.Second * 10, KeepAlive: config.KeepAliveTime} + conn, err := d.DialContext(ctx, "tcp", conf.Addr) + if err != nil { + return nil, err + } + go func() { + if stopChan != nil { + <-stopChan + conn.Close() + if client != nil { + client.Close() + } + } + }() + defer func() { + if err != nil { + if conn != nil { + conn.Close() + } + if client != nil { + client.Close() + } + } + }() + c, chans, reqs, err := ssh.NewClientConn(conn, conf.Addr, &ssh.ClientConfig{ + User: conf.User, + Auth: authMethod, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + //BannerCallback: ssh.BannerDisplayStderr(), + Timeout: time.Second * 10, + }) + if err != nil { + return nil, err + } + return ssh.NewClient(c, chans, reqs), nil +} + +func GetBastion(name string, defaultValue SshConfig) SshConfig { + var host, port string + conf := SshConfig{ + ConfigAlias: name, + } + var propertyList = []string{"ProxyJump", "Hostname", "User", "Port", "IdentityFile"} + for i, s := range propertyList { + value := defaultSshConfigList.Get(name, s) + switch i { + case 0: + + case 1: + host = value + case 2: + conf.User = value + case 3: + if port = value; port == "" { + port = strconv.Itoa(22) + } + case 4: + if value == "" { + conf.Keyfile = defaultValue.Keyfile + conf.Password = defaultValue.Password + conf.GSSAPIKeytabConf = defaultValue.GSSAPIKeytabConf + conf.GSSAPIPassword = defaultValue.GSSAPIPassword + conf.GSSAPICacheFile = defaultValue.GSSAPICacheFile + } else { + conf.Keyfile = value + } + } + } + conf.Addr = net.JoinHostPort(host, port) + return conf +} + +type defaultSshConf []*ssh_config.Config + +func (c defaultSshConf) Get(alias string, key string) string { + for _, s := range c { + if v, err := s.Get(alias, key); err == nil { + return v + } + } + return ssh_config.Get(alias, key) +} + +var once sync.Once + +var defaultSshConfigList defaultSshConf + +func init() { + once.Do(func() { + paths := []string{ + filepath.Join(homedir.HomeDir(), ".ssh", "config"), + filepath.Join("/", "etc", "ssh", "ssh_config"), + } + for _, path := range paths { + file, err := os.ReadFile(path) + if err != nil { + continue + } + cfg, err := ssh_config.DecodeBytes(file) + if err != nil { + continue + } + defaultSshConfigList = append(defaultSshConfigList, cfg) + } + }) +} + +func newSshClientWrap(client *ssh.Client, cancel context.CancelFunc) *sshClientWrap { + return &sshClientWrap{Client: client, cancel: cancel} +} + +type sshClientWrap struct { + cancel context.CancelFunc + *ssh.Client +} + +func (c *sshClientWrap) Close() error { + c.cancel() + return c.Client.Close() +} + +func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) { + // for ssh jumper host + flags.StringVar(&sshConf.Addr, "ssh-addr", "", "Optional ssh jump server address to dial as :, eg: 127.0.0.1:22") + flags.StringVar(&sshConf.User, "ssh-username", "", "Optional username for ssh jump server") + flags.StringVar(&sshConf.Password, "ssh-password", "", "Optional password for ssh jump server") + flags.StringVar(&sshConf.Keyfile, "ssh-keyfile", "", "Optional file with private key for SSH authentication") + flags.StringVar(&sshConf.ConfigAlias, "ssh-alias", "", "Optional config alias with ~/.ssh/config for SSH authentication") + flags.StringVar(&sshConf.Jump, "ssh-jump", "", "Optional bastion jump config string, eg: '--ssh-addr jumpe.naison.org --ssh-username naison --gssapi-password xxx'") + flags.StringVar(&sshConf.GSSAPIPassword, "gssapi-password", "", "GSSAPI password") + flags.StringVar(&sshConf.GSSAPIKeytabConf, "gssapi-keytab", "", "GSSAPI keytab file path") + flags.StringVar(&sshConf.GSSAPICacheFile, "gssapi-cache", "", "GSSAPI cache file path, use command `kinit -c /path/to/cache USERNAME@RELAM` to generate") + flags.StringVar(&sshConf.RemoteKubeconfig, "remote-kubeconfig", "", "Remote kubeconfig abstract path of ssh server, default is /home/$USERNAME/.kube/config") + lookup := flags.Lookup("remote-kubeconfig") + lookup.NoOptDefVal = "~/.kube/config" +} + +func keepAlive(cl *ssh.Client, conn net.Conn, done <-chan struct{}) error { + const keepAliveInterval = time.Second * 10 + t := time.NewTicker(keepAliveInterval) + defer t.Stop() + for { + select { + case <-t.C: + _, _, err := cl.SendRequest("keepalive@golang.org", true, nil) + if err != nil && err != io.EOF { + return errors.Wrap(err, "failed to send keep alive") + } + case <-done: + return nil + } + } +} diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index f7da47b2..b5def737 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -12,114 +12,29 @@ import ( "os" "path/filepath" "strconv" - "strings" "sync" "time" - "github.com/kevinburke/ssh_config" + "github.com/google/uuid" "github.com/pkg/errors" "github.com/spf13/pflag" - "golang.org/x/crypto/ssh" + gossh "golang.org/x/crypto/ssh" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/cli-runtime/pkg/genericclioptions" "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd/api" "k8s.io/client-go/tools/clientcmd/api/latest" - "k8s.io/client-go/util/homedir" "k8s.io/kubectl/pkg/cmd/util" "k8s.io/utils/pointer" "github.com/wencaiwulue/kubevpn/v2/pkg/config" - "github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc" plog "github.com/wencaiwulue/kubevpn/v2/pkg/log" pkgutil "github.com/wencaiwulue/kubevpn/v2/pkg/util" ) -type SshConfig struct { - Addr string - User string - Password string - Keyfile string - Jump string - ConfigAlias string - RemoteKubeconfig string - // GSSAPI - GSSAPIKeytabConf string - GSSAPIPassword string - GSSAPICacheFile string -} - -func (s SshConfig) Clone() SshConfig { - return SshConfig{ - Addr: s.Addr, - User: s.User, - Password: s.Password, - Keyfile: s.Keyfile, - Jump: s.Jump, - ConfigAlias: s.ConfigAlias, - RemoteKubeconfig: s.RemoteKubeconfig, - GSSAPIKeytabConf: s.GSSAPIKeytabConf, - GSSAPIPassword: s.GSSAPIPassword, - GSSAPICacheFile: s.GSSAPICacheFile, - } -} - -func ParseSshFromRPC(sshJump *rpc.SshJump) *SshConfig { - if sshJump == nil { - return &SshConfig{} - } - return &SshConfig{ - Addr: sshJump.Addr, - User: sshJump.User, - Password: sshJump.Password, - Keyfile: sshJump.Keyfile, - Jump: sshJump.Jump, - ConfigAlias: sshJump.ConfigAlias, - RemoteKubeconfig: sshJump.RemoteKubeconfig, - GSSAPIKeytabConf: sshJump.GSSAPIKeytabConf, - GSSAPIPassword: sshJump.GSSAPIPassword, - GSSAPICacheFile: sshJump.GSSAPICacheFile, - } -} - -func (config *SshConfig) ToRPC() *rpc.SshJump { - return &rpc.SshJump{ - Addr: config.Addr, - User: config.User, - Password: config.Password, - Keyfile: config.Keyfile, - Jump: config.Jump, - ConfigAlias: config.ConfigAlias, - RemoteKubeconfig: config.RemoteKubeconfig, - GSSAPIKeytabConf: config.GSSAPIKeytabConf, - GSSAPIPassword: config.GSSAPIPassword, - GSSAPICacheFile: config.GSSAPICacheFile, - } -} - -func (config *SshConfig) IsEmpty() bool { - return config.ConfigAlias == "" && config.Addr == "" && config.Jump == "" -} - -func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) { - // for ssh jumper host - flags.StringVar(&sshConf.Addr, "ssh-addr", "", "Optional ssh jump server address to dial as :, eg: 127.0.0.1:22") - flags.StringVar(&sshConf.User, "ssh-username", "", "Optional username for ssh jump server") - flags.StringVar(&sshConf.Password, "ssh-password", "", "Optional password for ssh jump server") - flags.StringVar(&sshConf.Keyfile, "ssh-keyfile", "", "Optional file with private key for SSH authentication") - flags.StringVar(&sshConf.ConfigAlias, "ssh-alias", "", "Optional config alias with ~/.ssh/config for SSH authentication") - flags.StringVar(&sshConf.Jump, "ssh-jump", "", "Optional bastion jump config string, eg: '--ssh-addr jumpe.naison.org --ssh-username naison --gssapi-password xxx'") - flags.StringVar(&sshConf.GSSAPIPassword, "gssapi-password", "", "GSSAPI password") - flags.StringVar(&sshConf.GSSAPIKeytabConf, "gssapi-keytab", "", "GSSAPI keytab file path") - flags.StringVar(&sshConf.GSSAPICacheFile, "gssapi-cache", "", "GSSAPI cache file path, use command `kinit -c /path/to/cache USERNAME@RELAM` to generate") - flags.StringVar(&sshConf.RemoteKubeconfig, "remote-kubeconfig", "", "Remote kubeconfig abstract path of ssh server, default is /home/$USERNAME/.kube/config") - lookup := flags.Lookup("remote-kubeconfig") - lookup.NoOptDefVal = "~/.kube/config" -} - // DialSshRemote https://github.com/golang/go/issues/21478 -func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *ssh.Client, err error) { +func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *gossh.Client, err error) { defer func() { if err != nil { if remote != nil { @@ -148,64 +63,8 @@ func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{ return remote, err } -func keepAlive(cl *ssh.Client, conn net.Conn, done <-chan struct{}) error { - const keepAliveInterval = time.Second * 10 - t := time.NewTicker(keepAliveInterval) - defer t.Stop() - for { - select { - case <-t.C: - _, _, err := cl.SendRequest("keepalive@golang.org", true, nil) - if err != nil && err != io.EOF { - return errors.Wrap(err, "failed to send keep alive") - } - case <-done: - return nil - } - } -} - -func (config SshConfig) GetAuth() ([]ssh.AuthMethod, error) { - host, _, _ := net.SplitHostPort(config.Addr) - var auth []ssh.AuthMethod - var c Krb5InitiatorClient - var err error - var krb5Conf = GetKrb5Path() - if config.Password != "" { - auth = append(auth, ssh.Password(config.Password)) - } else if config.GSSAPIPassword != "" { - c, err = NewKrb5InitiatorClientWithPassword(config.User, config.GSSAPIPassword, krb5Conf) - if err != nil { - return nil, err - } - auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host)) - } else if config.GSSAPIKeytabConf != "" { - c, err = NewKrb5InitiatorClientWithKeytab(config.User, krb5Conf, config.GSSAPIKeytabConf) - if err != nil { - return nil, err - } - } else if config.GSSAPICacheFile != "" { - c, err = NewKrb5InitiatorClientWithCache(krb5Conf, config.GSSAPICacheFile) - if err != nil { - return nil, err - } - auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host)) - } else { - if config.Keyfile == "" { - config.Keyfile = filepath.Join(homedir.HomeDir(), ".ssh", "id_rsa") - } - var keyFile ssh.AuthMethod - keyFile, err = publicKeyFile(config.Keyfile) - if err != nil { - return nil, err - } - auth = append(auth, keyFile) - } - return auth, nil -} - -func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) { - var session *ssh.Session +func RemoteRun(client *gossh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) { + var session *gossh.Session session, err = client.NewSession() if err != nil { return @@ -227,322 +86,6 @@ func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output [] return out.Bytes(), er.Bytes(), err } -func publicKeyFile(file string) (ssh.AuthMethod, error) { - var err error - if len(file) != 0 && file[0] == '~' { - file = filepath.Join(homedir.HomeDir(), file[1:]) - } - file, err = filepath.Abs(file) - if err != nil { - err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file)) - return nil, err - } - buffer, err := os.ReadFile(file) - if err != nil { - err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file)) - return nil, err - } - - key, err := ssh.ParsePrivateKey(buffer) - if err != nil { - err = errors.Wrap(err, fmt.Sprintf("Cannot parse SSH public key file %s", file)) - return nil, err - } - return ssh.PublicKeys(key), nil -} - -func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { - chDone := make(chan bool, 2) - - // start remote -> local data transfer - go func() { - buf := config.LPool.Get().([]byte)[:] - defer config.LPool.Put(buf[:]) - _, err := io.CopyBuffer(local, remote, buf) - if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { - plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err) - } - chDone <- true - }() - - // start local -> remote data transfer - go func() { - buf := config.LPool.Get().([]byte)[:] - defer config.LPool.Put(buf[:]) - _, err := io.CopyBuffer(remote, local, buf) - if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { - plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err) - } - chDone <- true - }() - - select { - case <-chDone: - return - case <-ctx.Done(): - return - } -} - -func (config SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { - var name = config.ConfigAlias - var jumper = "ProxyJump" - var bastionList = []SshConfig{GetBastion(name, config)} - for { - value := confList.Get(name, jumper) - if value != "" { - bastionList = append(bastionList, GetBastion(value, config)) - name = value - continue - } - break - } - for i := len(bastionList) - 1; i >= 0; i-- { - if client == nil { - client, err = bastionList[i].Dial(ctx, stopChan) - if err != nil { - return - } - } else { - client, err = JumpTo(ctx, client, bastionList[i], stopChan) - if err != nil { - return - } - } - } - return -} - -func (config SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { - flags := pflag.NewFlagSet("", pflag.ContinueOnError) - var sshConf = &SshConfig{} - AddSshFlags(flags, sshConf) - err = flags.Parse(strings.Split(config.Jump, " ")) - if err != nil { - return nil, err - } - var baseClient *ssh.Client - baseClient, err = DialSshRemote(ctx, sshConf, stopChan) - if err != nil { - return nil, err - } - - var bastionList []SshConfig - if config.ConfigAlias != "" { - var name = config.ConfigAlias - var jumper = "ProxyJump" - bastionList = append(bastionList, GetBastion(name, config)) - for { - value := confList.Get(name, jumper) - if value != "" { - bastionList = append(bastionList, GetBastion(value, config)) - name = value - continue - } - break - } - } - if config.Addr != "" { - bastionList = append(bastionList, config) - } - - for _, sshConfig := range bastionList { - client, err = JumpTo(ctx, baseClient, sshConfig, stopChan) - if err != nil { - return - } - } - return -} - -func GetBastion(name string, defaultValue SshConfig) SshConfig { - var host, port string - config := SshConfig{ - ConfigAlias: name, - } - var propertyList = []string{"ProxyJump", "Hostname", "User", "Port", "IdentityFile"} - for i, s := range propertyList { - value := confList.Get(name, s) - switch i { - case 0: - - case 1: - host = value - case 2: - config.User = value - case 3: - if port = value; port == "" { - port = strconv.Itoa(22) - } - case 4: - if value == "" { - config.Keyfile = defaultValue.Keyfile - config.Password = defaultValue.Password - config.GSSAPIKeytabConf = defaultValue.GSSAPIKeytabConf - config.GSSAPIPassword = defaultValue.GSSAPIPassword - config.GSSAPICacheFile = defaultValue.GSSAPICacheFile - } else { - config.Keyfile = value - } - } - } - config.Addr = net.JoinHostPort(host, port) - return config -} - -func (config SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) { - if _, _, err = net.SplitHostPort(config.Addr); err != nil { - // use default ssh port 22 - config.Addr = net.JoinHostPort(config.Addr, "22") - err = nil - } - // connect to the bastion host - authMethod, err := config.GetAuth() - if err != nil { - return nil, err - } - d := net.Dialer{Timeout: time.Second * 10} - conn, err := d.DialContext(ctx, "tcp", config.Addr) - if err != nil { - return nil, err - } - go func() { - if stopChan != nil { - <-stopChan - conn.Close() - if client != nil { - client.Close() - } - } - }() - defer func() { - if err != nil { - if conn != nil { - conn.Close() - } - if client != nil { - client.Close() - } - } - }() - c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{ - User: config.User, - Auth: authMethod, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - //BannerCallback: ssh.BannerDisplayStderr(), - Timeout: time.Second * 10, - }) - if err != nil { - return nil, err - } - return ssh.NewClient(c, chans, reqs), nil -} - -func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig, stopChan <-chan struct{}) (client *ssh.Client, err error) { - if _, _, err = net.SplitHostPort(to.Addr); err != nil { - // use default ssh port 22 - to.Addr = net.JoinHostPort(to.Addr, "22") - err = nil - } - - var authMethod []ssh.AuthMethod - authMethod, err = to.GetAuth() - if err != nil { - return nil, err - } - // Dial a connection to the service host, from the bastion - var conn net.Conn - conn, err = bClient.DialContext(ctx, "tcp", to.Addr) - if err != nil { - return - } - go func() { - if stopChan != nil { - <-stopChan - conn.Close() - if client != nil { - client.Close() - } - bClient.Close() - } - }() - defer func() { - if err != nil { - if client != nil { - client.Close() - } - if conn != nil { - conn.Close() - } - } - }() - var ncc ssh.Conn - var chans <-chan ssh.NewChannel - var reqs <-chan *ssh.Request - ncc, chans, reqs, err = ssh.NewClientConn(conn, to.Addr, &ssh.ClientConfig{ - User: to.User, - Auth: authMethod, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - //BannerCallback: ssh.BannerDisplayStderr(), - Timeout: time.Second * 10, - }) - if err != nil { - return - } - - client = ssh.NewClient(ncc, chans, reqs) - return -} - -type conf []*ssh_config.Config - -func (c conf) Get(alias string, key string) string { - for _, s := range c { - if v, err := s.Get(alias, key); err == nil { - return v - } - } - return ssh_config.Get(alias, key) -} - -var once sync.Once - -var confList conf - -func init() { - once.Do(func() { - strings := []string{ - filepath.Join(homedir.HomeDir(), ".ssh", "config"), - filepath.Join("/", "etc", "ssh", "ssh_config"), - } - for _, s := range strings { - file, err := os.ReadFile(s) - if err != nil { - continue - } - cfg, err := ssh_config.DecodeBytes(file) - if err != nil { - continue - } - confList = append(confList, cfg) - } - }) -} - -func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient { - return &sshClient{Client: client, cancel: cancel} -} - -type sshClient struct { - cancel context.CancelFunc - *ssh.Client -} - -func (c *sshClient) Close() error { - c.cancel() - return c.Client.Close() -} - func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error { // Listen on remote server port var lc net.ListenConfig @@ -555,7 +98,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr go func() { defer localListen.Close() - var sshClientChan = make(chan *sshClient, 1000*1000) + var clientMap = &sync.Map{} ctx1, cancelFunc1 := context.WithCancel(ctx) defer cancelFunc1() @@ -568,11 +111,11 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr go func() { defer localConn.Close() - remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote) + remoteConn, err := getRemoteConn(ctx1, clientMap, conf, remote) if err != nil { - var openChannelError *ssh.OpenChannelError + var openChannelError *gossh.OpenChannelError // if ssh server not permitted ssh port-forward, do nothing until exit - if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited { + if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited { plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err) cancelFunc1() } @@ -588,61 +131,6 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr return nil } -func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) { - select { - case cli, ok := <-sshClientChan: - if !ok { - return nil, errors.New("ssh client chan closed") - } - ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc1() - conn, err = cli.DialContext(ctx1, "tcp", remote.String()) - if err != nil { - plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err) - _ = cli.Close() - return nil, err - } - safeWrite(ctx, sshClientChan, cli) - return conn, nil - default: - ctx1, cancelFunc1 := context.WithCancel(ctx) - defer func() { - if err != nil { - cancelFunc1() - } - }() - ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc2() - var client *ssh.Client - client, err = DialSshRemote(ctx2, conf, ctx1.Done()) - if err != nil { - plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err) - return nil, err - } - ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc3() - conn, err = client.DialContext(ctx3, "tcp", remote.String()) - if err != nil { - plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) - client.Close() - return nil, err - } - cli := newSshClient(client, cancelFunc1) - safeWrite(ctx1, sshClientChan, cli) - return conn, nil - } -} - -func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) { - write := pkgutil.SafeWrite(sshClientChan, cli) - if !write { - go func() { - <-ctx.Done() - cli.Close() - }() - } -} - func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) { if conf.Addr == "" && conf.ConfigAlias == "" { if flags != nil { @@ -665,14 +153,6 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b } }() - // pre-check network ip connect - var cli *ssh.Client - cli, err = DialSshRemote(ctx, conf, ctx.Done()) - if err != nil { - return - } - defer cli.Close() - configFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag() if conf.RemoteKubeconfig != "" || (flags != nil && flags.Changed("remote-kubeconfig")) { @@ -685,6 +165,13 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b // if `--remote-kubeconfig` is parsed then Entrypoint is reset conf.RemoteKubeconfig = filepath.Join("/home", conf.User, clientcmd.RecommendedHomeDir, clientcmd.RecommendedFileName) } + // pre-check network ip connect + var cli *gossh.Client + cli, err = DialSshRemote(ctx, conf, ctx.Done()) + if err != nil { + return + } + defer cli.Close() stdout, stderr, err = RemoteRun(cli, fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s'", conf.RemoteKubeconfig, @@ -863,3 +350,138 @@ func SshJumpAndSetEnv(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet } return os.Setenv(config.EnvSSHJump, path) } + +func JumpTo(ctx context.Context, bClient *gossh.Client, to SshConfig, stopChan <-chan struct{}) (client *gossh.Client, err error) { + if _, _, err = net.SplitHostPort(to.Addr); err != nil { + // use default ssh port 22 + to.Addr = net.JoinHostPort(to.Addr, "22") + err = nil + } + + var authMethod []gossh.AuthMethod + authMethod, err = to.GetAuth() + if err != nil { + return nil, err + } + // Dial a connection to the service host, from the bastion + var conn net.Conn + conn, err = bClient.DialContext(ctx, "tcp", to.Addr) + if err != nil { + return + } + go func() { + if stopChan != nil { + <-stopChan + conn.Close() + if client != nil { + client.Close() + } + bClient.Close() + } + }() + defer func() { + if err != nil { + if client != nil { + client.Close() + } + if conn != nil { + conn.Close() + } + } + }() + var ncc gossh.Conn + var chans <-chan gossh.NewChannel + var reqs <-chan *gossh.Request + ncc, chans, reqs, err = gossh.NewClientConn(conn, to.Addr, &gossh.ClientConfig{ + User: to.User, + Auth: authMethod, + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + //BannerCallback: ssh.BannerDisplayStderr(), + Timeout: time.Second * 10, + }) + if err != nil { + return + } + + client = gossh.NewClient(ncc, chans, reqs) + return +} + +func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) { + clientMap.Range(func(key, value any) bool { + cli := value.(*sshClientWrap) + ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10) + conn, err = cli.DialContext(ctx1, "tcp", remote.String()) + cancelFunc1() + if err != nil { + plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err) + clientMap.Delete(key) + _ = cli.Close() + return true + } + return false + }) + if conn != nil { + return + } + + ctx1, cancelFunc1 := context.WithCancel(ctx) + defer func() { + if err != nil { + cancelFunc1() + } + }() + + ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc2() + var client *gossh.Client + client, err = DialSshRemote(ctx2, conf, ctx1.Done()) + if err != nil { + plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err) + return nil, err + } + + ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc3() + conn, err = client.DialContext(ctx3, "tcp", remote.String()) + if err != nil { + plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) + _ = client.Close() + return nil, err + } + clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1)) + return conn, nil +} + +func copyStream(ctx context.Context, local net.Conn, remote net.Conn) { + chDone := make(chan bool, 2) + + // start remote -> local data transfer + go func() { + buf := config.LPool.Get().([]byte)[:] + defer config.LPool.Put(buf[:]) + _, err := io.CopyBuffer(local, remote, buf) + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { + plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err) + } + chDone <- true + }() + + // start local -> remote data transfer + go func() { + buf := config.LPool.Get().([]byte)[:] + defer config.LPool.Put(buf[:]) + _, err := io.CopyBuffer(remote, local, buf) + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { + plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err) + } + chDone <- true + }() + + select { + case <-chDone: + return + case <-ctx.Done(): + return + } +}