fix: ssh portmap redo ssh dial (#170)

This commit is contained in:
naison
2024-02-17 19:25:35 +08:00
committed by GitHub
parent 46fcf5521f
commit 01e3456ad3
3 changed files with 64 additions and 62 deletions

View File

@@ -82,7 +82,7 @@ func (w *wsHandler) handle(ctx context.Context) {
if err != nil {
return
}
err = util.PortMapUntil(ctx, cli, remote, local)
err = util.PortMapUntil(ctx, w.sshConfig, remote, local)
if err != nil {
w.Log("Port map error: %v", err)
return

View File

@@ -893,7 +893,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
if print {
log.Infof("wait jump to bastion host...")
}
err = util.PortMapUntil(ctx, cli, remote, local)
err = util.PortMapUntil(ctx, conf, remote, local)
if err != nil {
log.Errorf("ssh proxy err: %v", err)
return

View File

@@ -67,39 +67,6 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
}
}
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
// Listen on remote server port
var lc net.ListenConfig
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
if err != nil {
return err
}
defer listen.Close()
select {
case done <- struct{}{}:
default:
}
// handle incoming connections on reverse forwarded tunnel
for ctx.Err() == nil {
localConn, err := listen.Accept()
if err != nil {
return err
}
go func(localConn net.Conn) {
defer localConn.Close()
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
if err != nil {
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
return
}
defer remoteConn.Close()
copyStream(localConn, remoteConn)
}(localConn)
}
return ctx.Err()
}
// DialSshRemote https://github.com/golang/go/issues/21478
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
var remote *ssh.Client
@@ -170,6 +137,7 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
if er != nil {
log.Errorf("failed to send keep alive error: %s", er)
return
}
}
}
@@ -225,12 +193,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
return ssh.PublicKeys(key), nil
}
func copyStream(client net.Conn, remote net.Conn) {
func copyStream(local net.Conn, remote net.Conn) {
chDone := make(chan bool, 2)
// start remote -> local data transfer
go func() {
_, err := io.Copy(client, remote)
_, err := io.Copy(local, remote)
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("error while copy remote->local: %s", err)
}
@@ -242,7 +210,7 @@ func copyStream(client net.Conn, remote net.Conn) {
// start local -> remote data transfer
go func() {
_, err := io.Copy(remote, client)
_, err := io.Copy(remote, local)
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("error while copy local->remote: %s", err)
}
@@ -388,31 +356,65 @@ func init() {
})
}
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
errChan := make(chan error, 1)
readyChan := make(chan struct{}, 1)
go func() {
for ctx.Err() == nil {
err := PortMap(ctx, cli, remote, local, readyChan)
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
// Listen on remote server port
var lc net.ListenConfig
localListen, err := lc.Listen(ctx, "tcp", local.String())
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Errorf("Ssh forward failed err: %v", err)
return err
}
select {
case errChan <- err:
default:
var lock sync.Mutex
var cancelFunc context.CancelFunc
var sshClient *ssh.Client
var getRemoteConnFunc = func() (net.Conn, error) {
lock.Lock()
defer lock.Unlock()
if sshClient != nil {
remoteConn, err := sshClient.Dial("tcp", remote.String())
if err == nil {
return remoteConn, nil
}
sshClient.Close()
if cancelFunc != nil {
cancelFunc()
}
}
time.Sleep(time.Second * 2)
var ctx2 context.Context
ctx2, cancelFunc = context.WithCancel(ctx)
sshClient, err = DialSshRemote(ctx2, conf)
if err != nil {
cancelFunc()
cancelFunc = nil
log.Errorf("failed to dial remote ssh server: %v", err)
return nil, err
}
return sshClient.Dial("tcp", remote.String())
}
go func() {
defer localListen.Close()
for ctx.Err() == nil {
localConn, err := localListen.Accept()
if err != nil {
log.Errorf("failed to accept conn: %v", err)
return
}
go func() {
defer localConn.Close()
remoteConn, err := getRemoteConnFunc()
if err != nil {
log.Errorf("Failed to dial %s: %s", remote.String(), err)
return
}
defer remoteConn.Close()
copyStream(localConn, remoteConn)
}()
}
}()
select {
case <-readyChan:
return nil
case err := <-errChan:
log.Errorf("Ssh forward failed err: %v", err)
return err
case <-ctx.Done():
return ctx.Err()
}
}