mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-12-24 11:51:13 +08:00
Merge pull request #433 from kubenetworks/hotfix/use-default-krb5-config
hotfix: use default krb5 config and not cancel context after handle new local connection of PortmapUtil, otherwise ssh.client stop channel also closed
This commit is contained in:
@@ -35,13 +35,6 @@ func NewKrb5InitiatorClientWithPassword(username, password, krb5Conf string) (kc
|
||||
return
|
||||
}
|
||||
|
||||
// Set to lookup KDCs in DNS
|
||||
c.LibDefaults.DNSLookupKDC = true
|
||||
c.LibDefaults.DNSLookupRealm = true
|
||||
|
||||
// Blank out the KDCs to ensure they are not being used
|
||||
c.Realms = []config.Realm{}
|
||||
|
||||
defaultRealm := c.LibDefaults.DefaultRealm
|
||||
|
||||
cl := client.NewWithPassword(username, defaultRealm, password, c)
|
||||
@@ -65,12 +58,6 @@ func NewKrb5InitiatorClientWithKeytab(username string, krb5Conf, keytabConf stri
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Set to lookup KDCs in DNS
|
||||
c.LibDefaults.DNSLookupKDC = true
|
||||
c.LibDefaults.DNSLookupRealm = true
|
||||
|
||||
// Blank out the KDCs to ensure they are not being used
|
||||
c.Realms = []config.Realm{}
|
||||
|
||||
// Init keytab from conf
|
||||
cache, err := keytab.Load(keytabConf)
|
||||
@@ -81,9 +68,6 @@ func NewKrb5InitiatorClientWithKeytab(username string, krb5Conf, keytabConf stri
|
||||
defaultRealm := c.LibDefaults.DefaultRealm
|
||||
|
||||
cl := client.NewWithKeytab(username, defaultRealm, cache, c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = cl.Login()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -105,13 +89,6 @@ func NewKrb5InitiatorClientWithCache(krb5Conf, cacheFile string) (kcl Krb5Initia
|
||||
return
|
||||
}
|
||||
|
||||
// Set to lookup KDCs in DNS
|
||||
c.LibDefaults.DNSLookupKDC = true
|
||||
c.LibDefaults.DNSLookupRealm = true
|
||||
|
||||
// Blank out the KDCs to ensure they are not being used
|
||||
c.Realms = []config.Realm{}
|
||||
|
||||
// Init krb5 client and login
|
||||
cache, err := credentials.LoadCCache(cacheFile)
|
||||
// https://stackoverflow.com/questions/58653482/what-is-the-default-kerberos-credential-cache-on-osx
|
||||
|
||||
148
pkg/ssh/ssh.go
148
pkg/ssh/ssh.go
@@ -530,6 +530,10 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
|
||||
return &sshClient{Client: client, cancel: cancel}
|
||||
}
|
||||
|
||||
type sshClient struct {
|
||||
cancel context.CancelFunc
|
||||
*ssh.Client
|
||||
@@ -553,74 +557,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
defer localListen.Close()
|
||||
|
||||
var sshClientChan = make(chan *sshClient, 1000*1000)
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer cancelFunc1()
|
||||
|
||||
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
marshal, _ := json.Marshal(conf)
|
||||
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
var openChannelError *ssh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||
_ = client.Close()
|
||||
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||
<-connCtx.Done()
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := &sshClient{cancel: cancelFunc1, Client: client}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
for ctx1.Err() == nil {
|
||||
localConn, err1 := localListen.Accept()
|
||||
if err1 != nil {
|
||||
log.Debugf("Failed to accept ssh conn: %v", err1)
|
||||
@@ -628,30 +568,82 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
}
|
||||
go func() {
|
||||
defer localConn.Close()
|
||||
cCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
var remoteConn net.Conn
|
||||
var err error
|
||||
for i := 0; i < 5; i++ {
|
||||
remoteConn, err = getRemoteConnFunc(cCtx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
|
||||
if err != nil {
|
||||
var openChannelError *ssh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||
cancelFunc1()
|
||||
}
|
||||
log.Debugf("Failed to get remote conn: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer remoteConn.Close()
|
||||
copyStream(cCtx, localConn, remoteConn)
|
||||
copyStream(ctx, localConn, remoteConn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
_ = cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
safeWrite(ctx, sshClientChan, cli)
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := newSshClient(client, cancelFunc1)
|
||||
safeWrite(ctx1, sshClientChan, cli)
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {
|
||||
if conf.Addr == "" && conf.ConfigAlias == "" {
|
||||
if flags != nil {
|
||||
|
||||
Reference in New Issue
Block a user