refactor: optimize code (#561)

This commit is contained in:
naison
2025-04-25 19:37:03 +08:00
committed by GitHub
parent 28657e3832
commit 9661a122bd
7 changed files with 53 additions and 73 deletions

View File

@@ -39,6 +39,7 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
sshCtx, sshCancel := context.WithCancel(context.Background())
connect.AddRolloutFunc(func() error {
sshCancel()
os.Remove(file)
return nil
})
sshCtx = plog.WithLogger(sshCtx, logger)
@@ -99,6 +100,7 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
}
connect.AddRolloutFunc(func() error {
sshCancel()
os.Remove(file)
return nil
})
defer func() {
@@ -112,6 +114,10 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
if err != nil {
return err
}
connect.AddRolloutFunc(func() error {
os.Remove(path)
return nil
})
err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace))
if err != nil {
return err
@@ -137,6 +143,8 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
)
if isSameCluster {
sshCancel()
os.Remove(file)
os.Remove(path)
// same cluster, do nothing
logger.Infof("Connected with cluster")
return nil

View File

@@ -58,6 +58,7 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
sshCtx, sshCancel := context.WithCancel(context.Background())
svr.connect.AddRolloutFunc(func() error {
sshCancel()
os.Remove(file)
return nil
})
sshCtx = plog.WithLogger(sshCtx, logger)
@@ -67,6 +68,7 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
svr.connect.Cleanup(sshCtx)
svr.connect = nil
svr.t = time.Time{}
os.Remove(file)
sshCancel()
}
}()
@@ -114,12 +116,14 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
}
connect.AddRolloutFunc(func() error {
sshCancel()
os.Remove(file)
return nil
})
defer func() {
if e != nil {
connect.Cleanup(plog.WithLogger(context.Background(), logger))
sshCancel()
os.Remove(file)
}
}()
var path string
@@ -127,6 +131,10 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
if err != nil {
return err
}
connect.AddRolloutFunc(func() error {
os.Remove(path)
return nil
})
err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace))
if err != nil {
return err
@@ -152,6 +160,8 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
)
if isSameCluster {
sshCancel()
os.Remove(path)
os.Remove(file)
// same cluster, do nothing
logger.Infof("Connected to cluster")
return nil

View File

@@ -64,6 +64,9 @@ func (c *ConnectOptions) Cleanup(ctx context.Context) {
if err != nil {
plog.G(ctx).Errorf("Leave proxy resources error: %v", err)
}
if c.cancel != nil {
c.cancel()
}
for _, function := range c.getRolloutFunc() {
if function != nil {
@@ -72,9 +75,6 @@ func (c *ConnectOptions) Cleanup(ctx context.Context) {
}
}
}
if c.cancel != nil {
c.cancel()
}
if c.dnsConfig != nil {
if inUserDaemon {
plog.G(ctx2).Infof("Clearing DNS settings")

View File

@@ -168,13 +168,13 @@ func (conf SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct
if client == nil {
client, err = bastionList[i].Dial(ctx, stopChan)
if err != nil {
err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %s", bastionList[i]))
err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %v", bastionList[i]))
return
}
} else {
client, err = JumpTo(ctx, client, bastionList[i], stopChan)
if err != nil {
err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", bastionList[i]))
err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %v", bastionList[i]))
return
}
}

View File

@@ -111,6 +111,9 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
for ctx1.Err() == nil {
localConn, err1 := localListen.Accept()
if err1 != nil {
if errors.Is(err1, net.ErrClosed) {
return
}
plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1)
continue
}
@@ -123,14 +126,14 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var openChannelError *gossh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward: %s: %v", remote.String(), err)
plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
cancelFunc1()
}
plog.G(ctx).Debugf("Failed to get remote conn: %v", err)
plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err)
return
}
plog.G(ctx).Debugf("Opened ssh port-forward: %s", remote.String())
plog.G(ctx).Debugf("Opened ssh port-forward to %s", remote.String())
defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn)
@@ -196,20 +199,12 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
return
}
var temp *os.File
if temp, err = os.CreateTemp("", "*.kubeconfig"); err != nil {
var file string
file, err = pkgutil.ConvertToTempKubeconfigFile(bytes.TrimSpace(stdout))
if err != nil {
return
}
if err = temp.Close(); err != nil {
return
}
if err = os.WriteFile(temp.Name(), stdout, 0644); err != nil {
return
}
if err = os.Chmod(temp.Name(), 0644); err != nil {
return
}
configFlags.KubeConfig = pointer.String(temp.Name())
configFlags.KubeConfig = pointer.String(file)
} else {
if flags != nil {
lookup := flags.Lookup("kubeconfig")
@@ -332,26 +327,16 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
plog.G(ctx).Errorf("failed to marshal config: %v", err)
return
}
var temp *os.File
temp, err = os.CreateTemp("", "*.kubeconfig")
path, err = pkgutil.ConvertToTempKubeconfigFile(marshal)
if err != nil {
return
}
if err = temp.Close(); err != nil {
return
}
if err = os.WriteFile(temp.Name(), marshal, 0644); err != nil {
return
}
if err = os.Chmod(temp.Name(), 0644); err != nil {
plog.G(ctx).Errorf("failed to write kubeconfig: %v", err)
return
}
if print {
plog.G(ctx).Infof("Use temporary kubeconfig: %s", temp.Name())
plog.G(ctx).Infof("Use temporary kubeconfig: %s", path)
} else {
plog.G(ctx).Debugf("Use temporary kubeconfig: %s", temp.Name())
plog.G(ctx).Debugf("Use temporary kubeconfig: %s", path)
}
path = temp.Name()
return
}
@@ -435,7 +420,7 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re
conn, err = cli.DialContext(ctx1, "tcp", remote.String())
cancelFunc1()
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err)
plog.G(ctx).Debugf("Failed to dial remote address %s: %v", remote.String(), err)
clientMap.Delete(key)
plog.G(ctx).Error("Delete invalid ssh client from map")
_ = cli.Close()
@@ -456,16 +441,16 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re
return nil, err
}
clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1))
plog.G(ctx1).Debug("Connected to remote SSH server")
plog.G(ctx1).Debug("Connected to remote ssh server")
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2()
conn, err = client.DialContext(ctx2, "tcp", remote.String())
if err != nil {
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err)
plog.G(ctx).Debugf("Failed to dial remote addr %s: %v", remote.String(), err)
return nil, err
}
plog.G(ctx).Debugf("Connected to remote addr: %s", remote.String())
plog.G(ctx).Debugf("Connected to remote addr %s", remote.String())
return conn, nil
}

View File

@@ -85,19 +85,10 @@ func ConvertK8sApiServerToDomain(kubeConfigPath string) (newPath string, err err
if err != nil {
return
}
var temp *os.File
temp, err = os.CreateTemp("", "*.kubeconfig")
newPath, err = ConvertToTempKubeconfigFile(marshal)
if err != nil {
return
}
if err = temp.Close(); err != nil {
return
}
err = os.WriteFile(temp.Name(), marshal, 0644)
if err != nil {
return
}
newPath = temp.Name()
return
}

View File

@@ -132,11 +132,15 @@ func ConvertToTempKubeconfigFile(kubeconfigBytes []byte) (string, error) {
if err != nil {
return "", err
}
err = temp.Close()
_, err = temp.Write(kubeconfigBytes)
if err != nil {
return "", err
}
err = os.WriteFile(temp.Name(), kubeconfigBytes, os.ModePerm)
err = temp.Chmod(0644)
if err != nil {
return "", err
}
err = temp.Close()
if err != nil {
return "", err
}
@@ -156,19 +160,11 @@ func InitFactory(kubeconfigBytes string, ns string) cmdutil.Factory {
}
return c
}
temp, err := os.CreateTemp("", "*.kubeconfig")
file, err := ConvertToTempKubeconfigFile([]byte(kubeconfigBytes))
if err != nil {
return nil
}
err = temp.Close()
if err != nil {
return nil
}
err = os.WriteFile(temp.Name(), []byte(kubeconfigBytes), os.ModePerm)
if err != nil {
return nil
}
configFlags.KubeConfig = pointer.String(temp.Name())
configFlags.KubeConfig = pointer.String(file)
configFlags.Namespace = pointer.String(ns)
matchVersionFlags := cmdutil.NewMatchVersionFlags(configFlags)
return cmdutil.NewFactory(matchVersionFlags)
@@ -214,19 +210,9 @@ func GetKubeconfigPath(factory cmdutil.Factory) (string, error) {
return "", err
}
temp, err := os.CreateTemp("", "*.kubeconfig")
file, err := ConvertToTempKubeconfigFile(kubeconfigJsonBytes)
if err != nil {
return "", err
}
temp.Close()
err = os.WriteFile(temp.Name(), kubeconfigJsonBytes, 0644)
if err != nil {
return "", err
}
err = os.Chmod(temp.Name(), 0644)
if err != nil {
return "", err
}
return temp.Name(), nil
return file, nil
}