hotfix: not cancel context after handle new local connection of PortmapUtil, otherwise ssh.client stop channel also closed

This commit is contained in:
fengcaiwen
2025-02-12 23:27:45 +08:00
parent 24367b1b82
commit 399bc4efe0

View File

@@ -530,6 +530,10 @@ func init() {
})
}
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
return &sshClient{Client: client, cancel: cancel}
}
type sshClient struct {
cancel context.CancelFunc
*ssh.Client
@@ -553,8 +557,39 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
defer localListen.Close()
var sshClientChan = make(chan *sshClient, 1000*1000)
ctx1, cancelFunc1 := context.WithCancel(ctx)
defer cancelFunc1()
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) {
for ctx1.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
log.Debugf("Failed to accept ssh conn: %v", err1)
continue
}
go func() {
defer localConn.Close()
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
if err != nil {
var openChannelError *ssh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
cancelFunc1()
}
log.Debugf("Failed to get remote conn: %v", err)
return
}
defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn)
}()
}
}()
return nil
}
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
select {
case cli, ok := <-sshClientChan:
if !ok {
@@ -565,16 +600,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
if err != nil {
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
cli.Close()
_ = cli.Close()
return nil, err
}
write := pkgutil.SafeWrite(sshClientChan, cli)
if !write {
go func() {
<-connCtx.Done()
cli.Close()
}()
}
safeWrite(ctx, sshClientChan, cli)
return conn, nil
default:
ctx1, cancelFunc1 := context.WithCancel(ctx)
@@ -588,68 +617,31 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var client *ssh.Client
client, err = DialSshRemote(ctx2, conf, ctx1.Done())
if err != nil {
marshal, _ := json.Marshal(conf)
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
log.Debugf("Failed to dial remote ssh server: %v", err)
return nil, err
}
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc3()
conn, err = client.DialContext(ctx3, "tcp", remote.String())
if err != nil {
var openChannelError *ssh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
_ = client.Close()
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
<-connCtx.Done()
return nil, err
}
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
client.Close()
return nil, err
}
cli := &sshClient{cancel: cancelFunc1, Client: client}
write := pkgutil.SafeWrite(sshClientChan, cli)
if !write {
go func() {
<-connCtx.Done()
cli.Close()
}()
}
cli := newSshClient(client, cancelFunc1)
safeWrite(ctx1, sshClientChan, cli)
return conn, nil
}
}
for ctx.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
log.Debugf("Failed to accept ssh conn: %v", err1)
continue
}
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
write := pkgutil.SafeWrite(sshClientChan, cli)
if !write {
go func() {
defer localConn.Close()
cCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
var remoteConn net.Conn
var err error
for i := 0; i < 5; i++ {
remoteConn, err = getRemoteConnFunc(cCtx)
if err == nil {
break
}
}
if err != nil {
log.Debugf("Failed to get remote conn: %v", err)
return
}
defer remoteConn.Close()
copyStream(cCtx, localConn, remoteConn)
<-ctx.Done()
cli.Close()
}()
}
}()
return nil
}
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {