mirror of
https://github.com/kubenetworks/kubevpn.git
synced 2025-10-05 15:26:57 +08:00
fix: fix some bug and optimize ssh logic (#147)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package cmds
|
package cmds
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -87,10 +88,27 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
go moniteSize(sessionID)
|
|
||||||
go io.Copy(conn, os.Stdin)
|
errChan := make(chan error, 3)
|
||||||
_, err = io.Copy(os.Stdout, conn)
|
go func() {
|
||||||
return err
|
err := monitorSize(cmd.Context(), sessionID)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(conn, os.Stdin)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(os.Stdout, conn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-errChan:
|
||||||
|
return err
|
||||||
|
case <-cmd.Context().Done():
|
||||||
|
return cmd.Context().Err()
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
addSshFlags(cmd, sshConf)
|
addSshFlags(cmd, sshConf)
|
||||||
@@ -98,7 +116,7 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command {
|
|||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func moniteSize(sessionID string) error {
|
func monitorSize(ctx context.Context, sessionID string) error {
|
||||||
conn := daemon.GetTCPClient(true)
|
conn := daemon.GetTCPClient(true)
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return fmt.Errorf("conn is nil")
|
return fmt.Errorf("conn is nil")
|
||||||
@@ -125,14 +143,15 @@ func moniteSize(sessionID string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
encoder := json.NewEncoder(client)
|
encoder := json.NewEncoder(client)
|
||||||
for {
|
for ctx.Err() == nil {
|
||||||
size := sizeQueue.Next()
|
size := sizeQueue.Next()
|
||||||
if size == nil {
|
if size == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if err = encoder.Encode(&size); err != nil {
|
if err = encoder.Encode(&size); err != nil {
|
||||||
log.Errorf("Encode resize: %s", err)
|
log.Errorf("Encode resize: %s", err)
|
||||||
continue
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -4,10 +4,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/spf13/pflag"
|
"github.com/spf13/pflag"
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/handler"
|
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
|
||||||
|
|
||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
|
"github.com/wencaiwulue/kubevpn/v2/pkg/daemon/rpc"
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/handler"
|
||||||
|
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var CancelFunc = make(map[string]context.CancelFunc)
|
var CancelFunc = make(map[string]context.CancelFunc)
|
||||||
|
@@ -202,6 +202,7 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
|
|||||||
}
|
}
|
||||||
log.Info(recv.Message)
|
log.Info(recv.Message)
|
||||||
}
|
}
|
||||||
|
svr.secondaryConnect = append(svr.secondaryConnect[:i], svr.secondaryConnect[i+1:]...)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -25,7 +25,7 @@ import (
|
|||||||
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServer) error {
|
func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServer) (e error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
log.SetOutput(svr.LogFile)
|
log.SetOutput(svr.LogFile)
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
@@ -45,6 +45,15 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
|
|||||||
// todo define already connect error?
|
// todo define already connect error?
|
||||||
return status.Error(codes.AlreadyExists, s)
|
return status.Error(codes.AlreadyExists, s)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
if e != nil {
|
||||||
|
if svr.connect != nil {
|
||||||
|
svr.connect.Cleanup()
|
||||||
|
svr.connect = nil
|
||||||
|
}
|
||||||
|
svr.t = time.Time{}
|
||||||
|
}
|
||||||
|
}()
|
||||||
svr.t = time.Now()
|
svr.t = time.Now()
|
||||||
svr.connect = &handler.ConnectOptions{
|
svr.connect = &handler.ConnectOptions{
|
||||||
Namespace: req.Namespace,
|
Namespace: req.Namespace,
|
||||||
|
@@ -43,14 +43,18 @@ type wsHandler struct {
|
|||||||
// 1) start remote kubevpn server
|
// 1) start remote kubevpn server
|
||||||
// 2) start local tunnel
|
// 2) start local tunnel
|
||||||
// 3) ssh terminal
|
// 3) ssh terminal
|
||||||
func (w *wsHandler) handle(ctx2 context.Context) {
|
func (w *wsHandler) handle(ctx context.Context) {
|
||||||
conn := w.conn
|
ctx, f := context.WithCancel(ctx)
|
||||||
sshConfig := w.sshConfig
|
defer f()
|
||||||
cidr := w.cidr
|
|
||||||
ctx, cancelFunc := context.WithCancel(ctx2)
|
|
||||||
defer cancelFunc()
|
|
||||||
|
|
||||||
err := w.remoteInstallKubevpnIfCommandNotFound(ctx, sshConfig)
|
cli, err := util.DialSshRemote(ctx, w.sshConfig)
|
||||||
|
if err != nil {
|
||||||
|
w.Log("Dial ssh remote error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer cli.Close()
|
||||||
|
|
||||||
|
err = w.installKubevpnOnRemote(ctx, cli)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Log("Install kubevpn error: %v", err)
|
w.Log("Install kubevpn error: %v", err)
|
||||||
return
|
return
|
||||||
@@ -62,16 +66,32 @@ func (w *wsHandler) handle(ctx2 context.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
local, err := w.portMap(ctx, sshConfig)
|
remotePort := 10800
|
||||||
|
var localPort int
|
||||||
|
localPort, err = util.GetAvailableTCPPortOrDie()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var remote netip.AddrPort
|
||||||
|
remote, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var local netip.AddrPort
|
||||||
|
local, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(localPort)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = util.PortMapUntil(ctx, cli, remote, local)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Log("Port map error: %v", err)
|
w.Log("Port map error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// startup daemon process if daemon process not start
|
// startup daemon process if daemon process not start
|
||||||
startDaemonCmd := fmt.Sprintf(`export %s=%s && kubevpn get service > /dev/null 2>&1 &`, config.EnvStartSudoKubeVPNByKubeVPN, "true")
|
startDaemonCmd := fmt.Sprintf(`export %s=%s && kubevpn get service > /dev/null 2>&1 &`, config.EnvStartSudoKubeVPNByKubeVPN, "true")
|
||||||
_, _, _ = util.RemoteRun(sshConfig, startDaemonCmd, nil)
|
util.RemoteRun(cli, startDaemonCmd, nil)
|
||||||
cmd := fmt.Sprintf(`export %s=%s && kubevpn ssh-daemon --client-ip %s`, config.EnvStartSudoKubeVPNByKubeVPN, "true", clientIP.String())
|
cmd := fmt.Sprintf(`export %s=%s && kubevpn ssh-daemon --client-ip %s`, config.EnvStartSudoKubeVPNByKubeVPN, "true", clientIP.String())
|
||||||
serverIP, stderr, err := util.RemoteRun(sshConfig, cmd, nil)
|
serverIP, stderr, err := util.RemoteRun(cli, cmd, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("run error: %v", err)
|
log.Errorf("run error: %v", err)
|
||||||
log.Errorf("run stdout: %v", string(serverIP))
|
log.Errorf("run stdout: %v", string(serverIP))
|
||||||
@@ -86,12 +106,12 @@ func (w *wsHandler) handle(ctx2 context.Context) {
|
|||||||
}
|
}
|
||||||
msg := fmt.Sprintf("| You can use client: %s to communicate with server: %s |", clientIP.IP.String(), ip.String())
|
msg := fmt.Sprintf("| You can use client: %s to communicate with server: %s |", clientIP.IP.String(), ip.String())
|
||||||
w.PrintLine(msg)
|
w.PrintLine(msg)
|
||||||
cidr = append(cidr, string(serverIP))
|
w.cidr = append(w.cidr, string(serverIP))
|
||||||
r := core.Route{
|
r := core.Route{
|
||||||
ServeNodes: []string{
|
ServeNodes: []string{
|
||||||
fmt.Sprintf("tun:/127.0.0.1:8422?net=%s&route=%s", clientIP, strings.Join(cidr, ",")),
|
fmt.Sprintf("tun:/127.0.0.1:8422?net=%s&route=%s", clientIP, strings.Join(w.cidr, ",")),
|
||||||
},
|
},
|
||||||
ChainNode: fmt.Sprintf("tcp://127.0.0.1:%d", local),
|
ChainNode: fmt.Sprintf("tcp://127.0.0.1:%d", localPort),
|
||||||
Retries: 5,
|
Retries: 5,
|
||||||
}
|
}
|
||||||
servers, err := handler.Parse(r)
|
servers, err := handler.Parse(r)
|
||||||
@@ -124,78 +144,14 @@ func (w *wsHandler) handle(ctx2 context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err = w.terminal(ctx, sshConfig, conn)
|
err = w.terminal(ctx, cli, w.conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Log("Enter terminal error: %v", err)
|
w.Log("Enter terminal error: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wsHandler) portMap(ctx context.Context, conf *util.SshConfig) (localPort int, err error) {
|
func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn *websocket.Conn) error {
|
||||||
remotePort := 10800
|
|
||||||
localPort, err = util.GetAvailableTCPPortOrDie()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var remote netip.AddrPort
|
|
||||||
remote, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var local netip.AddrPort
|
|
||||||
local, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(localPort)))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// pre-check network ip connect
|
|
||||||
var cli *ssh.Client
|
|
||||||
cli, err = util.DialSshRemote(conf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
_ = cli.Close()
|
|
||||||
}
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
readyChan := make(chan struct{}, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
err := util.Main(ctx, remote, local, conf, readyChan)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
log.Errorf("ssh forward failed err: %v", err)
|
|
||||||
w.Log("Ssh forward failed err: %v", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case errChan <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 2)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-readyChan:
|
|
||||||
return
|
|
||||||
case err = <-errChan:
|
|
||||||
w.Log("Ssh forward failed err: %v", err)
|
|
||||||
log.Errorf("ssh proxy err: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *websocket.Conn) error {
|
|
||||||
cli, err := util.DialSshRemote(conf)
|
|
||||||
if err != nil {
|
|
||||||
w.Log("Dial remote error: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
session, err := cli.NewSession()
|
session, err := cli.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.Log("New session error: %v", err)
|
w.Log("New session error: %v", err)
|
||||||
@@ -205,7 +161,6 @@ func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *we
|
|||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
session.Close()
|
session.Close()
|
||||||
cli.Close()
|
|
||||||
}()
|
}()
|
||||||
session.Stdout = conn
|
session.Stdout = conn
|
||||||
session.Stderr = conn
|
session.Stderr = conn
|
||||||
@@ -232,11 +187,11 @@ func (w *wsHandler) terminal(ctx context.Context, conf *util.SshConfig, conn *we
|
|||||||
return session.Wait()
|
return session.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wsHandler) remoteInstallKubevpnIfCommandNotFound(ctx context.Context, sshConfig *util.SshConfig) error {
|
func (w *wsHandler) installKubevpnOnRemote(ctx context.Context, sshClient *ssh.Client) error {
|
||||||
cmd := `hash kubevpn || type kubevpn || which kubevpn || command -v kubevpn`
|
cmd := `hash kubevpn || type kubevpn || which kubevpn || command -v kubevpn`
|
||||||
_, _, err := util.RemoteRun(sshConfig, cmd, nil)
|
_, _, err := util.RemoteRun(sshClient, cmd, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
w.Log("Remote kubevpn command found, not needs to install")
|
w.Log("Remote kubevpn command found, use it")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
log.Infof("remote kubevpn command not found, try to install it...")
|
log.Infof("remote kubevpn command not found, try to install it...")
|
||||||
@@ -293,12 +248,12 @@ func (w *wsHandler) remoteInstallKubevpnIfCommandNotFound(ctx context.Context, s
|
|||||||
"chmod +x ~/.kubevpn/kubevpn",
|
"chmod +x ~/.kubevpn/kubevpn",
|
||||||
"sudo mv ~/.kubevpn/kubevpn /usr/local/bin/kubevpn",
|
"sudo mv ~/.kubevpn/kubevpn /usr/local/bin/kubevpn",
|
||||||
}
|
}
|
||||||
err = util.SCP(w.conn, w.conn, sshConfig, tempBin.Name(), "kubevpn", cmds...)
|
err = util.SCPAndExec(w.conn, w.conn, sshClient, tempBin.Name(), "kubevpn", cmds...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// try to startup daemon process
|
// try to startup daemon process
|
||||||
go util.RemoteRun(sshConfig, "kubevpn get pods", nil)
|
go util.RemoteRun(sshClient, "kubevpn get pods", nil)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -354,12 +309,15 @@ func init() {
|
|||||||
condReady: cancelFunc,
|
condReady: cancelFunc,
|
||||||
}
|
}
|
||||||
CondReady[sessionID] = ctx
|
CondReady[sessionID] = ctx
|
||||||
|
defer conn.Close()
|
||||||
h.handle(conn.Request().Context())
|
h.handle(conn.Request().Context())
|
||||||
}))
|
}))
|
||||||
http.Handle("/resize", websocket.Handler(func(conn *websocket.Conn) {
|
http.Handle("/resize", websocket.Handler(func(conn *websocket.Conn) {
|
||||||
sessionID := conn.Request().Header.Get("session-id")
|
sessionID := conn.Request().Header.Get("session-id")
|
||||||
log.Infof("resize: %s", sessionID)
|
log.Infof("resize: %s", sessionID)
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
var session *ssh.Session
|
var session *ssh.Session
|
||||||
select {
|
select {
|
||||||
case <-conn.Request().Context().Done():
|
case <-conn.Request().Context().Done():
|
||||||
|
@@ -759,11 +759,18 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// pre-check network ip connect
|
||||||
|
var cli *ssh.Client
|
||||||
|
cli, err = util.DialSshRemote(ctx, conf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag()
|
configFlags := genericclioptions.NewConfigFlags(true).WithDeprecatedPasswordFlag()
|
||||||
|
|
||||||
if conf.RemoteKubeconfig != "" || (flags != nil && flags.Changed("remote-kubeconfig")) {
|
if conf.RemoteKubeconfig != "" || (flags != nil && flags.Changed("remote-kubeconfig")) {
|
||||||
var stdOut []byte
|
var stdout []byte
|
||||||
var errOut []byte
|
var stderr []byte
|
||||||
if len(conf.RemoteKubeconfig) != 0 && conf.RemoteKubeconfig[0] == '~' {
|
if len(conf.RemoteKubeconfig) != 0 && conf.RemoteKubeconfig[0] == '~' {
|
||||||
conf.RemoteKubeconfig = filepath.Join("/home", conf.User, conf.RemoteKubeconfig[1:])
|
conf.RemoteKubeconfig = filepath.Join("/home", conf.User, conf.RemoteKubeconfig[1:])
|
||||||
}
|
}
|
||||||
@@ -771,18 +778,18 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
|||||||
// if `--remote-kubeconfig` is parsed then Entrypoint is reset
|
// if `--remote-kubeconfig` is parsed then Entrypoint is reset
|
||||||
conf.RemoteKubeconfig = filepath.Join("/home", conf.User, clientcmd.RecommendedHomeDir, clientcmd.RecommendedFileName)
|
conf.RemoteKubeconfig = filepath.Join("/home", conf.User, clientcmd.RecommendedHomeDir, clientcmd.RecommendedFileName)
|
||||||
}
|
}
|
||||||
stdOut, errOut, err = util.RemoteRun(conf,
|
stdout, stderr, err = util.RemoteRun(cli,
|
||||||
fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s'",
|
fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s'",
|
||||||
conf.RemoteKubeconfig,
|
conf.RemoteKubeconfig,
|
||||||
conf.RemoteKubeconfig),
|
conf.RemoteKubeconfig),
|
||||||
map[string]string{clientcmd.RecommendedConfigPathEnvVar: conf.RemoteKubeconfig},
|
map[string]string{clientcmd.RecommendedConfigPathEnvVar: conf.RemoteKubeconfig},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrap(err, string(errOut))
|
err = errors.Wrap(err, string(stderr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(stdOut) == 0 {
|
if len(stdout) == 0 {
|
||||||
err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(errOut))
|
err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(stderr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -793,7 +800,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
|||||||
if err = temp.Close(); err != nil {
|
if err = temp.Close(); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = os.WriteFile(temp.Name(), stdOut, 0644); err != nil {
|
if err = os.WriteFile(temp.Name(), stdout, 0644); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err = os.Chmod(temp.Name(), 0644); err != nil {
|
if err = os.Chmod(temp.Name(), 0644); err != nil {
|
||||||
@@ -871,7 +878,6 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var port int
|
var port int
|
||||||
port, err = util.GetAvailableTCPPortOrDie()
|
port, err = util.GetAvailableTCPPortOrDie()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -883,43 +889,11 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// pre-check network ip connect
|
|
||||||
var cli *ssh.Client
|
|
||||||
cli, err = util.DialSshRemote(conf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
_ = cli.Close()
|
|
||||||
}
|
|
||||||
errChan := make(chan error, 1)
|
|
||||||
readyChan := make(chan struct{}, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
err := util.Main(ctx, remote, local, conf, readyChan)
|
|
||||||
if err != nil {
|
|
||||||
if !errors.Is(err, context.Canceled) {
|
|
||||||
log.Errorf("ssh forward failed err: %v", err)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case errChan <- err:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(time.Second * 2)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if print {
|
if print {
|
||||||
log.Infof("wait jump to bastion host...")
|
log.Infof("wait jump to bastion host...")
|
||||||
}
|
}
|
||||||
select {
|
err = util.PortMapUntil(ctx, cli, remote, local)
|
||||||
case <-readyChan:
|
if err != nil {
|
||||||
case err = <-errChan:
|
|
||||||
log.Errorf("ssh proxy err: %v", err)
|
log.Errorf("ssh proxy err: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/opencontainers/image-spec/specs-go/v1"
|
"github.com/opencontainers/image-spec/specs-go/v1"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetClient() (*client.Client, *command.DockerCli, error) {
|
func GetClient() (*client.Client, *command.DockerCli, error) {
|
||||||
@@ -109,6 +110,11 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
|
|||||||
}
|
}
|
||||||
|
|
||||||
// transfer image to remote
|
// transfer image to remote
|
||||||
|
var sshClient *ssh.Client
|
||||||
|
sshClient, err = DialSshRemote(ctx, conf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
var responseReader io.ReadCloser
|
var responseReader io.ReadCloser
|
||||||
responseReader, err = cli.ImageSave(ctx, []string{imageTarget})
|
responseReader, err = cli.ImageSave(ctx, []string{imageTarget})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -132,12 +138,12 @@ func TransferImage(ctx context.Context, conf *SshConfig, imageSource, imageTarge
|
|||||||
logrus.Infof("Transferring image %s", imageTarget)
|
logrus.Infof("Transferring image %s", imageTarget)
|
||||||
filename := filepath.Base(file.Name())
|
filename := filepath.Base(file.Name())
|
||||||
cmd := fmt.Sprintf(
|
cmd := fmt.Sprintf(
|
||||||
"(docker load image -i ~/.kubevpn/%s && docker push %s) || (nerdctl image load -i ~/.kubevpn/%s && nerdctl image push %s)",
|
"(docker load -i ~/.kubevpn/%s && docker push %s) || (nerdctl image load -i ~/.kubevpn/%s && nerdctl image push %s)",
|
||||||
filename, imageTarget,
|
filename, imageTarget,
|
||||||
filename, imageTarget,
|
filename, imageTarget,
|
||||||
)
|
)
|
||||||
stdout := log.StandardLogger().Out
|
stdout := log.StandardLogger().Out
|
||||||
err = SCP(stdout, stdout, conf, file.Name(), filename, []string{cmd}...)
|
err = SCPAndExec(stdout, stdout, sshClient, file.Name(), filename, []string{cmd}...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -10,29 +10,20 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SCP copy file to remote and exec command
|
// SCPAndExec copy file to remote and exec command
|
||||||
func SCP(stdout, stderr io.Writer, conf *SshConfig, filename, to string, commands ...string) error {
|
func SCPAndExec(stdout, stderr io.Writer, client *ssh.Client, filename, to string, commands ...string) error {
|
||||||
remote, err := DialSshRemote(conf)
|
err := SCP(client, stdout, stderr, filename, to)
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Dial into remote server error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := remote.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = main(sess, stdout, stderr, filename, to)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Copy file to remote error: %s", err)
|
log.Errorf("Copy file to remote error: %s", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, command := range commands {
|
for _, command := range commands {
|
||||||
sess, err = remote.NewSession()
|
var session *ssh.Session
|
||||||
|
session, err = client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
output, err := sess.CombinedOutput(command)
|
output, err := session.CombinedOutput(command)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(string(output))
|
log.Error(string(output))
|
||||||
return err
|
return err
|
||||||
@@ -43,24 +34,28 @@ func SCP(stdout, stderr io.Writer, conf *SshConfig, filename, to string, command
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://blog.neilpang.com/%E6%94%B6%E8%97%8F-scp-secure-copy%E5%8D%8F%E8%AE%AE/
|
// SCP https://blog.neilpang.com/%E6%94%B6%E8%97%8F-scp-secure-copy%E5%8D%8F%E8%AE%AE/
|
||||||
func main(sess *ssh.Session, stdout, stderr io.Writer, filename string, to string) error {
|
func SCP(client *ssh.Client, stdout, stderr io.Writer, filename, to string) error {
|
||||||
open, err := os.Open(filename)
|
file, err := os.Open(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stat, err := open.Stat()
|
defer file.Close()
|
||||||
|
stat, err := file.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sess, err := client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer open.Close()
|
|
||||||
defer sess.Close()
|
defer sess.Close()
|
||||||
go func() {
|
go func() {
|
||||||
w, _ := sess.StdinPipe()
|
w, _ := sess.StdinPipe()
|
||||||
defer w.Close()
|
defer w.Close()
|
||||||
fmt.Fprintln(w, "D0755", 0, ".kubevpn") // mkdir
|
fmt.Fprintln(w, "D0755", 0, ".kubevpn") // mkdir
|
||||||
fmt.Fprintln(w, "C0644", stat.Size(), to)
|
fmt.Fprintln(w, "C0644", stat.Size(), to)
|
||||||
err := sCopy(w, open, stat.Size(), stdout, stderr)
|
err := sCopy(w, file, stat.Size(), stdout, stderr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to transfer file to remote: %v", err)
|
log.Errorf("failed to transfer file to remote: %v", err)
|
||||||
return
|
return
|
||||||
@@ -72,9 +67,9 @@ func main(sess *ssh.Session, stdout, stderr io.Writer, filename string, to strin
|
|||||||
|
|
||||||
func sCopy(dst io.Writer, src io.Reader, size int64, stdout, stderr io.Writer) error {
|
func sCopy(dst io.Writer, src io.Reader, size int64, stdout, stderr io.Writer) error {
|
||||||
total := float64(size) / 1024 / 1024
|
total := float64(size) / 1024 / 1024
|
||||||
s := fmt.Sprintf("Length: %d (%0.2fM)\n", size, total)
|
s := fmt.Sprintf("Length: %d (%0.2fM)", size, total)
|
||||||
log.Info(s)
|
log.Info(s)
|
||||||
io.WriteString(stdout, s)
|
io.WriteString(stdout, s+"\n")
|
||||||
|
|
||||||
bar := progressbar.NewOptions(int(size),
|
bar := progressbar.NewOptions(int(size),
|
||||||
progressbar.OptionSetWriter(stdout),
|
progressbar.OptionSetWriter(stdout),
|
||||||
|
116
pkg/util/ssh.go
116
pkg/util/ssh.go
@@ -67,31 +67,7 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, conf *SshConfig, done chan struct{}) error {
|
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
|
||||||
ctx, cancelFunc := context.WithCancel(pctx)
|
|
||||||
defer cancelFunc()
|
|
||||||
|
|
||||||
sshClient, err := DialSshRemote(conf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Dial into remote server error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ref: https://github.com/golang/go/issues/21478
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(time.Second * 5)
|
|
||||||
defer ticker.Stop()
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
_, _, err := sshClient.SendRequest("keepalive@golang.org", true, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to send keep alive error: %s", err)
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Listen on remote server port
|
// Listen on remote server port
|
||||||
var lc net.ListenConfig
|
var lc net.ListenConfig
|
||||||
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
|
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
|
||||||
@@ -105,34 +81,27 @@ func Main(pctx context.Context, remoteEndpoint, localEndpoint netip.AddrPort, co
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
// handle incoming connections on reverse forwarded tunnel
|
// handle incoming connections on reverse forwarded tunnel
|
||||||
for {
|
for ctx.Err() == nil {
|
||||||
select {
|
localConn, err := listen.Accept()
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
local, err := listen.Accept()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go func(local net.Conn) {
|
go func(localConn net.Conn) {
|
||||||
defer local.Close()
|
defer localConn.Close()
|
||||||
conn, err := sshClient.Dial("tcp", remoteEndpoint.String())
|
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
|
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
|
||||||
cancelFunc()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer remoteConn.Close()
|
||||||
handleClient(local, conn)
|
copyStream(localConn, remoteConn)
|
||||||
}(local)
|
}(localConn)
|
||||||
}
|
}
|
||||||
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo ssh heartbeats
|
// DialSshRemote https://github.com/golang/go/issues/21478
|
||||||
// https://github.com/golang/go/issues/21478
|
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
|
||||||
func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
|
|
||||||
var remote *ssh.Client
|
var remote *ssh.Client
|
||||||
var err error
|
var err error
|
||||||
if conf.ConfigAlias != "" {
|
if conf.ConfigAlias != "" {
|
||||||
@@ -189,22 +158,33 @@ func DialSshRemote(conf *SshConfig) (*ssh.Client, error) {
|
|||||||
// Connect to SSH remote server using serverEndpoint
|
// Connect to SSH remote server using serverEndpoint
|
||||||
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
|
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/golang/go/issues/21478
|
||||||
|
if err == nil {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(time.Second * 5)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
|
||||||
|
if er != nil {
|
||||||
|
log.Errorf("failed to send keep alive error: %s", er)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
return remote, err
|
return remote, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
func RemoteRun(client *ssh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
|
||||||
var remote *ssh.Client
|
|
||||||
remote, err = DialSshRemote(conf)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Dial into remote server error: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer remote.Close()
|
|
||||||
var session *ssh.Session
|
var session *ssh.Session
|
||||||
session, err = remote.NewSession()
|
session, err = client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer session.Close()
|
||||||
for k, v := range env {
|
for k, v := range env {
|
||||||
// /etc/ssh/sshd_config
|
// /etc/ssh/sshd_config
|
||||||
// AcceptEnv DEBIAN_FRONTEND
|
// AcceptEnv DEBIAN_FRONTEND
|
||||||
@@ -213,7 +193,6 @@ func RemoteRun(conf *SshConfig, cmd string, env map[string]string) (output []byt
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
defer remote.Close()
|
|
||||||
var out bytes.Buffer
|
var out bytes.Buffer
|
||||||
var er bytes.Buffer
|
var er bytes.Buffer
|
||||||
session.Stdout = &out
|
session.Stdout = &out
|
||||||
@@ -246,7 +225,7 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
|
|||||||
return ssh.PublicKeys(key), nil
|
return ssh.PublicKeys(key), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleClient(client net.Conn, remote net.Conn) {
|
func copyStream(client net.Conn, remote net.Conn) {
|
||||||
chDone := make(chan bool, 2)
|
chDone := make(chan bool, 2)
|
||||||
|
|
||||||
// start remote -> local data transfer
|
// start remote -> local data transfer
|
||||||
@@ -408,3 +387,32 @@ func init() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
|
readyChan := make(chan struct{}, 1)
|
||||||
|
go func() {
|
||||||
|
for ctx.Err() == nil {
|
||||||
|
err := PortMap(ctx, cli, remote, local, readyChan)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("Ssh forward failed err: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case errChan <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(time.Second * 2)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-readyChan:
|
||||||
|
return nil
|
||||||
|
case err := <-errChan:
|
||||||
|
log.Errorf("Ssh forward failed err: %v", err)
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user