From 457fda89a0e4ce54157223b382f9cad020eb367e Mon Sep 17 00:00:00 2001 From: fengcaiwen Date: Fri, 27 Oct 2023 17:49:38 +0800 Subject: [PATCH] feat: optimize code --- cmd/kubevpn/cmds/dev.go | 7 ++++-- pkg/daemon/action/connect-fork.go | 12 ++++++++--- pkg/daemon/action/connect.go | 12 +++++++++-- pkg/dev/convert.go | 4 ++-- pkg/dev/main.go | 17 +++++++++++---- pkg/dns/dns.go | 2 ++ pkg/handler/cleaner.go | 7 +++--- pkg/handler/clone.go | 17 ++++++++++++--- pkg/handler/connect.go | 36 ++++++++++++++++++++++--------- pkg/handler/envoy.go | 2 +- pkg/handler/remote.go | 4 ++-- 11 files changed, 88 insertions(+), 32 deletions(-) diff --git a/cmd/kubevpn/cmds/dev.go b/cmd/kubevpn/cmds/dev.go index aff8b045..c9042f6e 100644 --- a/cmd/kubevpn/cmds/dev.go +++ b/cmd/kubevpn/cmds/dev.go @@ -6,6 +6,7 @@ import ( "github.com/docker/cli/cli/command" dockercomp "github.com/docker/cli/cli/command/completion" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" cmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/util/completion" @@ -110,9 +111,11 @@ Startup your kubernetes workloads in local Docker container with same volume态e } err = dev.DoDev(cmd.Context(), devOptions, sshConf, cmd.Flags(), f, transferImage) - for _, fun := range devOptions.RollbackFuncList { + for _, fun := range devOptions.GetRollbackFuncList() { if fun != nil { - fun() + if err = fun(); err != nil { + log.Errorf("roll back failed, error: %s", err.Error()) + } } } return err diff --git a/pkg/daemon/action/connect-fork.go b/pkg/daemon/action/connect-fork.go index 37204f51..38556890 100644 --- a/pkg/daemon/action/connect-fork.go +++ b/pkg/daemon/action/connect-fork.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "io" - "k8s.io/utils/pointer" defaultlog "log" log "github.com/sirupsen/logrus" "github.com/spf13/pflag" + "k8s.io/utils/pointer" "github.com/wencaiwulue/kubevpn/pkg/config" "github.com/wencaiwulue/kubevpn/pkg/daemon/rpc" @@ -61,7 +61,10 @@ func (svr *Server) ConnectFork(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectF }) sshCtx, sshCancel := context.WithCancel(context.Background()) - connect.RollbackFuncList = append(connect.RollbackFuncList, sshCancel) + connect.AddRolloutFunc(func() error { + sshCancel() + return nil + }) var path string path, err = handler.SshJump(sshCtx, sshConf, flags, false) if err != nil { @@ -118,7 +121,10 @@ func (svr *Server) redirectConnectForkToSudoDaemon(req *rpc.ConnectRequest, resp DefValue: file, }) sshCtx, sshCancel := context.WithCancel(context.Background()) - connect.RollbackFuncList = append(connect.RollbackFuncList, sshCancel) + connect.AddRolloutFunc(func() error { + sshCancel() + return nil + }) var path string path, err = handler.SshJump(sshCtx, sshConf, flags, true) if err != nil { diff --git a/pkg/daemon/action/connect.go b/pkg/daemon/action/connect.go index b2bdb28d..b9b70b87 100644 --- a/pkg/daemon/action/connect.go +++ b/pkg/daemon/action/connect.go @@ -76,7 +76,10 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe }) sshCtx, sshCancel := context.WithCancel(context.Background()) - svr.connect.RollbackFuncList = append(svr.connect.RollbackFuncList, sshCancel) + svr.connect.AddRolloutFunc(func() error { + sshCancel() + return nil + }) var path string path, err = handler.SshJump(sshCtx, sshConf, flags, false) if err != nil { @@ -100,6 +103,8 @@ func (svr *Server) Connect(req *rpc.ConnectRequest, resp rpc.Daemon_ConnectServe if err != nil { log.Errorf("do connect error: %v", err) svr.connect.Cleanup() + svr.connect = nil + svr.t = time.Time{} return err } return nil @@ -131,7 +136,10 @@ func (svr *Server) redirectToSudoDaemon(req *rpc.ConnectRequest, resp rpc.Daemon DefValue: file, }) sshCtx, sshCancel := context.WithCancel(context.Background()) - connect.RollbackFuncList = append(connect.RollbackFuncList, sshCancel) + connect.AddRolloutFunc(func() error { + sshCancel() + return nil + }) var path string path, err = handler.SshJump(sshCtx, sshConf, flags, true) if err != nil { diff --git a/pkg/dev/convert.go b/pkg/dev/convert.go index ead5b333..d7b7b974 100644 --- a/pkg/dev/convert.go +++ b/pkg/dev/convert.go @@ -223,8 +223,8 @@ func GetVolume(ctx context.Context, f util.Factory, ns, pod string, d *Options) if volumeMount.SubPath != "" { join = filepath.Join(join, volumeMount.SubPath) } - d.RollbackFuncList = append(d.RollbackFuncList, func() { - _ = os.RemoveAll(join) + d.AddRollbackFunc(func() error { + return os.RemoveAll(join) }) // pod-namespace/pod-name:path remotePath := fmt.Sprintf("%s/%s:%s", ns, pod, volumeMount.MountPath) diff --git a/pkg/dev/main.go b/pkg/dev/main.go index 1c9256be..78169df8 100644 --- a/pkg/dev/main.go +++ b/pkg/dev/main.go @@ -81,7 +81,7 @@ type Options struct { DockerCli *command.DockerCli // rollback - RollbackFuncList []func() + rollbackFuncList []func() error } func (d *Options) Main(ctx context.Context, tempContainerConfig *containerConfig) error { @@ -217,8 +217,8 @@ func (d *Options) Main(ctx context.Context, tempContainerConfig *containerConfig } } - d.RollbackFuncList = append(d.RollbackFuncList, func() { - _ = runConfigList.Remove(ctx, d.Cli) + d.AddRollbackFunc(func() error { + return runConfigList.Remove(ctx, d.Cli) }) err = runConfigList.Run(ctx, volume, d.Cli, d.DockerCli) if err != nil { @@ -576,8 +576,9 @@ func (d *Options) doConnect(ctx context.Context, f cmdutil.Factory, conf *util.S }, ) go h.Run(func() error { select {} }) - d.RollbackFuncList = append(d.RollbackFuncList, func() { + d.AddRollbackFunc(func() error { h.Close() + return nil }) err = runLogsWaitRunning(cancelCtx, d.DockerCli, id) if err != nil { @@ -864,3 +865,11 @@ func createKubevpnNetwork(ctx context.Context, cli *client.Client) (string, erro } return create.ID, nil } + +func (d *Options) AddRollbackFunc(f func() error) { + d.rollbackFuncList = append(d.rollbackFuncList, f) +} + +func (d *Options) GetRollbackFuncList() []func() error { + return d.rollbackFuncList +} diff --git a/pkg/dns/dns.go b/pkg/dns/dns.go index f20f7de8..5c24f36b 100644 --- a/pkg/dns/dns.go +++ b/pkg/dns/dns.go @@ -29,6 +29,8 @@ type Config struct { Ns []string UseLocalDNS bool TunName string + // lite mode means connect to another cluster + Lite bool Hosts []Entry } diff --git a/pkg/handler/cleaner.go b/pkg/handler/cleaner.go index e7e1b6bc..823b78b6 100644 --- a/pkg/handler/cleaner.go +++ b/pkg/handler/cleaner.go @@ -64,9 +64,11 @@ func (c *ConnectOptions) Cleanup() { log.Errorf("can not update ref-count: %v", err) } } - for _, function := range c.RollbackFuncList { + for _, function := range c.getRolloutFunc() { if function != nil { - function() + if err := function(); err != nil { + log.Warningf("rollout function error: %v", err) + } } } // leave proxy resources @@ -77,7 +79,6 @@ func (c *ConnectOptions) Cleanup() { if c.cancel != nil { c.cancel() } - c.RollbackFuncList = c.RollbackFuncList[:] if c.dnsConfig != nil { log.Infof("clean up dns") c.dnsConfig.CancelDNS() diff --git a/pkg/handler/clone.go b/pkg/handler/clone.go index 4c6264a0..83a4c538 100644 --- a/pkg/handler/clone.go +++ b/pkg/handler/clone.go @@ -68,7 +68,7 @@ type CloneOptions struct { config *rest.Config factory cmdutil.Factory - RollbackFuncList []func() + rollbackFuncList []func() error } func (d *CloneOptions) InitClient(f cmdutil.Factory) (err error) { @@ -189,8 +189,8 @@ func (d *CloneOptions) DoClone(ctx context.Context) error { if err != nil { return err } - d.RollbackFuncList = append(d.RollbackFuncList, func() { - _ = client.Resource(object.Mapping.Resource).Namespace(d.TargetNamespace).Delete(context.Background(), u.GetName(), metav1.DeleteOptions{}) + d.addRollbackFunc(func() error { + return client.Resource(object.Mapping.Resource).Namespace(d.TargetNamespace).Delete(context.Background(), u.GetName(), metav1.DeleteOptions{}) }) retryErr := retry.RetryOnConflict(retry.DefaultRetry, func() error { // (1) add annotation KUBECONFIG @@ -785,5 +785,16 @@ func (d *CloneOptions) Cleanup(workloads ...string) error { } log.Infof("clean up clone workload: %s successfully", workload) } + for _, f := range d.rollbackFuncList { + if f != nil { + if err := f(); err != nil { + log.Warningf("exec rollback function error: %s", err.Error()) + } + } + } return nil } + +func (d *CloneOptions) addRollbackFunc(f func() error) { + d.rollbackFuncList = append(d.rollbackFuncList, f) +} diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index c05a100d..f47c3b4f 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -92,7 +92,7 @@ type ConnectOptions struct { // needs to give it back to dhcp localTunIPv4 *net.IPNet localTunIPv6 *net.IPNet - RollbackFuncList []func() + rollbackFuncList []func() error dnsConfig *dns.Config apiServerIPs []net.IP @@ -266,7 +266,7 @@ func (c *ConnectOptions) DoConnect(ctx context.Context, isLite bool) (err error) log.Errorf("add extra route failed: %v", err) return } - if err = c.setupDNS(c.ctx); err != nil { + if err = c.setupDNS(c.ctx, isLite); err != nil { log.Errorf("set up dns failed: %v", err) return } @@ -614,11 +614,14 @@ func (c *ConnectOptions) deleteFirewallRule(ctx context.Context) { if !util.FindAllowFirewallRule() { util.AddAllowFirewallRule() } - c.RollbackFuncList = append(c.RollbackFuncList, util.DeleteAllowFirewallRule) + c.AddRolloutFunc(func() error { + util.DeleteAllowFirewallRule() + return nil + }) go util.DeleteBlockFirewallRule(ctx) } -func (c *ConnectOptions) setupDNS(ctx context.Context) error { +func (c *ConnectOptions) setupDNS(ctx context.Context, lite bool) error { const port = 53 pod, err := c.GetRunningPodList(ctx) if err != nil { @@ -655,6 +658,7 @@ func (c *ConnectOptions) setupDNS(ctx context.Context) error { Ns: ns.UnsortedList(), UseLocalDNS: c.UseLocalDNS, TunName: tunName, + Lite: lite, Hosts: c.extraHost, } if err = c.dnsConfig.SetupDNS(); err != nil { @@ -1230,13 +1234,17 @@ RetryWithDNSClient: for _, rr := range answer.Answer { switch a := rr.(type) { case *miekgdns.A: - addRouteFunc(domain, a.A.String()) - c.extraHost = append(c.extraHost, dns.Entry{IP: a.A.String(), Domain: domain}) - success = true + if ip := net.ParseIP(a.A.String()); ip != nil && !ip.IsLoopback() { + addRouteFunc(domain, a.A.String()) + c.extraHost = append(c.extraHost, dns.Entry{IP: a.A.String(), Domain: domain}) + success = true + } case *miekgdns.AAAA: - addRouteFunc(domain, a.AAAA.String()) - c.extraHost = append(c.extraHost, dns.Entry{IP: a.AAAA.String(), Domain: domain}) - success = true + if ip := net.ParseIP(a.AAAA.String()); ip != nil && !ip.IsLoopback() { + addRouteFunc(domain, a.AAAA.String()) + c.extraHost = append(c.extraHost, dns.Entry{IP: a.AAAA.String(), Domain: domain}) + success = true + } } } return nil @@ -1517,3 +1525,11 @@ func (c *ConnectOptions) GetKubeconfigCluster() string { } return "" } + +func (c *ConnectOptions) AddRolloutFunc(f func() error) { + c.rollbackFuncList = append(c.rollbackFuncList, f) +} + +func (c *ConnectOptions) getRolloutFunc() []func() error { + return c.rollbackFuncList +} diff --git a/pkg/handler/envoy.go b/pkg/handler/envoy.go index 76ad84f8..c62e03b7 100644 --- a/pkg/handler/envoy.go +++ b/pkg/handler/envoy.go @@ -67,7 +67,7 @@ func InjectVPNAndEnvoySidecar(ctx1 context.Context, factory cmdutil.Factory, cli } if containerNames.HasAll(config.ContainerSidecarVPN, config.ContainerSidecarEnvoyProxy) { // add rollback func to remove envoy config - //RollbackFuncList = append(RollbackFuncList, func() { + //rollbackFuncList = append(rollbackFuncList, func() { // err := UnPatchContainer(factory, clientset, namespace, workload, c.LocalTunIPv4) // if err != nil { // log.Error(err) diff --git a/pkg/handler/remote.go b/pkg/handler/remote.go index e197cce3..d19ed112 100644 --- a/pkg/handler/remote.go +++ b/pkg/handler/remote.go @@ -527,7 +527,7 @@ func InjectVPNSidecar(ctx1 context.Context, factory cmdutil.Factory, namespace, return err } - //RollbackFuncList = append(RollbackFuncList, func() { + //rollbackFuncList = append(rollbackFuncList, func() { // p2 := &v1.Pod{ObjectMeta: origin.ObjectMeta, Spec: origin.Spec} // CleanupUselessInfo(p2) // if err = CreateAfterDeletePod(factory, p2, helper); err != nil { @@ -560,7 +560,7 @@ func InjectVPNSidecar(ctx1 context.Context, factory cmdutil.Factory, namespace, return err } - //RollbackFuncList = append(RollbackFuncList, func() { + //rollbackFuncList = append(rollbackFuncList, func() { // if err = removeInboundContainer(factory, namespace, workload); err != nil { // log.Error(err) // }