diff --git a/cmd/kubevpn/cmds/controlplane.go b/cmd/kubevpn/cmds/controlplane.go index 5d7c8199..69378f39 100644 --- a/cmd/kubevpn/cmds/controlplane.go +++ b/cmd/kubevpn/cmds/controlplane.go @@ -20,10 +20,11 @@ func CmdControlPlane(_ cmdutil.Factory) *cobra.Command { Hidden: true, Short: "Control-plane is a envoy xds server", Long: `Control-plane is a envoy xds server, distribute envoy route configuration`, - Run: func(cmd *cobra.Command, args []string) { + RunE: func(cmd *cobra.Command, args []string) error { util.InitLoggerForServer(config.Debug) go util.StartupPProf(0) - controlplane.Main(watchDirectoryFilename, port, log.StandardLogger()) + err := controlplane.Main(cmd.Context(), watchDirectoryFilename, port, log.StandardLogger()) + return err }, } cmd.Flags().StringVarP(&watchDirectoryFilename, "watchDirectoryFilename", "w", "/etc/envoy/envoy-config.yaml", "full path to directory to watch for files") diff --git a/pkg/controlplane/main.go b/pkg/controlplane/main.go index 58ac3f89..1ca7aa09 100644 --- a/pkg/controlplane/main.go +++ b/pkg/controlplane/main.go @@ -10,14 +10,15 @@ import ( log "github.com/sirupsen/logrus" ) -func Main(filename string, port uint, logger *log.Logger) { +func Main(ctx context.Context, filename string, port uint, logger *log.Logger) error { snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger) proc := NewProcessor(snapshotCache, logger) + errChan := make(chan error, 2) + go func() { - ctx := context.Background() server := serverv3.NewServer(ctx, snapshotCache, nil) - RunServer(ctx, server, port) + errChan <- RunServer(ctx, server, port) }() notifyCh := make(chan NotifyMessage, 100) @@ -29,20 +30,29 @@ func Main(filename string, port uint, logger *log.Logger) { watcher, err := fsnotify.NewWatcher() if err != nil { - log.Fatal(fmt.Errorf("failed to create file watcher, err: %v", err)) + return fmt.Errorf("failed to create file watcher: %v", err) } defer watcher.Close() - if err = watcher.Add(filename); err != nil { - log.Fatal(fmt.Errorf("failed to add file: %s to wather, err: %v", filename, err)) + err = watcher.Add(filename) + if err != nil { + return fmt.Errorf("failed to add file: %s to wather: %v", filename, err) } go func() { - log.Fatal(Watch(watcher, filename, notifyCh)) + errChan <- Watch(watcher, filename, notifyCh) }() for { select { case msg := <-notifyCh: - proc.ProcessFile(msg) + err = proc.ProcessFile(msg) + if err != nil { + log.Errorf("failed to process file: %v", err) + return err + } + case err = <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() } } } diff --git a/pkg/controlplane/processor.go b/pkg/controlplane/processor.go index 93048477..888cd449 100644 --- a/pkg/controlplane/processor.go +++ b/pkg/controlplane/processor.go @@ -44,11 +44,11 @@ func (p *Processor) newVersion() string { return strconv.FormatInt(p.version, 10) } -func (p *Processor) ProcessFile(file NotifyMessage) { +func (p *Processor) ProcessFile(file NotifyMessage) error { configList, err := ParseYaml(file.FilePath) if err != nil { p.logger.Errorf("error parsing yaml file: %+v", err) - return + return err } for _, config := range configList { if len(config.Uid) == 0 { @@ -76,21 +76,22 @@ func (p *Processor) ProcessFile(file NotifyMessage) { if err != nil { p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err) - return + return err } if err = snapshot.Consistent(); err != nil { p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err) - return + return err } p.logger.Debugf("will serve snapshot %+v, nodeID: %s", snapshot, config.Uid) if err = p.cache.SetSnapshot(context.Background(), config.Uid, snapshot); err != nil { p.logger.Errorf("snapshot error %q for %v", err, snapshot) - p.logger.Fatal(err) + return err } p.expireCache.Set(config.Uid, config, time.Minute*5) } + return nil } func ParseYaml(file string) ([]*Virtual, error) { diff --git a/pkg/controlplane/server.go b/pkg/controlplane/server.go index 53cbe7bf..637451e2 100644 --- a/pkg/controlplane/server.go +++ b/pkg/controlplane/server.go @@ -21,13 +21,13 @@ const ( grpcMaxConcurrentStreams = 1000000 ) -func RunServer(ctx context.Context, server serverv3.Server, port uint) { +func RunServer(ctx context.Context, server serverv3.Server, port uint) error { grpcServer := grpc.NewServer(grpc.MaxConcurrentStreams(grpcMaxConcurrentStreams)) var lc net.ListenConfig listener, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port)) if err != nil { - log.Fatal(err) + return err } discoverygrpc.RegisterAggregatedDiscoveryServiceServer(grpcServer, server) @@ -39,7 +39,5 @@ func RunServer(ctx context.Context, server serverv3.Server, port uint) { runtimeservice.RegisterRuntimeDiscoveryServiceServer(grpcServer, server) log.Infof("management server listening on %d", port) - if err = grpcServer.Serve(listener); err != nil { - log.Fatal(err) - } + return grpcServer.Serve(listener) } diff --git a/pkg/core/gvisorudphandler.go b/pkg/core/gvisorudphandler.go index 834c05f3..39d09f73 100644 --- a/pkg/core/gvisorudphandler.go +++ b/pkg/core/gvisorudphandler.go @@ -125,7 +125,7 @@ func GvisorUDPListener(addr string) (net.Listener, error) { if err != nil { return nil, err } - return &tcpKeepAliveListener{ln}, nil + return &tcpKeepAliveListener{TCPListener: ln}, nil } func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) { diff --git a/pkg/core/tcp.go b/pkg/core/tcp.go index 3b2ce46c..bf772b88 100644 --- a/pkg/core/tcp.go +++ b/pkg/core/tcp.go @@ -27,7 +27,7 @@ func TCPListener(addr string) (net.Listener, error) { if err != nil { return nil, err } - return &tcpKeepAliveListener{ln}, nil + return &tcpKeepAliveListener{TCPListener: ln}, nil } type tcpKeepAliveListener struct { diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index 812e1d30..a66335ac 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -142,8 +142,8 @@ func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) { } } -func (h tunHandler) printRoute() { - for { +func (h *tunHandler) printRoute(ctx context.Context) { + for ctx.Err() == nil { select { case <-time.Tick(time.Second * 5): var i int @@ -370,7 +370,7 @@ func (d *Device) SetTunInboundHandler(handler func(tunInbound <-chan *DataElem, } func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { - go h.printRoute() + go h.printRoute(ctx) device := &Device{ tun: tun, diff --git a/pkg/handler/connect.go b/pkg/handler/connect.go index b44cd4f1..367e45da 100644 --- a/pkg/handler/connect.go +++ b/pkg/handler/connect.go @@ -29,7 +29,6 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/pflag" "golang.org/x/crypto/ssh" - "golang.org/x/sync/errgroup" "google.golang.org/grpc/metadata" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -684,29 +683,31 @@ func (c *ConnectOptions) setupDNS(ctx context.Context, lite bool) error { } func Run(ctx context.Context, servers []core.Server) error { - group, ctx := errgroup.WithContext(ctx) + errChan := make(chan error, len(servers)) for i := range servers { - i := i - group.Go(func() error { - l := servers[i].Listener - defer l.Close() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: + go func(i int) { + errChan <- func() error { + svr := servers[i] + defer svr.Listener.Close() + for ctx.Err() == nil { + conn, err := svr.Listener.Accept() + if err != nil { + log.Debugf("server accept connect error: %v", err) + return err + } + go svr.Handler.Handle(ctx, conn) } - - conn, errs := l.Accept() - if errs != nil { - log.Debugf("server accept connect error: %v", errs) - continue - } - go servers[i].Handler.Handle(ctx, conn) - } - }) + return ctx.Err() + }() + }(i) + } + + select { + case err := <-errChan: + return err + case <-ctx.Done(): + return ctx.Err() } - return group.Wait() } func Parse(r core.Route) ([]core.Server, error) { diff --git a/pkg/handler/tools.go b/pkg/handler/tools.go index edefffc9..60f82376 100644 --- a/pkg/handler/tools.go +++ b/pkg/handler/tools.go @@ -31,7 +31,8 @@ func Complete(ctx context.Context, route *core.Route) error { if err != nil { return err } - resp, err := client.RentIP(context.Background(), &rpc.RentIPRequest{ + var resp *rpc.RentIPResponse + resp, err = client.RentIP(context.Background(), &rpc.RentIPRequest{ PodName: os.Getenv(config.EnvPodName), PodNamespace: ns, }) @@ -44,6 +45,8 @@ func Complete(ctx context.Context, route *core.Route) error { err := release(context.Background(), client) if err != nil { log.Errorf("release ip failed: %v", err) + } else { + log.Errorf("release ip secuess") } }() @@ -57,7 +60,8 @@ func Complete(ctx context.Context, route *core.Route) error { return err } for i := 0; i < len(route.ServeNodes); i++ { - node, err := core.ParseNode(route.ServeNodes[i]) + var node *core.Node + node, err = core.ParseNode(route.ServeNodes[i]) if err != nil { return err }