mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-11-02 21:34:01 +08:00
fix: ssh portmap redo ssh dial (#170)
This commit is contained in:
@@ -82,7 +82,7 @@ func (w *wsHandler) handle(ctx context.Context) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = util.PortMapUntil(ctx, cli, remote, local)
|
||||
err = util.PortMapUntil(ctx, w.sshConfig, remote, local)
|
||||
if err != nil {
|
||||
w.Log("Port map error: %v", err)
|
||||
return
|
||||
|
||||
@@ -893,7 +893,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
||||
if print {
|
||||
log.Infof("wait jump to bastion host...")
|
||||
}
|
||||
err = util.PortMapUntil(ctx, cli, remote, local)
|
||||
err = util.PortMapUntil(ctx, conf, remote, local)
|
||||
if err != nil {
|
||||
log.Errorf("ssh proxy err: %v", err)
|
||||
return
|
||||
|
||||
114
pkg/util/ssh.go
114
pkg/util/ssh.go
@@ -67,39 +67,6 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
|
||||
}
|
||||
}
|
||||
|
||||
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
|
||||
// Listen on remote server port
|
||||
var lc net.ListenConfig
|
||||
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listen.Close()
|
||||
|
||||
select {
|
||||
case done <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
// handle incoming connections on reverse forwarded tunnel
|
||||
for ctx.Err() == nil {
|
||||
localConn, err := listen.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func(localConn net.Conn) {
|
||||
defer localConn.Close()
|
||||
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
copyStream(localConn, remoteConn)
|
||||
}(localConn)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
|
||||
var remote *ssh.Client
|
||||
@@ -170,6 +137,7 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
|
||||
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
|
||||
if er != nil {
|
||||
log.Errorf("failed to send keep alive error: %s", er)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -225,12 +193,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func copyStream(client net.Conn, remote net.Conn) {
|
||||
func copyStream(local net.Conn, remote net.Conn) {
|
||||
chDone := make(chan bool, 2)
|
||||
|
||||
// start remote -> local data transfer
|
||||
go func() {
|
||||
_, err := io.Copy(client, remote)
|
||||
_, err := io.Copy(local, remote)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("error while copy remote->local: %s", err)
|
||||
}
|
||||
@@ -242,7 +210,7 @@ func copyStream(client net.Conn, remote net.Conn) {
|
||||
|
||||
// start local -> remote data transfer
|
||||
go func() {
|
||||
_, err := io.Copy(remote, client)
|
||||
_, err := io.Copy(remote, local)
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("error while copy local->remote: %s", err)
|
||||
}
|
||||
@@ -388,31 +356,65 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
|
||||
errChan := make(chan error, 1)
|
||||
readyChan := make(chan struct{}, 1)
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
err := PortMap(ctx, cli, remote, local, readyChan)
|
||||
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
|
||||
// Listen on remote server port
|
||||
var lc net.ListenConfig
|
||||
localListen, err := lc.Listen(ctx, "tcp", local.String())
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Ssh forward failed err: %v", err)
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
|
||||
var lock sync.Mutex
|
||||
var cancelFunc context.CancelFunc
|
||||
var sshClient *ssh.Client
|
||||
|
||||
var getRemoteConnFunc = func() (net.Conn, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if sshClient != nil {
|
||||
remoteConn, err := sshClient.Dial("tcp", remote.String())
|
||||
if err == nil {
|
||||
return remoteConn, nil
|
||||
}
|
||||
sshClient.Close()
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 2)
|
||||
var ctx2 context.Context
|
||||
ctx2, cancelFunc = context.WithCancel(ctx)
|
||||
sshClient, err = DialSshRemote(ctx2, conf)
|
||||
if err != nil {
|
||||
cancelFunc()
|
||||
cancelFunc = nil
|
||||
log.Errorf("failed to dial remote ssh server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return sshClient.Dial("tcp", remote.String())
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer localListen.Close()
|
||||
|
||||
for ctx.Err() == nil {
|
||||
localConn, err := localListen.Accept()
|
||||
if err != nil {
|
||||
log.Errorf("failed to accept conn: %v", err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
defer localConn.Close()
|
||||
|
||||
remoteConn, err := getRemoteConnFunc()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial %s: %s", remote.String(), err)
|
||||
return
|
||||
}
|
||||
defer remoteConn.Close()
|
||||
copyStream(localConn, remoteConn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-readyChan:
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
log.Errorf("Ssh forward failed err: %v", err)
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user