diff --git a/pkg/ssh/gssapi.go b/pkg/ssh/gssapi.go index a109aad2..a45a6871 100644 --- a/pkg/ssh/gssapi.go +++ b/pkg/ssh/gssapi.go @@ -35,13 +35,6 @@ func NewKrb5InitiatorClientWithPassword(username, password, krb5Conf string) (kc return } - // Set to lookup KDCs in DNS - c.LibDefaults.DNSLookupKDC = true - c.LibDefaults.DNSLookupRealm = true - - // Blank out the KDCs to ensure they are not being used - c.Realms = []config.Realm{} - defaultRealm := c.LibDefaults.DefaultRealm cl := client.NewWithPassword(username, defaultRealm, password, c) @@ -65,12 +58,6 @@ func NewKrb5InitiatorClientWithKeytab(username string, krb5Conf, keytabConf stri if err != nil { return } - // Set to lookup KDCs in DNS - c.LibDefaults.DNSLookupKDC = true - c.LibDefaults.DNSLookupRealm = true - - // Blank out the KDCs to ensure they are not being used - c.Realms = []config.Realm{} // Init keytab from conf cache, err := keytab.Load(keytabConf) @@ -81,9 +68,6 @@ func NewKrb5InitiatorClientWithKeytab(username string, krb5Conf, keytabConf stri defaultRealm := c.LibDefaults.DefaultRealm cl := client.NewWithKeytab(username, defaultRealm, cache, c) - if err != nil { - return - } err = cl.Login() if err != nil { return @@ -105,13 +89,6 @@ func NewKrb5InitiatorClientWithCache(krb5Conf, cacheFile string) (kcl Krb5Initia return } - // Set to lookup KDCs in DNS - c.LibDefaults.DNSLookupKDC = true - c.LibDefaults.DNSLookupRealm = true - - // Blank out the KDCs to ensure they are not being used - c.Realms = []config.Realm{} - // Init krb5 client and login cache, err := credentials.LoadCCache(cacheFile) // https://stackoverflow.com/questions/58653482/what-is-the-default-kerberos-credential-cache-on-osx diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 7de1f1fb..70cd5425 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -530,6 +530,10 @@ func init() { }) } +func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient { + return &sshClient{Client: client, cancel: cancel} +} + type sshClient struct { cancel context.CancelFunc *ssh.Client @@ -553,74 +557,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr defer localListen.Close() var sshClientChan = make(chan *sshClient, 1000*1000) + ctx1, cancelFunc1 := context.WithCancel(ctx) + defer cancelFunc1() - var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) { - select { - case cli, ok := <-sshClientChan: - if !ok { - return nil, errors.New("ssh client chan closed") - } - ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc1() - conn, err = cli.DialContext(ctx1, "tcp", remote.String()) - if err != nil { - log.Debugf("Failed to dial remote address %s: %s", remote.String(), err) - cli.Close() - return nil, err - } - write := pkgutil.SafeWrite(sshClientChan, cli) - if !write { - go func() { - <-connCtx.Done() - cli.Close() - }() - } - return conn, nil - default: - ctx1, cancelFunc1 := context.WithCancel(ctx) - defer func() { - if err != nil { - cancelFunc1() - } - }() - ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc2() - var client *ssh.Client - client, err = DialSshRemote(ctx2, conf, ctx1.Done()) - if err != nil { - marshal, _ := json.Marshal(conf) - log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err) - return nil, err - } - ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) - defer cancelFunc3() - conn, err = client.DialContext(ctx3, "tcp", remote.String()) - if err != nil { - var openChannelError *ssh.OpenChannelError - // if ssh server not permitted ssh port-forward, do nothing until exit - if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited { - _ = client.Close() - log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err) - <-connCtx.Done() - return nil, err - } - log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) - client.Close() - return nil, err - } - cli := &sshClient{cancel: cancelFunc1, Client: client} - write := pkgutil.SafeWrite(sshClientChan, cli) - if !write { - go func() { - <-connCtx.Done() - cli.Close() - }() - } - return conn, nil - } - } - - for ctx.Err() == nil { + for ctx1.Err() == nil { localConn, err1 := localListen.Accept() if err1 != nil { log.Debugf("Failed to accept ssh conn: %v", err1) @@ -628,30 +568,82 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr } go func() { defer localConn.Close() - cCtx, cancelFunc := context.WithCancel(ctx) - defer cancelFunc() - var remoteConn net.Conn - var err error - for i := 0; i < 5; i++ { - remoteConn, err = getRemoteConnFunc(cCtx) - if err == nil { - break - } - } + remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote) if err != nil { + var openChannelError *ssh.OpenChannelError + // if ssh server not permitted ssh port-forward, do nothing until exit + if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited { + log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err) + cancelFunc1() + } log.Debugf("Failed to get remote conn: %v", err) return } defer remoteConn.Close() - copyStream(cCtx, localConn, remoteConn) + copyStream(ctx, localConn, remoteConn) }() } }() return nil } +func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) { + select { + case cli, ok := <-sshClientChan: + if !ok { + return nil, errors.New("ssh client chan closed") + } + ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc1() + conn, err = cli.DialContext(ctx1, "tcp", remote.String()) + if err != nil { + log.Debugf("Failed to dial remote address %s: %s", remote.String(), err) + _ = cli.Close() + return nil, err + } + safeWrite(ctx, sshClientChan, cli) + return conn, nil + default: + ctx1, cancelFunc1 := context.WithCancel(ctx) + defer func() { + if err != nil { + cancelFunc1() + } + }() + ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc2() + var client *ssh.Client + client, err = DialSshRemote(ctx2, conf, ctx1.Done()) + if err != nil { + log.Debugf("Failed to dial remote ssh server: %v", err) + return nil, err + } + ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) + defer cancelFunc3() + conn, err = client.DialContext(ctx3, "tcp", remote.String()) + if err != nil { + log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) + client.Close() + return nil, err + } + cli := newSshClient(client, cancelFunc1) + safeWrite(ctx1, sshClientChan, cli) + return conn, nil + } +} + +func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) { + write := pkgutil.SafeWrite(sshClientChan, cli) + if !write { + go func() { + <-ctx.Done() + cli.Close() + }() + } +} + func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) { if conf.Addr == "" && conf.ConfigAlias == "" { if flags != nil {