Files
kubevpn/pkg/ssh/ssh.go
naison 98c4a61ca1 feat: sync mode modify api server & add ut for mode run and mode sync (#704)
* feat: mode sync modify kubeconfig apiserver to in cluster apiserver
* feat: add ut for sync mode and run mode
* fix: bugs
2025-08-21 22:45:47 +08:00

361 lines
9.6 KiB
Go

package ssh
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/netip"
"os"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
"k8s.io/client-go/tools/clientcmd"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
pkgutil "github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
// DialSshRemote https://github.com/golang/go/issues/21478
func DialSshRemote(ctx context.Context, conf *SshConfig, stopChan <-chan struct{}) (remote *gossh.Client, err error) {
defer func() {
if err != nil {
if remote != nil {
remote.Close()
}
}
}()
if conf.ConfigAlias != "" {
remote, err = conf.AliasRecursion(ctx, stopChan)
} else if conf.Jump != "" {
remote, err = conf.JumpRecursion(ctx, stopChan)
} else {
remote, err = conf.Dial(ctx, stopChan)
}
// ref: https://github.com/golang/go/issues/21478
if err == nil {
//go func() {
// err2 := keepAlive(remote, conn, ctx.Done())
// if err2 != nil {
// plog.G(ctx).Debugf("Failed to send keep-alive request: %v", err2)
// }
//}()
}
return remote, err
}
func RemoteRun(client *gossh.Client, cmd string, env map[string]string) (output []byte, errOut []byte, err error) {
var session *gossh.Session
session, err = client.NewSession()
if err != nil {
return
}
defer session.Close()
for k, v := range env {
// /etc/ssh/sshd_config
// AcceptEnv DEBIAN_FRONTEND
_ = session.Setenv(k, v)
}
var out bytes.Buffer
var er bytes.Buffer
session.Stdout = &out
session.Stderr = &er
err = session.Run(cmd)
return out.Bytes(), er.Bytes(), err
}
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
// Listen on remote server port
var lc net.ListenConfig
localListen, e := lc.Listen(ctx, "tcp", local.String())
if e != nil {
plog.G(ctx).Errorf("failed to listen %s: %v", local.String(), e)
return e
}
plog.G(ctx).Debugf("SSH listening on local %s forward to %s", local.String(), remote.String())
go func() {
<-ctx.Done()
localListen.Close()
}()
go func() {
defer localListen.Close()
var clientMap = &sync.Map{}
ctx1, cancelFunc1 := context.WithCancel(ctx)
defer cancelFunc1()
for ctx1.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
if errors.Is(err1, net.ErrClosed) {
return
}
plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1)
continue
}
plog.G(ctx).Debugf("Accepted ssh conn from %s", localConn.RemoteAddr().String())
go func() {
defer localConn.Close()
remoteConn, err := getRemoteConn(ctx1, clientMap, conf, remote)
if err != nil {
var openChannelError *gossh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
cancelFunc1()
}
plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err)
return
}
plog.G(ctx).Debugf("Opened ssh port-forward to %s", remote.String())
defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn)
}()
}
}()
return nil
}
func SshJump(ctx context.Context, conf *SshConfig, kubeconfigBytes []byte, print bool) (path string, err error) {
if len(conf.RemoteKubeconfig) != 0 {
var stdout []byte
var stderr []byte
// pre-check network ip connect
var cli *gossh.Client
cli, err = DialSshRemote(ctx, conf, ctx.Done())
if err != nil {
return
}
defer cli.Close()
stdout, stderr, err = RemoteRun(cli,
fmt.Sprintf("sh -c 'kubectl config view --flatten --raw --kubeconfig %s || minikube kubectl -- config view --flatten --raw --kubeconfig %s || cat %s'",
conf.RemoteKubeconfig,
conf.RemoteKubeconfig,
conf.RemoteKubeconfig,
),
map[string]string{clientcmd.RecommendedConfigPathEnvVar: conf.RemoteKubeconfig},
)
if err != nil {
err = errors.Wrap(err, string(stderr))
return
}
if len(bytes.TrimSpace(stdout)) == 0 {
err = errors.Errorf("can not get kubeconfig %s from remote ssh server: %s", conf.RemoteKubeconfig, string(stderr))
return
}
kubeconfigBytes = bytes.TrimSpace(stdout)
}
var port int
port, err = pkgutil.GetAvailableTCPPortOrDie()
if err != nil {
return
}
var local netip.AddrPort
local, err = netip.ParseAddrPort(net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err != nil {
return
}
var oldAPIServer netip.AddrPort
var newKubeconfigBytes []byte
newKubeconfigBytes, oldAPIServer, err = pkgutil.ModifyAPIServer(ctx, kubeconfigBytes, local)
if err != nil {
return
}
if print {
plog.G(ctx).Infof("Waiting jump to bastion host...")
plog.G(ctx).Infof("Jump ssh bastion host to apiserver: %s", oldAPIServer.String())
} else {
plog.G(ctx).Debugf("Waiting jump to bastion host...")
plog.G(ctx).Debugf("Jump ssh bastion host to apiserver: %s", oldAPIServer.String())
}
err = PortMapUntil(ctx, conf, oldAPIServer, local)
if err != nil {
plog.G(ctx).Errorf("SSH port map error: %v", err)
return
}
path, err = pkgutil.ConvertToTempKubeconfigFile(newKubeconfigBytes, GenKubeconfigTempPath(conf, kubeconfigBytes))
if err != nil {
plog.G(ctx).Errorf("failed to write kubeconfig: %v", err)
return
}
go func() {
<-ctx.Done()
_ = os.Remove(path)
}()
if print {
plog.G(ctx).Infof("Use temp kubeconfig: %s", path)
} else {
plog.G(ctx).Debugf("Use temp kubeconfig: %s", path)
}
return
}
func SshJumpAndSetEnv(ctx context.Context, sshConf *SshConfig, kubeconfigBytes []byte, print bool) error {
path, err := SshJump(ctx, sshConf, kubeconfigBytes, print)
if err != nil {
return err
}
err = os.Setenv(clientcmd.RecommendedConfigPathEnvVar, path)
if err != nil {
return err
}
return os.Setenv(config.EnvSSHJump, path)
}
func JumpTo(ctx context.Context, bClient *gossh.Client, to SshConfig, stopChan <-chan struct{}) (client *gossh.Client, err error) {
if _, _, err = net.SplitHostPort(to.Addr); err != nil {
// use default ssh port 22
to.Addr = net.JoinHostPort(to.Addr, "22")
err = nil
}
var authMethod []gossh.AuthMethod
authMethod, err = to.GetAuth()
if err != nil {
return nil, err
}
// Dial a connection to the service host, from the bastion
var conn net.Conn
conn, err = bClient.DialContext(ctx, "tcp", to.Addr)
if err != nil {
return
}
go func() {
if stopChan != nil {
<-stopChan
conn.Close()
if client != nil {
client.Close()
}
bClient.Close()
}
}()
defer func() {
if err != nil {
if client != nil {
client.Close()
}
if conn != nil {
conn.Close()
}
}
}()
var ncc gossh.Conn
var chans <-chan gossh.NewChannel
var reqs <-chan *gossh.Request
ncc, chans, reqs, err = gossh.NewClientConn(conn, to.Addr, &gossh.ClientConfig{
User: to.User,
Auth: authMethod,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
//BannerCallback: ssh.BannerDisplayStderr(),
Timeout: time.Second * 10,
})
if err != nil {
return
}
client = gossh.NewClient(ncc, chans, reqs)
return
}
func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, remote netip.AddrPort) (net.Conn, error) {
var conn net.Conn
var err error
clientMap.Range(func(key, value any) bool {
cli := value.(*sshClientWrap)
ctx1, cancelFunc1 := context.WithTimeout(ctx, time.Second*10)
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
cancelFunc1()
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote address %s: %v", remote.String(), err)
clientMap.Delete(key)
plog.G(ctx).Error("Delete invalid ssh client from map")
_ = cli.Close()
return true
}
return false
})
if conn != nil {
return conn, nil
}
ctx1, cancelFunc1 := context.WithCancel(ctx)
var client *gossh.Client
client, err = DialSshRemote(ctx1, conf, ctx1.Done())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote ssh server: %v", err)
cancelFunc1()
return nil, err
}
clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1))
plog.G(ctx1).Debug("Connected to remote ssh server")
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
conn, err = client.DialContext(ctx2, "tcp", remote.String())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote addr %s: %v", remote.String(), err)
return nil, err
}
plog.G(ctx).Debugf("Connected to remote addr %s", remote.String())
return conn, nil
}
func copyStream(ctx context.Context, local net.Conn, remote net.Conn) {
chDone := make(chan bool, 2)
// start remote -> local data transfer
go func() {
buf := config.LPool.Get().([]byte)[:]
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(local, remote, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Errorf("Failed to copy remote -> local: %s", err)
}
chDone <- true
}()
// start local -> remote data transfer
go func() {
buf := config.LPool.Get().([]byte)[:]
defer config.LPool.Put(buf[:])
_, err := io.CopyBuffer(remote, local, buf)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
plog.G(ctx).Errorf("Failed to copy local -> remote: %s", err)
}
chDone <- true
}()
select {
case <-chDone:
return
case <-ctx.Done():
return
}
}
func GenKubeconfigTempPath(conf *SshConfig, kubeconfigBytes []byte) string {
if conf != nil && conf.RemoteKubeconfig != "" {
return filepath.Join(config.GetTempPath(), fmt.Sprintf("%s_%d", conf.GenKubeconfigIdentify(), time.Now().Unix()))
}
return pkgutil.GenKubeconfigTempPath(kubeconfigBytes)
}