mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-14 03:23:47 +08:00
fix: fix some bug and optimize ssh logic (#147)
This commit is contained in:
116
pkg/util/ssh.go
116
pkg/util/ssh.go
@@ -67,31 +67,7 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
|
||||
}
|
||||
}
|
||||
|
||||
func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, conf *SshConfig, done chan struct{}) error {
|
||||
ctx, cancelFunc := context.WithCancel(pctx)
|
||||
defer cancelFunc()
|
||||
|
||||
sshClient, err := DialSshRemote(conf)
|
||||
if err != nil {
|
||||
log.Errorf("Dial into remote server error: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// ref: https://github.com/golang/go/issues/21478
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 5)
|
||||
defer ticker.Stop()
|
||||
select {
|
||||
case <-ticker.C:
|
||||
_, _, err := sshClient.SendRequest("keepalive@golang.org", true, nil)
|
||||
if err != nil {
|
||||
log.Errorf("failed to send keep alive error: %s", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
|
||||
// Listen on remote server port
|
||||
var lc net.ListenConfig
|
||||
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
|
||||
@@ -105,34 +81,27 @@ func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, co
|
||||
default:
|
||||
}
|
||||
// handle incoming connections on reverse forwarded tunnel
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
local, err := listen.Accept()
|
||||
for ctx.Err() == nil {
|
||||
localConn, err := listen.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func(local net.Conn) {
|
||||
defer local.Close()
|
||||
conn, err := sshClient.Dial("tcp", remoteEndpoint.String())
|
||||
go func(localConn net.Conn) {
|
||||
defer localConn.Close()
|
||||
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
|
||||
cancelFunc()
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
handleClient(local, conn)
|
||||
}(local)
|
||||
defer remoteConn.Close()
|
||||
copyStream(localConn, remoteConn)
|
||||
}(localConn)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// todo ssh heartbeats
|
||||
// https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
|
||||
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
|
||||
var remote *ssh.Client
|
||||
var err error
|
||||
if conf.ConfigAlias != "" {
|
||||
@@ -189,22 +158,33 @@ func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
|
||||
// Connect to SSH remote server using serverEndpoint
|
||||
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
|
||||
}
|
||||
|
||||
// ref: https://github.com/golang/go/issues/21478
|
||||
if err == nil {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second * 5)
|
||||
defer ticker.Stop()
|
||||
for ctx.Err() == nil {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
|
||||
if er != nil {
|
||||
log.Errorf("failed to send keep alive error: %s", er)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return remote, err
|
||||
}
|
||||
|
||||
func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
||||
var remote *ssh.Client
|
||||
remote, err = DialSshRemote(conf)
|
||||
if err != nil {
|
||||
log.Errorf("Dial into remote server error: %s", err)
|
||||
return
|
||||
}
|
||||
defer remote.Close()
|
||||
func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
||||
var session *ssh.Session
|
||||
session, err = remote.NewSession()
|
||||
session, err = client.NewSession()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer session.Close()
|
||||
for k, v := range env {
|
||||
// /etc/ssh/sshd_config
|
||||
// AcceptEnv DEBIAN_FRONTEND
|
||||
@@ -213,7 +193,6 @@ func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byt
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
defer remote.Close()
|
||||
var out bytes.Buffer
|
||||
var er bytes.Buffer
|
||||
session.Stdout = &out
|
||||
@@ -246,7 +225,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func handleClient(client net.Conn, remote net.Conn) {
|
||||
func copyStream(client net.Conn, remote net.Conn) {
|
||||
chDone := make(chan bool, 2)
|
||||
|
||||
// start remote -> local data transfer
|
||||
@@ -408,3 +387,32 @@ func init() {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
|
||||
errChan := make(chan error, 1)
|
||||
readyChan := make(chan struct{}, 1)
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
err := PortMap(ctx, cli, remote, local, readyChan)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
log.Errorf("Ssh forward failed err: %v", err)
|
||||
}
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 2)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-readyChan:
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
log.Errorf("Ssh forward failed err: %v", err)
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user