hotfix: not cancel context after handle new local connection of PortmapUtil, otherwise ssh.client stop channel also closed

This commit is contained in:
fengcaiwen
2025-02-12 23:27:45 +08:00
parent 24367b1b82
commit 399bc4efe0

View File

@@ -530,6 +530,10 @@ func init() {
}) })
} }
func newSshClient(client *ssh.Client, cancel context.CancelFunc) *sshClient {
return &sshClient{Client: client, cancel: cancel}
}
type sshClient struct { type sshClient struct {
cancel context.CancelFunc cancel context.CancelFunc
*ssh.Client *ssh.Client
@@ -553,8 +557,39 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
defer localListen.Close() defer localListen.Close()
var sshClientChan = make(chan *sshClient, 1000*1000) var sshClientChan = make(chan *sshClient, 1000*1000)
ctx1, cancelFunc1 := context.WithCancel(ctx)
defer cancelFunc1()
var getRemoteConnFunc = func(connCtx context.Context) (conn net.Conn, err error) { for ctx1.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
log.Debugf("Failed to accept ssh conn: %v", err1)
continue
}
go func() {
defer localConn.Close()
remoteConn, err := getRemoteConn(ctx, sshClientChan, conf, remote)
if err != nil {
var openChannelError *ssh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
cancelFunc1()
}
log.Debugf("Failed to get remote conn: %v", err)
return
}
defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn)
}()
}
}()
return nil
}
func getRemoteConn(ctx context.Context, sshClientChan chan *sshClient, conf *SshConfig, remote netip.AddrPort) (conn net.Conn, err error) {
select { select {
case cli, ok := <-sshClientChan: case cli, ok := <-sshClientChan:
if !ok { if !ok {
@@ -565,16 +600,10 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
conn, err = cli.DialContext(ctx1, "tcp", remote.String()) conn, err = cli.DialContext(ctx1, "tcp", remote.String())
if err != nil { if err != nil {
log.Debugf("Failed to dial remote address %s: %s", remote.String(), err) log.Debugf("Failed to dial remote address %s: %s", remote.String(), err)
cli.Close() _ = cli.Close()
return nil, err return nil, err
} }
write := pkgutil.SafeWrite(sshClientChan, cli) safeWrite(ctx, sshClientChan, cli)
if !write {
go func() {
<-connCtx.Done()
cli.Close()
}()
}
return conn, nil return conn, nil
default: default:
ctx1, cancelFunc1 := context.WithCancel(ctx) ctx1, cancelFunc1 := context.WithCancel(ctx)
@@ -588,68 +617,31 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var client *ssh.Client var client *ssh.Client
client, err = DialSshRemote(ctx2, conf, ctx1.Done()) client, err = DialSshRemote(ctx2, conf, ctx1.Done())
if err != nil { if err != nil {
marshal, _ := json.Marshal(conf) log.Debugf("Failed to dial remote ssh server: %v", err)
log.Debugf("Failed to dial remote ssh server %v: %v", string(marshal), err)
return nil, err return nil, err
} }
ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10) ctx3, cancelFunc3 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc3() defer cancelFunc3()
conn, err = client.DialContext(ctx3, "tcp", remote.String()) conn, err = client.DialContext(ctx3, "tcp", remote.String())
if err != nil { if err != nil {
var openChannelError *ssh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == ssh.Prohibited {
_ = client.Close()
log.Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
<-connCtx.Done()
return nil, err
}
log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) log.Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
client.Close() client.Close()
return nil, err return nil, err
} }
cli := &sshClient{cancel: cancelFunc1, Client: client} cli := newSshClient(client, cancelFunc1)
safeWrite(ctx1, sshClientChan, cli)
return conn, nil
}
}
func safeWrite(ctx context.Context, sshClientChan chan *sshClient, cli *sshClient) {
write := pkgutil.SafeWrite(sshClientChan, cli) write := pkgutil.SafeWrite(sshClientChan, cli)
if !write { if !write {
go func() { go func() {
<-connCtx.Done() <-ctx.Done()
cli.Close() cli.Close()
}() }()
} }
return conn, nil
}
}
for ctx.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
log.Debugf("Failed to accept ssh conn: %v", err1)
continue
}
go func() {
defer localConn.Close()
cCtx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc()
var remoteConn net.Conn
var err error
for i := 0; i < 5; i++ {
remoteConn, err = getRemoteConnFunc(cCtx)
if err == nil {
break
}
}
if err != nil {
log.Debugf("Failed to get remote conn: %v", err)
return
}
defer remoteConn.Close()
copyStream(cCtx, localConn, remoteConn)
}()
}
}()
return nil
} }
func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) { func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) (path string, err error) {