mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-12-24 11:51:13 +08:00
refactor: optimize ssh logic (#555)
This commit is contained in:
391
pkg/ssh/config.go
Normal file
391
pkg/ssh/config.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"k8s.io/client-go/util/homedir"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
|
||||
)
|
||||
|
||||
type SshConfig struct {
|
||||
Addr string
|
||||
User string
|
||||
Password string
|
||||
Keyfile string
|
||||
Jump string
|
||||
ConfigAlias string
|
||||
RemoteKubeconfig string
|
||||
// GSSAPI
|
||||
GSSAPIKeytabConf string
|
||||
GSSAPIPassword string
|
||||
GSSAPICacheFile string
|
||||
}
|
||||
|
||||
func (conf SshConfig) Clone() SshConfig {
|
||||
return SshConfig{
|
||||
Addr: conf.Addr,
|
||||
User: conf.User,
|
||||
Password: conf.Password,
|
||||
Keyfile: conf.Keyfile,
|
||||
Jump: conf.Jump,
|
||||
ConfigAlias: conf.ConfigAlias,
|
||||
RemoteKubeconfig: conf.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: conf.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: conf.GSSAPIPassword,
|
||||
GSSAPICacheFile: conf.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseSshFromRPC(sshJump *rpc.SshJump) *SshConfig {
|
||||
if sshJump == nil {
|
||||
return &SshConfig{}
|
||||
}
|
||||
return &SshConfig{
|
||||
Addr: sshJump.Addr,
|
||||
User: sshJump.User,
|
||||
Password: sshJump.Password,
|
||||
Keyfile: sshJump.Keyfile,
|
||||
Jump: sshJump.Jump,
|
||||
ConfigAlias: sshJump.ConfigAlias,
|
||||
RemoteKubeconfig: sshJump.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: sshJump.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: sshJump.GSSAPIPassword,
|
||||
GSSAPICacheFile: sshJump.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func (conf SshConfig) ToRPC() *rpc.SshJump {
|
||||
return &rpc.SshJump{
|
||||
Addr: conf.Addr,
|
||||
User: conf.User,
|
||||
Password: conf.Password,
|
||||
Keyfile: conf.Keyfile,
|
||||
Jump: conf.Jump,
|
||||
ConfigAlias: conf.ConfigAlias,
|
||||
RemoteKubeconfig: conf.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: conf.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: conf.GSSAPIPassword,
|
||||
GSSAPICacheFile: conf.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func (conf SshConfig) IsEmpty() bool {
|
||||
return conf.ConfigAlias == "" && conf.Addr == "" && conf.Jump == ""
|
||||
}
|
||||
|
||||
func (conf SshConfig) GetAuth() ([]ssh.AuthMethod, error) {
|
||||
host, _, _ := net.SplitHostPort(conf.Addr)
|
||||
var auth []ssh.AuthMethod
|
||||
var c Krb5InitiatorClient
|
||||
var err error
|
||||
var krb5Conf = GetKrb5Path()
|
||||
if conf.Password != "" {
|
||||
auth = append(auth, ssh.Password(conf.Password))
|
||||
} else if conf.GSSAPIPassword != "" {
|
||||
c, err = NewKrb5InitiatorClientWithPassword(conf.User, conf.GSSAPIPassword, krb5Conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host))
|
||||
} else if conf.GSSAPIKeytabConf != "" {
|
||||
c, err = NewKrb5InitiatorClientWithKeytab(conf.User, krb5Conf, conf.GSSAPIKeytabConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if conf.GSSAPICacheFile != "" {
|
||||
c, err = NewKrb5InitiatorClientWithCache(krb5Conf, conf.GSSAPICacheFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host))
|
||||
} else {
|
||||
if conf.Keyfile == "" {
|
||||
conf.Keyfile = filepath.Join(homedir.HomeDir(), ".ssh", "id_rsa")
|
||||
}
|
||||
var keyFile ssh.AuthMethod
|
||||
keyFile, err = publicKeyFile(conf.Keyfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, keyFile)
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
var err error
|
||||
if len(file) != 0 && file[0] == '~' {
|
||||
file = filepath.Join(homedir.HomeDir(), file[1:])
|
||||
}
|
||||
file, err = filepath.Abs(file)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
buffer, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := ssh.ParsePrivateKey(buffer)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot parse SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func (conf SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
var name = conf.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
var bastionList = []SshConfig{GetBastion(name, conf)}
|
||||
for {
|
||||
value := defaultSshConfigList.Get(name, jumper)
|
||||
if value != "" {
|
||||
bastionList = append(bastionList, GetBastion(value, conf))
|
||||
name = value
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
for i := len(bastionList) - 1; i >= 0; i-- {
|
||||
if client == nil {
|
||||
client, err = bastionList[i].Dial(ctx, stopChan)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %s", bastionList[i]))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
client, err = JumpTo(ctx, client, bastionList[i], stopChan)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", bastionList[i]))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (conf SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||
var sshConf = &SshConfig{}
|
||||
AddSshFlags(flags, sshConf)
|
||||
err = flags.Parse(strings.Split(conf.Jump, " "))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var baseClient *ssh.Client
|
||||
baseClient, err = DialSshRemote(ctx, sshConf, stopChan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bastionList []SshConfig
|
||||
if conf.ConfigAlias != "" {
|
||||
var name = conf.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
bastionList = append(bastionList, GetBastion(name, conf))
|
||||
for {
|
||||
value := defaultSshConfigList.Get(name, jumper)
|
||||
if value != "" {
|
||||
bastionList = append(bastionList, GetBastion(value, conf))
|
||||
name = value
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if conf.Addr != "" {
|
||||
bastionList = append(bastionList, conf)
|
||||
}
|
||||
|
||||
for _, sshConfig := range bastionList {
|
||||
client, err = JumpTo(ctx, baseClient, sshConfig, stopChan)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", sshConfig))
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (conf SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(conf.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
conf.Addr = net.JoinHostPort(conf.Addr, "22")
|
||||
err = nil
|
||||
}
|
||||
// connect to the bastion host
|
||||
authMethod, err := conf.GetAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := net.Dialer{Timeout: time.Second * 10, KeepAlive: config.KeepAliveTime}
|
||||
conn, err := d.DialContext(ctx, "tcp", conf.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
c, chans, reqs, err := ssh.NewClientConn(conn, conf.Addr, &ssh.ClientConfig{
|
||||
User: conf.User,
|
||||
Auth: authMethod,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
//BannerCallback: ssh.BannerDisplayStderr(),
|
||||
Timeout: time.Second * 10,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.NewClient(c, chans, reqs), nil
|
||||
}
|
||||
|
||||
func GetBastion(name string, defaultValue SshConfig) SshConfig {
|
||||
var host, port string
|
||||
conf := SshConfig{
|
||||
ConfigAlias: name,
|
||||
}
|
||||
var propertyList = []string{"ProxyJump", "Hostname", "User", "Port", "IdentityFile"}
|
||||
for i, s := range propertyList {
|
||||
value := defaultSshConfigList.Get(name, s)
|
||||
switch i {
|
||||
case 0:
|
||||
|
||||
case 1:
|
||||
host = value
|
||||
case 2:
|
||||
conf.User = value
|
||||
case 3:
|
||||
if port = value; port == "" {
|
||||
port = strconv.Itoa(22)
|
||||
}
|
||||
case 4:
|
||||
if value == "" {
|
||||
conf.Keyfile = defaultValue.Keyfile
|
||||
conf.Password = defaultValue.Password
|
||||
conf.GSSAPIKeytabConf = defaultValue.GSSAPIKeytabConf
|
||||
conf.GSSAPIPassword = defaultValue.GSSAPIPassword
|
||||
conf.GSSAPICacheFile = defaultValue.GSSAPICacheFile
|
||||
} else {
|
||||
conf.Keyfile = value
|
||||
}
|
||||
}
|
||||
}
|
||||
conf.Addr = net.JoinHostPort(host, port)
|
||||
return conf
|
||||
}
|
||||
|
||||
type defaultSshConf []*ssh_config.Config
|
||||
|
||||
func (c defaultSshConf) Get(alias string, key string) string {
|
||||
for _, s := range c {
|
||||
if v, err := s.Get(alias, key); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ssh_config.Get(alias, key)
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
||||
var defaultSshConfigList defaultSshConf
|
||||
|
||||
func init() {
|
||||
once.Do(func() {
|
||||
paths := []string{
|
||||
filepath.Join(homedir.HomeDir(), ".ssh", "config"),
|
||||
filepath.Join("/", "etc", "ssh", "ssh_config"),
|
||||
}
|
||||
for _, path := range paths {
|
||||
file, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cfg, err := ssh_config.DecodeBytes(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
defaultSshConfigList = append(defaultSshConfigList, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newSshClientWrap(client *ssh.Client, cancel context.CancelFunc) *sshClientWrap {
|
||||
return &sshClientWrap{Client: client, cancel: cancel}
|
||||
}
|
||||
|
||||
type sshClientWrap struct {
|
||||
cancel context.CancelFunc
|
||||
*ssh.Client
|
||||
}
|
||||
|
||||
func (c *sshClientWrap) Close() error {
|
||||
c.cancel()
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
|
||||
// for ssh jumper host
|
||||
flags.StringVar(&sshConf.Addr, "ssh-addr", "", "Optional ssh jump server address to dial as <hostname>:<port>, eg: 127.0.0.1:22")
|
||||
flags.StringVar(&sshConf.User, "ssh-username", "", "Optional username for ssh jump server")
|
||||
flags.StringVar(&sshConf.Password, "ssh-password", "", "Optional password for ssh jump server")
|
||||
flags.StringVar(&sshConf.Keyfile, "ssh-keyfile", "", "Optional file with private key for SSH authentication")
|
||||
flags.StringVar(&sshConf.ConfigAlias, "ssh-alias", "", "Optional config alias with ~/.ssh/config for SSH authentication")
|
||||
flags.StringVar(&sshConf.Jump, "ssh-jump", "", "Optional bastion jump config string, eg: '--ssh-addr jumpe.naison.org --ssh-username naison --gssapi-password xxx'")
|
||||
flags.StringVar(&sshConf.GSSAPIPassword, "gssapi-password", "", "GSSAPI password")
|
||||
flags.StringVar(&sshConf.GSSAPIKeytabConf, "gssapi-keytab", "", "GSSAPI keytab file path")
|
||||
flags.StringVar(&sshConf.GSSAPICacheFile, "gssapi-cache", "", "GSSAPI cache file path, use command `kinit -c /path/to/cache USERNAME@RELAM` to generate")
|
||||
flags.StringVar(&sshConf.RemoteKubeconfig, "remote-kubeconfig", "", "Remote kubeconfig abstract path of ssh server, default is /home/$USERNAME/.kube/config")
|
||||
lookup := flags.Lookup("remote-kubeconfig")
|
||||
lookup.NoOptDefVal = "~/.kube/config"
|
||||
}
|
||||
|
||||
func keepAlive(cl *ssh.Client, conn net.Conn, done <-chan struct{}) error {
|
||||
const keepAliveInterval = time.Second * 10
|
||||
t := time.NewTicker(keepAliveInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
_, _, err := cl.SendRequest("keepalive@golang.org", true, nil)
|
||||
if err != nil && err != io.EOF {
|
||||
return errors.Wrap(err, "failed to send keep alive")
|
||||
}
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
680
pkg/ssh/ssh.go
680
pkg/ssh/ssh.go
@@ -12,114 +12,29 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kevinburke/ssh_config"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/spf13/pflag"
|
||||
"golang.org/x/crypto/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
"k8s.io/apimachinery/pkg/runtime"
|
||||
"k8s.io/apimachinery/pkg/runtime/schema"
|
||||
"k8s.io/cli-runtime/pkg/genericclioptions"
|
||||
"k8s.io/client-go/tools/clientcmd"
|
||||
"k8s.io/client-go/tools/clientcmd/api"
|
||||
"k8s.io/client-go/tools/clientcmd/api/latest"
|
||||
"k8s.io/client-go/util/homedir"
|
||||
"k8s.io/kubectl/pkg/cmd/util"
|
||||
"k8s.io/utils/pointer"
|
||||
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
|
||||
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
|
||||
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
|
||||
pkgutil "github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||
)
|
||||
|
||||
type SshConfig struct {
|
||||
Addr string
|
||||
User string
|
||||
Password string
|
||||
Keyfile string
|
||||
Jump string
|
||||
ConfigAlias string
|
||||
RemoteKubeconfig string
|
||||
// GSSAPI
|
||||
GSSAPIKeytabConf string
|
||||
GSSAPIPassword string
|
||||
GSSAPICacheFile string
|
||||
}
|
||||
|
||||
func (s SshConfig) Clone() SshConfig {
|
||||
return SshConfig{
|
||||
Addr: s.Addr,
|
||||
User: s.User,
|
||||
Password: s.Password,
|
||||
Keyfile: s.Keyfile,
|
||||
Jump: s.Jump,
|
||||
ConfigAlias: s.ConfigAlias,
|
||||
RemoteKubeconfig: s.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: s.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: s.GSSAPIPassword,
|
||||
GSSAPICacheFile: s.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func ParseSshFromRPC(sshJump *rpc.SshJump) *SshConfig {
|
||||
if sshJump == nil {
|
||||
return &SshConfig{}
|
||||
}
|
||||
return &SshConfig{
|
||||
Addr: sshJump.Addr,
|
||||
User: sshJump.User,
|
||||
Password: sshJump.Password,
|
||||
Keyfile: sshJump.Keyfile,
|
||||
Jump: sshJump.Jump,
|
||||
ConfigAlias: sshJump.ConfigAlias,
|
||||
RemoteKubeconfig: sshJump.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: sshJump.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: sshJump.GSSAPIPassword,
|
||||
GSSAPICacheFile: sshJump.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func (config *SshConfig) ToRPC() *rpc.SshJump {
|
||||
return &rpc.SshJump{
|
||||
Addr: config.Addr,
|
||||
User: config.User,
|
||||
Password: config.Password,
|
||||
Keyfile: config.Keyfile,
|
||||
Jump: config.Jump,
|
||||
ConfigAlias: config.ConfigAlias,
|
||||
RemoteKubeconfig: config.RemoteKubeconfig,
|
||||
GSSAPIKeytabConf: config.GSSAPIKeytabConf,
|
||||
GSSAPIPassword: config.GSSAPIPassword,
|
||||
GSSAPICacheFile: config.GSSAPICacheFile,
|
||||
}
|
||||
}
|
||||
|
||||
func (config *SshConfig) IsEmpty() bool {
|
||||
return config.ConfigAlias == "" && config.Addr == "" && config.Jump == ""
|
||||
}
|
||||
|
||||
func AddSshFlags(flags *pflag.FlagSet, sshConf *SshConfig) {
|
||||
// for ssh jumper host
|
||||
flags.StringVar(&sshConf.Addr, "ssh-addr", "", "Optional ssh jump server address to dial as <hostname>:<port>, eg: 127.0.0.1:22")
|
||||
flags.StringVar(&sshConf.User, "ssh-username", "", "Optional username for ssh jump server")
|
||||
flags.StringVar(&sshConf.Password, "ssh-password", "", "Optional password for ssh jump server")
|
||||
flags.StringVar(&sshConf.Keyfile, "ssh-keyfile", "", "Optional file with private key for SSH authentication")
|
||||
flags.StringVar(&sshConf.ConfigAlias, "ssh-alias", "", "Optional config alias with ~/.ssh/config for SSH authentication")
|
||||
flags.StringVar(&sshConf.Jump, "ssh-jump", "", "Optional bastion jump config string, eg: '--ssh-addr jumpe.naison.org --ssh-username naison --gssapi-password xxx'")
|
||||
flags.StringVar(&sshConf.GSSAPIPassword, "gssapi-password", "", "GSSAPI password")
|
||||
flags.StringVar(&sshConf.GSSAPIKeytabConf, "gssapi-keytab", "", "GSSAPI keytab file path")
|
||||
flags.StringVar(&sshConf.GSSAPICacheFile, "gssapi-cache", "", "GSSAPI cache file path, use command `kinit -c /path/to/cache USERNAME@RELAM` to generate")
|
||||
flags.StringVar(&sshConf.RemoteKubeconfig, "remote-kubeconfig", "", "Remote kubeconfig abstract path of ssh server, default is /home/$USERNAME/.kube/config")
|
||||
lookup := flags.Lookup("remote-kubeconfig")
|
||||
lookup.NoOptDefVal = "~/.kube/config"
|
||||
}
|
||||
|
||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *ssh.Client, err error) {
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *gossh.Client, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if remote != nil {
|
||||
@@ -148,64 +63,8 @@ func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{
|
||||
return remote, err
|
||||
}
|
||||
|
||||
func keepAlive(cl *ssh.Client, conn net.Conn, done <-chan struct{}) error {
|
||||
const keepAliveInterval = time.Second * 10
|
||||
t := time.NewTicker(keepAliveInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
_, _, err := cl.SendRequest("keepalive@golang.org", true, nil)
|
||||
if err != nil && err != io.EOF {
|
||||
return errors.Wrap(err, "failed to send keep alive")
|
||||
}
|
||||
case <-done:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (config SshConfig) GetAuth() ([]ssh.AuthMethod, error) {
|
||||
host, _, _ := net.SplitHostPort(config.Addr)
|
||||
var auth []ssh.AuthMethod
|
||||
var c Krb5InitiatorClient
|
||||
var err error
|
||||
var krb5Conf = GetKrb5Path()
|
||||
if config.Password != "" {
|
||||
auth = append(auth, ssh.Password(config.Password))
|
||||
} else if config.GSSAPIPassword != "" {
|
||||
c, err = NewKrb5InitiatorClientWithPassword(config.User, config.GSSAPIPassword, krb5Conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host))
|
||||
} else if config.GSSAPIKeytabConf != "" {
|
||||
c, err = NewKrb5InitiatorClientWithKeytab(config.User, krb5Conf, config.GSSAPIKeytabConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if config.GSSAPICacheFile != "" {
|
||||
c, err = NewKrb5InitiatorClientWithCache(krb5Conf, config.GSSAPICacheFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, ssh.GSSAPIWithMICAuthMethod(&c, host))
|
||||
} else {
|
||||
if config.Keyfile == "" {
|
||||
config.Keyfile = filepath.Join(homedir.HomeDir(), ".ssh", "id_rsa")
|
||||
}
|
||||
var keyFile ssh.AuthMethod
|
||||
keyFile, err = publicKeyFile(config.Keyfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
auth = append(auth, keyFile)
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
||||
var session *ssh.Session
|
||||
func RemoteRun(client *gossh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
||||
var session *gossh.Session
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -227,322 +86,6 @@ func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []
|
||||
return out.Bytes(), er.Bytes(), err
|
||||
}
|
||||
|
||||
func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
var err error
|
||||
if len(file) != 0 && file[0] == '~' {
|
||||
file = filepath.Join(homedir.HomeDir(), file[1:])
|
||||
}
|
||||
file, err = filepath.Abs(file)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
buffer, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot read SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := ssh.ParsePrivateKey(buffer)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, fmt.Sprintf("Cannot parse SSH public key file %s", file))
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
|
||||
chDone := make(chan bool, 2)
|
||||
|
||||
// start remote -> local data transfer
|
||||
go func() {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
defer config.LPool.Put(buf[:])
|
||||
_, err := io.CopyBuffer(local, remote, buf)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err)
|
||||
}
|
||||
chDone <- true
|
||||
}()
|
||||
|
||||
// start local -> remote data transfer
|
||||
go func() {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
defer config.LPool.Put(buf[:])
|
||||
_, err := io.CopyBuffer(remote, local, buf)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err)
|
||||
}
|
||||
chDone <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-chDone:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (config SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
var name = config.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
var bastionList = []SshConfig{GetBastion(name, config)}
|
||||
for {
|
||||
value := confList.Get(name, jumper)
|
||||
if value != "" {
|
||||
bastionList = append(bastionList, GetBastion(value, config))
|
||||
name = value
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
for i := len(bastionList) - 1; i >= 0; i-- {
|
||||
if client == nil {
|
||||
client, err = bastionList[i].Dial(ctx, stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
client, err = JumpTo(ctx, client, bastionList[i], stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (config SshConfig) JumpRecursion(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
|
||||
var sshConf = &SshConfig{}
|
||||
AddSshFlags(flags, sshConf)
|
||||
err = flags.Parse(strings.Split(config.Jump, " "))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var baseClient *ssh.Client
|
||||
baseClient, err = DialSshRemote(ctx, sshConf, stopChan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bastionList []SshConfig
|
||||
if config.ConfigAlias != "" {
|
||||
var name = config.ConfigAlias
|
||||
var jumper = "ProxyJump"
|
||||
bastionList = append(bastionList, GetBastion(name, config))
|
||||
for {
|
||||
value := confList.Get(name, jumper)
|
||||
if value != "" {
|
||||
bastionList = append(bastionList, GetBastion(value, config))
|
||||
name = value
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if config.Addr != "" {
|
||||
bastionList = append(bastionList, config)
|
||||
}
|
||||
|
||||
for _, sshConfig := range bastionList {
|
||||
client, err = JumpTo(ctx, baseClient, sshConfig, stopChan)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GetBastion(name string, defaultValue SshConfig) SshConfig {
|
||||
var host, port string
|
||||
config := SshConfig{
|
||||
ConfigAlias: name,
|
||||
}
|
||||
var propertyList = []string{"ProxyJump", "Hostname", "User", "Port", "IdentityFile"}
|
||||
for i, s := range propertyList {
|
||||
value := confList.Get(name, s)
|
||||
switch i {
|
||||
case 0:
|
||||
|
||||
case 1:
|
||||
host = value
|
||||
case 2:
|
||||
config.User = value
|
||||
case 3:
|
||||
if port = value; port == "" {
|
||||
port = strconv.Itoa(22)
|
||||
}
|
||||
case 4:
|
||||
if value == "" {
|
||||
config.Keyfile = defaultValue.Keyfile
|
||||
config.Password = defaultValue.Password
|
||||
config.GSSAPIKeytabConf = defaultValue.GSSAPIKeytabConf
|
||||
config.GSSAPIPassword = defaultValue.GSSAPIPassword
|
||||
config.GSSAPICacheFile = defaultValue.GSSAPICacheFile
|
||||
} else {
|
||||
config.Keyfile = value
|
||||
}
|
||||
}
|
||||
}
|
||||
config.Addr = net.JoinHostPort(host, port)
|
||||
return config
|
||||
}
|
||||
|
||||
func (config SshConfig) Dial(ctx context.Context, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(config.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
config.Addr = net.JoinHostPort(config.Addr, "22")
|
||||
err = nil
|
||||
}
|
||||
// connect to the bastion host
|
||||
authMethod, err := config.GetAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := net.Dialer{Timeout: time.Second * 10}
|
||||
conn, err := d.DialContext(ctx, "tcp", config.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
c, chans, reqs, err := ssh.NewClientConn(conn, config.Addr, &ssh.ClientConfig{
|
||||
User: config.User,
|
||||
Auth: authMethod,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
//BannerCallback: ssh.BannerDisplayStderr(),
|
||||
Timeout: time.Second * 10,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.NewClient(c, chans, reqs), nil
|
||||
}
|
||||
|
||||
func JumpTo(ctx context.Context, bClient *ssh.Client, to SshConfig, stopChan <-chan struct{}) (client *ssh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(to.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
to.Addr = net.JoinHostPort(to.Addr, "22")
|
||||
err = nil
|
||||
}
|
||||
|
||||
var authMethod []ssh.AuthMethod
|
||||
authMethod, err = to.GetAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Dial a connection to the service host, from the bastion
|
||||
var conn net.Conn
|
||||
conn, err = bClient.DialContext(ctx, "tcp", to.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
bClient.Close()
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
var ncc ssh.Conn
|
||||
var chans <-chan ssh.NewChannel
|
||||
var reqs <-chan *ssh.Request
|
||||
ncc, chans, reqs, err = ssh.NewClientConn(conn, to.Addr, &ssh.ClientConfig{
|
||||
User: to.User,
|
||||
Auth: authMethod,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
//BannerCallback: ssh.BannerDisplayStderr(),
|
||||
Timeout: time.Second * 10,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
client = ssh.NewClient(ncc, chans, reqs)
|
||||
return
|
||||
}
|
||||
|
||||
type conf []*ssh_config.Config
|
||||
|
||||
func (c conf) Get(alias string, key string) string {
|
||||
for _, s := range c {
|
||||
if v, err := s.Get(alias, key); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ssh_config.Get(alias, key)
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
||||
var confList conf
|
||||
|
||||
func init() {
|
||||
once.Do(func() {
|
||||
strings := []string{
|
||||
filepath.Join(homedir.HomeDir(), ".ssh", "config"),
|
||||
filepath.Join("/", "etc", "ssh", "ssh_config"),
|
||||
}
|
||||
for _, s := range strings {
|
||||
file, err := os.ReadFile(s)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
cfg, err := ssh_config.DecodeBytes(file)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
confList = append(confList, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
|
||||
return &sshClient{Client: client, cancel: cancel}
|
||||
}
|
||||
|
||||
type sshClient struct {
|
||||
cancel context.CancelFunc
|
||||
*ssh.Client
|
||||
}
|
||||
|
||||
func (c *sshClient) Close() error {
|
||||
c.cancel()
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
|
||||
// Listen on remote server port
|
||||
var lc net.ListenConfig
|
||||
@@ -555,7 +98,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
go func() {
|
||||
defer localListen.Close()
|
||||
|
||||
var sshClientChan = make(chan *sshClient, 1000*1000)
|
||||
var clientMap = &sync.Map{}
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer cancelFunc1()
|
||||
|
||||
@@ -568,11 +111,11 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
go func() {
|
||||
defer localConn.Close()
|
||||
|
||||
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
|
||||
remoteConn, err := getRemoteConn(ctx1, clientMap, conf, remote)
|
||||
if err != nil {
|
||||
var openChannelError *ssh.OpenChannelError
|
||||
var openChannelError *gossh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
|
||||
plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||
cancelFunc1()
|
||||
}
|
||||
@@ -588,61 +131,6 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
_ = cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
safeWrite(ctx, sshClientChan, cli)
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := newSshClient(client, cancelFunc1)
|
||||
safeWrite(ctx1, sshClientChan, cli)
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {
|
||||
if conf.Addr == "" && conf.ConfigAlias == "" {
|
||||
if flags != nil {
|
||||
@@ -665,14 +153,6 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
|
||||
}
|
||||
}()
|
||||
|
||||
// pre-check network ip connect
|
||||
var cli *ssh.Client
|
||||
cli, err = DialSshRemote(ctx, conf, ctx.Done())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer cli.Close()
|
||||
|
||||
configFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag()
|
||||
|
||||
if conf.RemoteKubeconfig != "" || (flags != nil && flags.Changed("remote-kubeconfig")) {
|
||||
@@ -685,6 +165,13 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
|
||||
// if `--remote-kubeconfig` is parsed then Entrypoint is reset
|
||||
conf.RemoteKubeconfig = filepath.Join("/home", conf.User, clientcmd.RecommendedHomeDir, clientcmd.RecommendedFileName)
|
||||
}
|
||||
// pre-check network ip connect
|
||||
var cli *gossh.Client
|
||||
cli, err = DialSshRemote(ctx, conf, ctx.Done())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer cli.Close()
|
||||
stdout, stderr, err = RemoteRun(cli,
|
||||
fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s'",
|
||||
conf.RemoteKubeconfig,
|
||||
@@ -863,3 +350,138 @@ func SshJumpAndSetEnv(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet
|
||||
}
|
||||
return os.Setenv(config.EnvSSHJump, path)
|
||||
}
|
||||
|
||||
func JumpTo(ctx context.Context, bClient *gossh.Client, to SshConfig, stopChan <-chan struct{}) (client *gossh.Client, err error) {
|
||||
if _, _, err = net.SplitHostPort(to.Addr); err != nil {
|
||||
// use default ssh port 22
|
||||
to.Addr = net.JoinHostPort(to.Addr, "22")
|
||||
err = nil
|
||||
}
|
||||
|
||||
var authMethod []gossh.AuthMethod
|
||||
authMethod, err = to.GetAuth()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Dial a connection to the service host, from the bastion
|
||||
var conn net.Conn
|
||||
conn, err = bClient.DialContext(ctx, "tcp", to.Addr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
if stopChan != nil {
|
||||
<-stopChan
|
||||
conn.Close()
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
bClient.Close()
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
var ncc gossh.Conn
|
||||
var chans <-chan gossh.NewChannel
|
||||
var reqs <-chan *gossh.Request
|
||||
ncc, chans, reqs, err = gossh.NewClientConn(conn, to.Addr, &gossh.ClientConfig{
|
||||
User: to.User,
|
||||
Auth: authMethod,
|
||||
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
|
||||
//BannerCallback: ssh.BannerDisplayStderr(),
|
||||
Timeout: time.Second * 10,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
client = gossh.NewClient(ncc, chans, reqs)
|
||||
return
|
||||
}
|
||||
|
||||
func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
|
||||
clientMap.Range(func(key, value any) bool {
|
||||
cli := value.(*sshClientWrap)
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
cancelFunc1()
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
clientMap.Delete(key)
|
||||
_ = cli.Close()
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
if conn != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *gossh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1))
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
|
||||
chDone := make(chan bool, 2)
|
||||
|
||||
// start remote -> local data transfer
|
||||
go func() {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
defer config.LPool.Put(buf[:])
|
||||
_, err := io.CopyBuffer(local, remote, buf)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err)
|
||||
}
|
||||
chDone <- true
|
||||
}()
|
||||
|
||||
// start local -> remote data transfer
|
||||
go func() {
|
||||
buf := config.LPool.Get().([]byte)[:]
|
||||
defer config.LPool.Put(buf[:])
|
||||
_, err := io.CopyBuffer(remote, local, buf)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err)
|
||||
}
|
||||
chDone <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-chDone:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user