hotfix: fix ssh bug (#294)

This commit is contained in:
naison
2024-07-12 22:08:17 +08:00
committed by GitHub
parent 62b0de99f9
commit b0a6a0d054
9 changed files with 126 additions and 86 deletions

View File

@@ -720,7 +720,7 @@ Answer: here are two solutions to solve this problem
➜ ~ kubevpn version
KubeVPN: CLI
Version: v2.0.0
DaemonVersion: v2.0.0
Daemon: v2.0.0
Image: docker.io/naison/kubevpn:v2.0.0
Branch: feature/daemon
Git commit: 7c3a87e14e05c238d8fb23548f95fa1dd6e96936

View File

@@ -614,7 +614,7 @@ d0b3dab8912a naison/kubevpn:v2.0.0 "/bin/bash" 5 minutes ago
➜ ~ kubevpn version
KubeVPN: CLI
Version: v2.0.0
DaemonVersion: v2.0.0
Daemon: v2.0.0
Image: docker.io/naison/kubevpn:v2.0.0
Branch: feature/daemon
Git commit: 7c3a87e14e05c238d8fb23548f95fa1dd6e96936

View File

@@ -24,8 +24,7 @@ import (
)
// CmdSSH
// 设置本地的IP是223.254.0.1/32 ,记得一定是掩码 32位
// 这样别的路由不会走到这里来
// Remember to use network mask 32, because ssh using unique network cidr 223.255.0.0/16
func CmdSSH(_ cmdutil.Factory) *cobra.Command {
var sshConf = &util.SshConfig{}
var ExtraCIDR []string
@@ -71,16 +70,13 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command {
if err != nil {
return fmt.Errorf("terminal get size: %s", err)
}
marshal, err := json.Marshal(sshConf)
if err != nil {
return err
}
sessionID := uuid.NewString()
config.Header.Set("ssh-addr", sshConf.Addr)
config.Header.Set("ssh-username", sshConf.User)
config.Header.Set("ssh-password", sshConf.Password)
config.Header.Set("ssh-keyfile", sshConf.Keyfile)
config.Header.Set("ssh-alias", sshConf.ConfigAlias)
config.Header.Set("ssh", string(marshal))
config.Header.Set("extra-cidr", strings.Join(ExtraCIDR, ","))
config.Header.Set("gssapi-password", sshConf.GSSAPIPassword)
config.Header.Set("gssapi-keytab", sshConf.GSSAPIKeytabConf)
config.Header.Set("gssapi-cache", sshConf.GSSAPICacheFile)
config.Header.Set("terminal-width", strconv.Itoa(width))
config.Header.Set("terminal-height", strconv.Itoa(height))
config.Header.Set("session-id", sessionID)
@@ -93,8 +89,7 @@ func CmdSSH(_ cmdutil.Factory) *cobra.Command {
errChan := make(chan error, 3)
go func() {
err := monitorSize(cmd.Context(), sessionID)
errChan <- err
errChan <- monitorSize(cmd.Context(), sessionID)
}()
go func() {
_, err := io.Copy(conn, os.Stdin)

View File

@@ -45,9 +45,10 @@ func (svr *Server) SshStart(ctx context.Context, req *rpc.SshStartRequest) (*rpc
log.Errorf("parse route error: %v", err)
return nil, err
}
ctx, sshCancelFunc = context.WithCancel(context.Background())
var ctx1 context.Context
ctx1, sshCancelFunc = context.WithCancel(context.Background())
go func() {
err := handler.Run(ctx, servers)
err := handler.Run(ctx1, servers)
if err != nil {
log.Errorf("run route error: %v", err)
}

View File

@@ -14,6 +14,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -44,8 +45,8 @@ type wsHandler struct {
// 1) start remote kubevpn server
// 2) start local tunnel
// 3) ssh terminal
func (w *wsHandler) handle(ctx context.Context) {
ctx, f := context.WithCancel(ctx)
func (w *wsHandler) handle(c context.Context) {
ctx, f := context.WithCancel(c)
defer f()
cli, err := util.DialSshRemote(ctx, w.sshConfig)
@@ -57,7 +58,7 @@ func (w *wsHandler) handle(ctx context.Context) {
err = w.installKubevpnOnRemote(ctx, cli)
if err != nil {
w.Log("Install kubevpn error: %v", err)
//w.Log("Install kubevpn error: %v", err)
return
}
@@ -130,7 +131,12 @@ func (w *wsHandler) handle(ctx context.Context) {
time.Sleep(time.Second * 5)
}
}()
err = w.terminal(ctx, cli, w.conn)
rw := NewReadWriteWrapper(w.conn)
go func() {
<-rw.IsClosed()
f()
}()
err = w.terminal(ctx, cli, rw)
if err != nil {
w.Log("Enter terminal error: %v", err)
}
@@ -138,21 +144,15 @@ func (w *wsHandler) handle(ctx context.Context) {
}
// startup daemon process if daemon process not start
func startDaemonProcess(cli *ssh.Client) {
func startDaemonProcess(cli *ssh.Client) string {
startDaemonCmd := fmt.Sprintf(`export %s=%s && kubevpn status > /dev/null 2>&1 &`, config.EnvStartSudoKubeVPNByKubeVPN, "true")
_, _, _ = util.RemoteRun(cli, startDaemonCmd, nil)
ticker := time.NewTicker(time.Millisecond * 50)
defer ticker.Stop()
for range ticker.C {
output, _, err := util.RemoteRun(cli, "kubevpn version", nil)
if err != nil {
continue
}
version := getDaemonVersionFromOutput(output)
if version != "" && version != "unknown" {
break
}
output, _, err := util.RemoteRun(cli, "kubevpn version", nil)
if err != nil {
return ""
}
version := getDaemonVersionFromOutput(output)
return version
}
func getDaemonVersionFromOutput(output []byte) (version string) {
@@ -178,7 +178,45 @@ func getDaemonVersionFromOutput(output []byte) (version string) {
return data.DaemonVersion
}
func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn *websocket.Conn) error {
type ReadWriteWrapper struct {
closed chan any
sync.Once
net.Conn
}
func NewReadWriteWrapper(conn net.Conn) *ReadWriteWrapper {
return &ReadWriteWrapper{
closed: make(chan any),
Once: sync.Once{},
Conn: conn,
}
}
func (rw *ReadWriteWrapper) Read(b []byte) (int, error) {
n, err := rw.Conn.Read(b)
if err != nil {
rw.Do(func() {
close(rw.closed)
})
}
return n, err
}
func (rw *ReadWriteWrapper) Write(p []byte) (int, error) {
n, err := rw.Conn.Write(p)
if err != nil {
rw.Do(func() {
close(rw.closed)
})
}
return n, err
}
func (rw *ReadWriteWrapper) IsClosed() chan any {
return rw.closed
}
func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn io.ReadWriter) error {
session, err := cli.NewSession()
if err != nil {
w.Log("New session error: %v", err)
@@ -203,7 +241,7 @@ func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn *websock
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
if err := session.RequestPty("xterm", height, width, modes); err != nil {
if err = session.RequestPty("xterm", height, width, modes); err != nil {
w.Log("Request pty error: %v", err)
return err
}
@@ -217,7 +255,7 @@ func (w *wsHandler) terminal(ctx context.Context, cli *ssh.Client, conn *websock
func (w *wsHandler) installKubevpnOnRemote(ctx context.Context, sshClient *ssh.Client) (err error) {
defer func() {
if err == nil {
startDaemonProcess(sshClient)
w.Log("Remote daemon server version: %s", startDaemonProcess(sshClient))
}
}()
@@ -306,15 +344,12 @@ var CondReady = make(map[string]context.Context)
func init() {
http.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
sshConfig := util.SshConfig{
Addr: conn.Request().Header.Get("ssh-addr"),
User: conn.Request().Header.Get("ssh-username"),
Password: conn.Request().Header.Get("ssh-password"),
Keyfile: conn.Request().Header.Get("ssh-keyfile"),
ConfigAlias: conn.Request().Header.Get("ssh-alias"),
GSSAPIPassword: conn.Request().Header.Get("gssapi-password"),
GSSAPIKeytabConf: conn.Request().Header.Get("gssapi-keytab"),
GSSAPICacheFile: conn.Request().Header.Get("gssapi-cache"),
var sshConfig util.SshConfig
b := conn.Request().Header.Get("ssh")
if err := json.Unmarshal([]byte(b), &sshConfig); err != nil {
_, _ = conn.Write([]byte(err.Error()))
_ = conn.Close()
return
}
var extraCIDR []string
if v := conn.Request().Header.Get("extra-cidr"); v != "" {

View File

@@ -10,7 +10,7 @@ func TestGetVersionFromOutput(t *testing.T) {
{
output: `KubeVPN: CLI
Version: v2.2.3
DaemonVersion: v2.2.3
Daemon: v2.2.3
Image: docker.io/naison/kubevpn:v2.2.3
Branch: feat/ssh-heartbeat
Git commit: 1272e86a337d3075427ee3a1c3681d378558d133
@@ -22,7 +22,7 @@ func TestGetVersionFromOutput(t *testing.T) {
{
output: `KubeVPN: CLI
Version: v2.2.3
DaemonVersion: unknown
Daemon: unknown
Image: docker.io/naison/kubevpn:v2.2.3
Branch: feat/ssh-heartbeat
Git commit: 1272e86a337d3075427ee3a1c3681d378558d133

View File

@@ -68,16 +68,15 @@ func SCP(client *ssh.Client, stdout, stderr io.Writer, filename, to string) erro
func sCopy(dst io.Writer, src io.Reader, size int64, stdout, stderr io.Writer) error {
total := float64(size) / 1024 / 1024
s := fmt.Sprintf("Length: %d (%0.2fM)", size, total)
log.Info(s)
io.WriteString(stdout, s+"\n")
bar := progressbar.NewOptions(int(size),
progressbar.OptionSetWriter(stdout),
progressbar.OptionEnableColorCodes(true),
progressbar.OptionShowBytes(true),
progressbar.OptionSetWidth(50),
progressbar.OptionSetWidth(25),
progressbar.OptionOnCompletion(func() {
_, _ = fmt.Fprint(stderr, "\n")
_, _ = fmt.Fprint(stderr, "\n\r")
}),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetDescription("Transferring file..."),

View File

@@ -764,17 +764,33 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
return
}
if print {
msg := fmt.Sprintf("| To use: export KUBECONFIG=%s |", temp.Name())
printLine(msg)
log.Infof(msg)
printLine(msg)
msg := fmt.Sprintf("To use: export KUBECONFIG=%s", temp.Name())
PrintLine(log.Info, msg)
}
path = temp.Name()
return
}
func printLine(msg string) {
line := "+" + strings.Repeat("-", len(msg)-2) + "+"
log.Infof(line)
func PrintLine(f func(...any), msg ...string) {
var length = -1
for _, s := range msg {
length = max(len(s), length)
}
if f == nil {
f = func(a ...any) {
fmt.Println(a...)
}
}
line := "+" + strings.Repeat("-", length+2) + "+"
f(line)
for _, s := range msg {
var padding string
if length != len(s) {
padding = strings.Repeat(" ", length-len(s))
}
f(fmt.Sprintf("| %s%s |", s, padding))
}
f(line)
}
func SshJumpAndSetEnv(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print bool) error {

View File

@@ -10,6 +10,7 @@ import (
"path/filepath"
"strings"
"github.com/pkg/errors"
"github.com/schollz/progressbar/v3"
utilerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"
@@ -29,27 +30,25 @@ const (
func GetManifest(httpCli *http.Client, os string, arch string) (version string, url string, err error) {
var resp *http.Response
var errs []error
for _, addr := range address {
resp, err = httpCli.Get(addr)
if err != nil {
errs = append(errs, err)
for _, a := range address {
resp, err = httpCli.Get(a)
if err == nil {
break
}
errs = append(errs, err)
}
if resp == nil {
aggregate := utilerrors.NewAggregate(errs)
err = fmt.Errorf("failed to call github api, err: %v", aggregate)
err = errors.Wrap(utilerrors.NewAggregate(errs), "failed to call github api")
return
}
var all []byte
all, err = io.ReadAll(resp.Body)
if err != nil {
err = fmt.Errorf("failed to read all response from github api, err: %v", err)
if all, err = io.ReadAll(resp.Body); err != nil {
err = errors.Wrap(err, "failed to read all response from github api")
return
}
var m RootEntity
err = json.Unmarshal(all, &m)
if err != nil {
if err = json.Unmarshal(all, &m); err != nil {
err = fmt.Errorf("failed to unmarshal response, err: %v", err)
return
}
@@ -57,30 +56,25 @@ func GetManifest(httpCli *http.Client, os string, arch string) (version string,
for _, asset := range m.Assets {
if strings.Contains(asset.Name, arch) && strings.Contains(asset.Name, os) {
url = asset.BrowserDownloadUrl
break
}
}
if len(url) == 0 {
var found bool
// if os is not windows and darwin, default is linux
if !sets.New[string]("windows", "darwin").Has(os) {
for _, asset := range m.Assets {
if strings.Contains(asset.Name, "linux") && strings.Contains(asset.Name, arch) {
url = asset.BrowserDownloadUrl
found = true
break
}
}
}
if !found {
err = fmt.Errorf("Can not found latest version url of KubeVPN, you can download it manually: \n%s\n", addr)
return
}
}
// if os is not windows and darwin, default is linux
if !sets.New[string]("windows", "darwin").Has(strings.ToLower(os)) {
for _, asset := range m.Assets {
if strings.Contains(asset.Name, "linux") && strings.Contains(asset.Name, arch) {
url = asset.BrowserDownloadUrl
return
}
}
}
err = fmt.Errorf("can not found latest version url of KubeVPN, you can download it manually: %s", addr)
return
}
// Download
// https://api.github.com/repos/kubenetworks/kubevpn/releases
// https://github.com/kubenetworks/kubevpn/releases/download/v1.1.13/kubevpn-windows-arm64.exe
func Download(client *http.Client, url string, filename string, stdout, stderr io.Writer) error {
@@ -104,7 +98,7 @@ func Download(client *http.Client, url string, filename string, stdout, stderr i
progressbar.OptionShowBytes(true),
progressbar.OptionSetWidth(25),
progressbar.OptionOnCompletion(func() {
_, _ = fmt.Fprint(stderr, "\n")
_, _ = fmt.Fprint(stderr, "\n\r")
}),
progressbar.OptionSetRenderBlankState(true),
progressbar.OptionSetDescription("Writing temp file..."),