refactor: optimize code (#561)

This commit is contained in:
naison
2025-04-25 19:37:03 +08:00
committed by GitHub
parent 28657e3832
commit 9661a122bd
7 changed files with 53 additions and 73 deletions

View File

@@ -39,6 +39,7 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF
sshCtx, sshCancel := context.WithCancel(context.Background()) sshCtx, sshCancel := context.WithCancel(context.Background())
connect.AddRolloutFunc(func() error { connect.AddRolloutFunc(func() error {
sshCancel() sshCancel()
os.Remove(file)
return nil return nil
}) })
sshCtx = plog.WithLogger(sshCtx, logger) sshCtx = plog.WithLogger(sshCtx, logger)
@@ -99,6 +100,7 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
} }
connect.AddRolloutFunc(func() error { connect.AddRolloutFunc(func() error {
sshCancel() sshCancel()
os.Remove(file)
return nil return nil
}) })
defer func() { defer func() {
@@ -112,6 +114,10 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
if err != nil { if err != nil {
return err return err
} }
connect.AddRolloutFunc(func() error {
os.Remove(path)
return nil
})
err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace)) err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace))
if err != nil { if err != nil {
return err return err
@@ -137,6 +143,8 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp
) )
if isSameCluster { if isSameCluster {
sshCancel() sshCancel()
os.Remove(file)
os.Remove(path)
// same cluster, do nothing // same cluster, do nothing
logger.Infof("Connected with cluster") logger.Infof("Connected with cluster")
return nil return nil

View File

@@ -58,6 +58,7 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
sshCtx, sshCancel := context.WithCancel(context.Background()) sshCtx, sshCancel := context.WithCancel(context.Background())
svr.connect.AddRolloutFunc(func() error { svr.connect.AddRolloutFunc(func() error {
sshCancel() sshCancel()
os.Remove(file)
return nil return nil
}) })
sshCtx = plog.WithLogger(sshCtx, logger) sshCtx = plog.WithLogger(sshCtx, logger)
@@ -67,6 +68,7 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe
svr.connect.Cleanup(sshCtx) svr.connect.Cleanup(sshCtx)
svr.connect = nil svr.connect = nil
svr.t = time.Time{} svr.t = time.Time{}
os.Remove(file)
sshCancel() sshCancel()
} }
}() }()
@@ -114,12 +116,14 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
} }
connect.AddRolloutFunc(func() error { connect.AddRolloutFunc(func() error {
sshCancel() sshCancel()
os.Remove(file)
return nil return nil
}) })
defer func() { defer func() {
if e != nil { if e != nil {
connect.Cleanup(plog.WithLogger(context.Background(), logger)) connect.Cleanup(plog.WithLogger(context.Background(), logger))
sshCancel() sshCancel()
os.Remove(file)
} }
}() }()
var path string var path string
@@ -127,6 +131,10 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
if err != nil { if err != nil {
return err return err
} }
connect.AddRolloutFunc(func() error {
os.Remove(path)
return nil
})
err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace)) err = connect.InitClient(util.InitFactoryByPath(path, req.Namespace))
if err != nil { if err != nil {
return err return err
@@ -152,6 +160,8 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon
) )
if isSameCluster { if isSameCluster {
sshCancel() sshCancel()
os.Remove(path)
os.Remove(file)
// same cluster, do nothing // same cluster, do nothing
logger.Infof("Connected to cluster") logger.Infof("Connected to cluster")
return nil return nil

View File

@@ -64,6 +64,9 @@ func (c *ConnectOptions) Cleanup(ctx context.Context) {
if err != nil { if err != nil {
plog.G(ctx).Errorf("Leave proxy resources error: %v", err) plog.G(ctx).Errorf("Leave proxy resources error: %v", err)
} }
if c.cancel != nil {
c.cancel()
}
for _, function := range c.getRolloutFunc() { for _, function := range c.getRolloutFunc() {
if function != nil { if function != nil {
@@ -72,9 +75,6 @@ func (c *ConnectOptions) Cleanup(ctx context.Context) {
} }
} }
} }
if c.cancel != nil {
c.cancel()
}
if c.dnsConfig != nil { if c.dnsConfig != nil {
if inUserDaemon { if inUserDaemon {
plog.G(ctx2).Infof("Clearing DNS settings") plog.G(ctx2).Infof("Clearing DNS settings")

View File

@@ -168,13 +168,13 @@ func (conf SshConfig) AliasRecursion(ctx context.Context, stopChan <-chan struct
if client == nil { if client == nil {
client, err = bastionList[i].Dial(ctx, stopChan) client, err = bastionList[i].Dial(ctx, stopChan)
if err != nil { if err != nil {
err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %s", bastionList[i])) err = errors.Wrap(err, fmt.Sprintf("Failed to connect to %v", bastionList[i]))
return return
} }
} else { } else {
client, err = JumpTo(ctx, client, bastionList[i], stopChan) client, err = JumpTo(ctx, client, bastionList[i], stopChan)
if err != nil { if err != nil {
err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %s", bastionList[i])) err = errors.Wrap(err, fmt.Sprintf("Failed to jump to %v", bastionList[i]))
return return
} }
} }

View File

@@ -111,6 +111,9 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
for ctx1.Err() == nil { for ctx1.Err() == nil {
localConn, err1 := localListen.Accept() localConn, err1 := localListen.Accept()
if err1 != nil { if err1 != nil {
if errors.Is(err1, net.ErrClosed) {
return
}
plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1) plog.G(ctx).Debugf("Failed to accept ssh conn: %v", err1)
continue continue
} }
@@ -123,14 +126,14 @@ func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.Addr
var openChannelError *gossh.OpenChannelError var openChannelError *gossh.OpenChannelError
// if ssh server not permitted ssh port-forward, do nothing until exit // if ssh server not permitted ssh port-forward, do nothing until exit
if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited { if errors.As(err, &openChannelError) && openChannelError.Reason == gossh.Prohibited {
plog.G(ctx).Debugf("Failed to open ssh port-forward: %s: %v", remote.String(), err) plog.G(ctx).Debugf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
plog.G(ctx).Errorf("Failed to open ssh port-forward: %s: %v", remote.String(), err) plog.G(ctx).Errorf("Failed to open ssh port-forward to %s: %v", remote.String(), err)
cancelFunc1() cancelFunc1()
} }
plog.G(ctx).Debugf("Failed to get remote conn: %v", err) plog.G(ctx).Debugf("Failed to dial into remote %s: %v", remote.String(), err)
return return
} }
plog.G(ctx).Debugf("Opened ssh port-forward: %s", remote.String()) plog.G(ctx).Debugf("Opened ssh port-forward to %s", remote.String())
defer remoteConn.Close() defer remoteConn.Close()
copyStream(ctx, localConn, remoteConn) copyStream(ctx, localConn, remoteConn)
@@ -196,20 +199,12 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
return return
} }
var temp *os.File var file string
if temp, err = os.CreateTemp("", "*.kubeconfig"); err != nil { file, err = pkgutil.ConvertToTempKubeconfigFile(bytes.TrimSpace(stdout))
if err != nil {
return return
} }
if err = temp.Close(); err != nil { configFlags.KubeConfig = pointer.String(file)
return
}
if err = os.WriteFile(temp.Name(), stdout, 0644); err != nil {
return
}
if err = os.Chmod(temp.Name(), 0644); err != nil {
return
}
configFlags.KubeConfig = pointer.String(temp.Name())
} else { } else {
if flags != nil { if flags != nil {
lookup := flags.Lookup("kubeconfig") lookup := flags.Lookup("kubeconfig")
@@ -332,26 +327,16 @@ func SshJump(ctx context.Context, conf *SshConfig, flags *pflag.FlagSet, print b
plog.G(ctx).Errorf("failed to marshal config: %v", err) plog.G(ctx).Errorf("failed to marshal config: %v", err)
return return
} }
var temp *os.File path, err = pkgutil.ConvertToTempKubeconfigFile(marshal)
temp, err = os.CreateTemp("", "*.kubeconfig")
if err != nil { if err != nil {
return plog.G(ctx).Errorf("failed to write kubeconfig: %v", err)
}
if err = temp.Close(); err != nil {
return
}
if err = os.WriteFile(temp.Name(), marshal, 0644); err != nil {
return
}
if err = os.Chmod(temp.Name(), 0644); err != nil {
return return
} }
if print { if print {
plog.G(ctx).Infof("Use temporary kubeconfig: %s", temp.Name()) plog.G(ctx).Infof("Use temporary kubeconfig: %s", path)
} else { } else {
plog.G(ctx).Debugf("Use temporary kubeconfig: %s", temp.Name()) plog.G(ctx).Debugf("Use temporary kubeconfig: %s", path)
} }
path = temp.Name()
return return
} }
@@ -435,7 +420,7 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re
conn, err = cli.DialContext(ctx1, "tcp", remote.String()) conn, err = cli.DialContext(ctx1, "tcp", remote.String())
cancelFunc1() cancelFunc1()
if err != nil { if err != nil {
plog.G(ctx).Debugf("Failed to dial remote address %s: %s", remote.String(), err) plog.G(ctx).Debugf("Failed to dial remote address %s: %v", remote.String(), err)
clientMap.Delete(key) clientMap.Delete(key)
plog.G(ctx).Error("Delete invalid ssh client from map") plog.G(ctx).Error("Delete invalid ssh client from map")
_ = cli.Close() _ = cli.Close()
@@ -456,16 +441,16 @@ func getRemoteConn(ctx context.Context, clientMap *sync.Map, conf *SshConfig, re
return nil, err return nil, err
} }
clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1)) clientMap.Store(uuid.NewString(), newSshClientWrap(client, cancelFunc1))
plog.G(ctx1).Debug("Connected to remote SSH server") plog.G(ctx1).Debug("Connected to remote ssh server")
ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10) ctx2, cancelFunc2 := context.WithTimeout(ctx, time.Second*10)
defer cancelFunc2() defer cancelFunc2()
conn, err = client.DialContext(ctx2, "tcp", remote.String()) conn, err = client.DialContext(ctx2, "tcp", remote.String())
if err != nil { if err != nil {
plog.G(ctx).Debugf("Failed to dial remote addr: %s: %v", remote.String(), err) plog.G(ctx).Debugf("Failed to dial remote addr %s: %v", remote.String(), err)
return nil, err return nil, err
} }
plog.G(ctx).Debugf("Connected to remote addr: %s", remote.String()) plog.G(ctx).Debugf("Connected to remote addr %s", remote.String())
return conn, nil return conn, nil
} }

View File

@@ -85,19 +85,10 @@ func ConvertK8sApiServerToDomain(kubeConfigPath string) (newPath string, err err
if err != nil { if err != nil {
return return
} }
var temp *os.File newPath, err = ConvertToTempKubeconfigFile(marshal)
temp, err = os.CreateTemp("", "*.kubeconfig")
if err != nil { if err != nil {
return return
} }
if err = temp.Close(); err != nil {
return
}
err = os.WriteFile(temp.Name(), marshal, 0644)
if err != nil {
return
}
newPath = temp.Name()
return return
} }

View File

@@ -132,11 +132,15 @@ func ConvertToTempKubeconfigFile(kubeconfigBytes []byte) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
err = temp.Close() _, err = temp.Write(kubeconfigBytes)
if err != nil { if err != nil {
return "", err return "", err
} }
err = os.WriteFile(temp.Name(), kubeconfigBytes, os.ModePerm) err = temp.Chmod(0644)
if err != nil {
return "", err
}
err = temp.Close()
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -156,19 +160,11 @@ func InitFactory(kubeconfigBytes string, ns string) cmdutil.Factory {
} }
return c return c
} }
temp, err := os.CreateTemp("", "*.kubeconfig") file, err := ConvertToTempKubeconfigFile([]byte(kubeconfigBytes))
if err != nil { if err != nil {
return nil return nil
} }
err = temp.Close() configFlags.KubeConfig = pointer.String(file)
if err != nil {
return nil
}
err = os.WriteFile(temp.Name(), []byte(kubeconfigBytes), os.ModePerm)
if err != nil {
return nil
}
configFlags.KubeConfig = pointer.String(temp.Name())
configFlags.Namespace = pointer.String(ns) configFlags.Namespace = pointer.String(ns)
matchVersionFlags := cmdutil.NewMatchVersionFlags(configFlags) matchVersionFlags := cmdutil.NewMatchVersionFlags(configFlags)
return cmdutil.NewFactory(matchVersionFlags) return cmdutil.NewFactory(matchVersionFlags)
@@ -214,19 +210,9 @@ func GetKubeconfigPath(factory cmdutil.Factory) (string, error) {
return "", err return "", err
} }
temp, err := os.CreateTemp("", "*.kubeconfig") file, err := ConvertToTempKubeconfigFile(kubeconfigJsonBytes)
if err != nil { if err != nil {
return "", err return "", err
} }
temp.Close() return file, nil
err = os.WriteFile(temp.Name(), kubeconfigJsonBytes, 0644)
if err != nil {
return "", err
}
err = os.Chmod(temp.Name(), 0644)
if err != nil {
return "", err
}
return temp.Name(), nil
} }