hotfix: close ssh client if ctx done (#559)

This commit is contained in:
naison
2025-04-24 22:41:24 +08:00
committed by GitHub
parent 31186fc1d9
commit 6a8a197f48
6 changed files with 56 additions and 28 deletions

View File

@@ -81,5 +81,6 @@ func GvisorTCPListener(addr string) (net.Listener, error) {
_ = listener.Close()
return nil, err
}
plog.G(context.Background()).Debugf("Use tls mode")
return tls.NewListener(&tcpKeepAliveListener{TCPListener: listener}, serverConfig), nil
}

View File

@@ -38,7 +38,7 @@ func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, erro
plog.G(ctx).Debugf("tls config not found in config, use raw tcp mode")
return conn, nil
}
plog.G(ctx).Debugf("use tls mode")
plog.G(ctx).Debugf("Use tls mode")
return tls.Client(conn, tr.tlsConfig), nil
}

View File

@@ -150,8 +150,8 @@ func (w *wsHandler) createTwoWayTUNTunnel(ctx context.Context, cli *ssh.Client)
plog.G(ctx).Info("Connected private safe tunnel")
go func() {
for ctx.Err() == nil {
util.Ping(ctx, clientIP.IP.String(), ip.String())
time.Sleep(time.Second * 5)
util.PingOnce(ctx, clientIP.IP.String(), ip.String())
time.Sleep(time.Second * 15)
}
}()
return nil
@@ -436,7 +436,7 @@ func init() {
if errors.Is(err, io.EOF) {
return
} else if err != nil {
plog.G(context.Background()).Errorf("Session %s windos change w: %d h: %d failed: %v", sessionID, r.Width, r.Height, err)
plog.G(context.Background()).Errorf("Session %s windows change w: %d h: %d failed: %v", sessionID, r.Width, r.Height, err)
}
}
}))

View File

@@ -162,7 +162,11 @@ func (m *Mapper) Run(connectNamespace string) {
local := netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(containerPort))
remote := netip.AddrPortFrom(netip.IPv4Unspecified(), uint16(envoyRulePort))
for ctx.Err() == nil {
_ = ssh.ExposeLocalPortToRemote(ctx, remoteSSHServer, remote, local)
func() {
ctx2, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
_ = ssh.ExposeLocalPortToRemote(ctx2, remoteSSHServer, remote, local)
}()
time.Sleep(time.Second * 2)
}
}(containerPort, envoyRulePort)

View File

@@ -37,6 +37,7 @@ func ExposeLocalPortToRemote(ctx context.Context, remoteSSHServer, remotePort, l
plog.G(ctx).Errorf("Dial into remote server error: %s", err)
return err
}
defer serverConn.Close()
// Listen on remote server port
listener, err := serverConn.Listen("tcp", remotePort.String())
@@ -46,6 +47,12 @@ func ExposeLocalPortToRemote(ctx context.Context, remoteSSHServer, remotePort, l
}
defer listener.Close()
go func() {
<-ctx.Done()
listener.Close()
serverConn.Close()
}()
// handle incoming connections on reverse forwarded tunnel
for {
client, err := listener.Accept()

View File

@@ -91,10 +91,16 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var lc net.ListenConfig
localListen, e := lc.Listen(ctx, "tcp", local.String())
if e != nil {
plog.G(ctx).Errorf("failed to listen %s: %v", local.String(), e)
return e
}
plog.G(ctx).Debugf("SSH listening on local %s forward to %s", local.String(), remote.String())
go func() {
<-ctx.Done()
localListen.Close()
}()
go func() {
defer localListen.Close()
@@ -108,6 +114,7 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1)
continue
}
plog.G(ctx).Debugf("Accepted ssh conn from %s", localConn.RemoteAddr().String())
go func() {
defer localConn.Close()
@@ -117,11 +124,13 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
cancelFunc1()
}
plog.G(ctx).Debugf("Failed to get remote conn: %v", err)
return
}
plog.G(ctx).Debugf("Opened ssh port-forward: %s", remote.String())
defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn)
@@ -182,7 +191,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
err = errors.Wrap(err, string(stderr))
return
}
if len(stdout) == 0 {
if len(bytes.TrimSpace(stdout)) == 0 {
err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(stderr))
return
}
@@ -217,28 +226,34 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
var rawConfig api.Config
rawConfig, err = matchVersionFlags.ToRawKubeConfigLoader().RawConfig()
if err != nil {
plog.G(ctx).WithError(err).Errorf("failed to build config: %v", err)
return
}
if err = api.FlattenConfig(&rawConfig); err != nil {
plog.G(ctx).Errorf("failed to flatten config: %v", err)
return
}
if rawConfig.Contexts == nil {
err = errors.New("kubeconfig is invalid")
plog.G(ctx).Error("can not get contexts")
return
}
kubeContext := rawConfig.Contexts[rawConfig.CurrentContext]
if kubeContext == nil {
err = errors.New("kubeconfig is invalid")
plog.G(ctx).Errorf("can not find kubeconfig context %s", rawConfig.CurrentContext)
return
}
cluster := rawConfig.Clusters[kubeContext.Cluster]
if cluster == nil {
err = errors.New("kubeconfig is invalid")
plog.G(ctx).Errorf("can not find cluster %s", kubeContext.Cluster)
return
}
var u *url.URL
u, err = url.Parse(cluster.Server)
if err != nil {
plog.G(ctx).Errorf("failed to parse cluster url: %v", err)
return
}
@@ -252,6 +267,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
} else {
// handle other schemes if necessary
err = errors.New("kubeconfig is invalid: wrong protocol")
plog.G(ctx).Error(err)
return
}
}
@@ -263,6 +279,7 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
if len(ips) == 0 {
// handle error: no IP associated with the hostname
err = fmt.Errorf("kubeconfig: no IP associated with the hostname %s", serverHost)
plog.G(ctx).Error(err)
return
}
@@ -306,11 +323,13 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
var convertedObj runtime.Object
convertedObj, err = latest.Scheme.ConvertToVersion(&rawConfig, latest.ExternalVersion)
if err != nil {
plog.G(ctx).Errorf("failed to build config: %v", err)
return
}
var marshal []byte
marshal, err = json.Marshal(convertedObj)
if err != nil {
plog.G(ctx).Errorf("failed to marshal config: %v", err)
return
}
var temp *os.File
@@ -407,7 +426,9 @@ func JumpTo(ctx context.Context, bClient *gossh.Client, to SshConfig, stopChan <
return
}
func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (net.Conn, error) {
var conn net.Conn
var err error
clientMap.Range(func(key, value any) bool {
cli := value.(*sshClientWrap)
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
@@ -416,40 +437,35 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err)
clientMap.Delete(key)
plog.G(ctx).Error("Delete invalid ssh client from map")
_ = cli.Close()
return true
}
return false
})
if conn != nil {
return
return conn, nil
}
ctx1, cancelFunc1 := context.WithCancel(ctx)
defer func() {
if err != nil {
cancelFunc1()
}
}()
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
var client *gossh.Client
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
client, err = DialSshRemote(ctx1, conf, ctx1.Done())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err)
return nil, err
}
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc3()
conn, err = client.DialContext(ctx3, "tcp", remote.String())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
_ = client.Close()
cancelFunc1()
return nil, err
}
clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1))
plog.G(ctx1).Debug("Connected to remote SSH server")
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
conn, err = client.DialContext(ctx2, "tcp", remote.String())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
return nil, err
}
plog.G(ctx).Debugf("Connected to remote addr: %s", remote.String())
return conn, nil
}
@@ -462,7 +478,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(local, remote, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Debugf("Failed to copy remote -> local: %s", err)
plog.G(ctx).Errorf("Failed to copy remote -> local: %s", err)
}
chDone <- true
}()
@@ -473,7 +489,7 @@ func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(remote, local, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Debugf("Failed to copy local -> remote: %s", err)
plog.G(ctx).Errorf("Failed to copy local -> remote: %s", err)
}
chDone <- true
}()