mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-04 06:56:39 +08:00
hotfix: not cancel context after handle new local connection of PortmapUtil, otherwise ssh.client stop channel also closed
This commit is contained in:
148
pkg/ssh/ssh.go
148
pkg/ssh/ssh.go
@@ -530,6 +530,10 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
|
||||
return &sshClient{Client: client, cancel: cancel}
|
||||
}
|
||||
|
||||
type sshClient struct {
|
||||
cancel context.CancelFunc
|
||||
*ssh.Client
|
||||
@@ -553,74 +557,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
defer localListen.Close()
|
||||
|
||||
var sshClientChan = make(chan *sshClient, 1000*1000)
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer cancelFunc1()
|
||||
|
||||
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
marshal, _ := json.Marshal(conf)
|
||||
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
var openChannelError *ssh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||
_ = client.Close()
|
||||
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||
<-connCtx.Done()
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := &sshClient{cancel: cancelFunc1, Client: client}
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-connCtx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
for ctx.Err() == nil {
|
||||
for ctx1.Err() == nil {
|
||||
localConn, err1 := localListen.Accept()
|
||||
if err1 != nil {
|
||||
log.Debugf("Failed to accept ssh conn: %v", err1)
|
||||
@@ -628,30 +568,82 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
||||
}
|
||||
go func() {
|
||||
defer localConn.Close()
|
||||
cCtx, cancelFunc := context.WithCancel(ctx)
|
||||
defer cancelFunc()
|
||||
|
||||
var remoteConn net.Conn
|
||||
var err error
|
||||
for i := 0; i < 5; i++ {
|
||||
remoteConn, err = getRemoteConnFunc(cCtx)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
|
||||
if err != nil {
|
||||
var openChannelError *ssh.OpenChannelError
|
||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||
cancelFunc1()
|
||||
}
|
||||
log.Debugf("Failed to get remote conn: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer remoteConn.Close()
|
||||
copyStream(cCtx, localConn, remoteConn)
|
||||
copyStream(ctx, localConn, remoteConn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
|
||||
select {
|
||||
case cli, ok := <-sshClientChan:
|
||||
if !ok {
|
||||
return nil, errors.New("ssh client chan closed")
|
||||
}
|
||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc1()
|
||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||
_ = cli.Close()
|
||||
return nil, err
|
||||
}
|
||||
safeWrite(ctx, sshClientChan, cli)
|
||||
return conn, nil
|
||||
default:
|
||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
cancelFunc1()
|
||||
}
|
||||
}()
|
||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc2()
|
||||
var client *ssh.Client
|
||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||
defer cancelFunc3()
|
||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||
if err != nil {
|
||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||
client.Close()
|
||||
return nil, err
|
||||
}
|
||||
cli := newSshClient(client, cancelFunc1)
|
||||
safeWrite(ctx1, sshClientChan, cli)
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
|
||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||
if !write {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
cli.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {
|
||||
if conf.Addr == "" && conf.ConfigAlias == "" {
|
||||
if flags != nil {
|
||||
|
Reference in New Issue
Block a user