mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-04 23:12:42 +08:00
hotfix: not cancel context after handle new local connection of PortmapUtil, otherwise ssh.client stop channel also closed
This commit is contained in:
148
pkg/ssh/ssh.go
148
pkg/ssh/ssh.go
@@ -530,6 +530,10 @@ func init() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
|
||||||
|
return &sshClient{Client: client, cancel: cancel}
|
||||||
|
}
|
||||||
|
|
||||||
type sshClient struct {
|
type sshClient struct {
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
*ssh.Client
|
*ssh.Client
|
||||||
@@ -553,74 +557,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
|||||||
defer localListen.Close()
|
defer localListen.Close()
|
||||||
|
|
||||||
var sshClientChan = make(chan *sshClient, 1000*1000)
|
var sshClientChan = make(chan *sshClient, 1000*1000)
|
||||||
|
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||||
|
defer cancelFunc1()
|
||||||
|
|
||||||
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
|
for ctx1.Err() == nil {
|
||||||
select {
|
|
||||||
case cli, ok := <-sshClientChan:
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("ssh client chan closed")
|
|
||||||
}
|
|
||||||
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
|
||||||
defer cancelFunc1()
|
|
||||||
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
|
||||||
cli.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
|
||||||
if !write {
|
|
||||||
go func() {
|
|
||||||
<-connCtx.Done()
|
|
||||||
cli.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
default:
|
|
||||||
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
cancelFunc1()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
|
||||||
defer cancelFunc2()
|
|
||||||
var client *ssh.Client
|
|
||||||
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
|
||||||
if err != nil {
|
|
||||||
marshal, _ := json.Marshal(conf)
|
|
||||||
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
|
||||||
defer cancelFunc3()
|
|
||||||
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
|
||||||
if err != nil {
|
|
||||||
var openChannelError *ssh.OpenChannelError
|
|
||||||
// if ssh server not permitted ssh port-forward, do nothing until exit
|
|
||||||
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
|
||||||
_ = client.Close()
|
|
||||||
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
|
||||||
<-connCtx.Done()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
|
||||||
client.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cli := &sshClient{cancel: cancelFunc1, Client: client}
|
|
||||||
write := pkgutil.SafeWrite(sshClientChan, cli)
|
|
||||||
if !write {
|
|
||||||
go func() {
|
|
||||||
<-connCtx.Done()
|
|
||||||
cli.Close()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for ctx.Err() == nil {
|
|
||||||
localConn, err1 := localListen.Accept()
|
localConn, err1 := localListen.Accept()
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
log.Debugf("Failed to accept ssh conn: %v", err1)
|
log.Debugf("Failed to accept ssh conn: %v", err1)
|
||||||
@@ -628,30 +568,82 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
|
|||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
defer localConn.Close()
|
defer localConn.Close()
|
||||||
cCtx, cancelFunc := context.WithCancel(ctx)
|
|
||||||
defer cancelFunc()
|
|
||||||
|
|
||||||
var remoteConn net.Conn
|
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
|
||||||
var err error
|
|
||||||
for i := 0; i < 5; i++ {
|
|
||||||
remoteConn, err = getRemoteConnFunc(cCtx)
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
var openChannelError *ssh.OpenChannelError
|
||||||
|
// if ssh server not permitted ssh port-forward, do nothing until exit
|
||||||
|
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
|
||||||
|
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
|
||||||
|
cancelFunc1()
|
||||||
|
}
|
||||||
log.Debugf("Failed to get remote conn: %v", err)
|
log.Debugf("Failed to get remote conn: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer remoteConn.Close()
|
defer remoteConn.Close()
|
||||||
copyStream(cCtx, localConn, remoteConn)
|
copyStream(ctx, localConn, remoteConn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
|
||||||
|
select {
|
||||||
|
case cli, ok := <-sshClientChan:
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("ssh client chan closed")
|
||||||
|
}
|
||||||
|
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
|
||||||
|
defer cancelFunc1()
|
||||||
|
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
|
||||||
|
_ = cli.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
safeWrite(ctx, sshClientChan, cli)
|
||||||
|
return conn, nil
|
||||||
|
default:
|
||||||
|
ctx1, cancelFunc1 := context.WithCancel(ctx)
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
cancelFunc1()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
|
||||||
|
defer cancelFunc2()
|
||||||
|
var client *ssh.Client
|
||||||
|
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to dial remote ssh server: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
|
||||||
|
defer cancelFunc3()
|
||||||
|
conn, err = client.DialContext(ctx3, "tcp", remote.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
|
||||||
|
client.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cli := newSshClient(client, cancelFunc1)
|
||||||
|
safeWrite(ctx1, sshClientChan, cli)
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
|
||||||
|
write := pkgutil.SafeWrite(sshClientChan, cli)
|
||||||
|
if !write {
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
cli.Close()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {
|
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {
|
||||||
if conf.Addr == "" && conf.ConfigAlias == "" {
|
if conf.Addr == "" && conf.ConfigAlias == "" {
|
||||||
if flags != nil {
|
if flags != nil {
|
||||||
|
Reference in New Issue
Block a user