feat: optimize code (#149)

Co-authored-by: wencaiwulue <895703375@qq.com>
This commit is contained in:
naison
2024-02-15 11:38:14 +08:00
committed by GitHub
parent 14e91d5110
commit 3ad6127132
9 changed files with 63 additions and 48 deletions

View File

@@ -20,10 +20,11 @@ func CmdControlPlane(_ cmdutil.Factory) *cobra.Command {
Hidden: true, Hidden: true,
Short: "Control-plane is a envoy xds server", Short: "Control-plane is a envoy xds server",
Long: `Control-plane is a envoy xds server, distribute envoy route configuration`, Long: `Control-plane is a envoy xds server, distribute envoy route configuration`,
Run: func(cmd *cobra.Command, args []string) { RunE: func(cmd *cobra.Command, args []string) error {
util.InitLoggerForServer(config.Debug) util.InitLoggerForServer(config.Debug)
go util.StartupPProf(0) go util.StartupPProf(0)
controlplane.Main(watchDirectoryFilename, port, log.StandardLogger()) err := controlplane.Main(cmd.Context(), watchDirectoryFilename, port, log.StandardLogger())
return err
}, },
} }
cmd.Flags().StringVarP(&watchDirectoryFilename, "watchDirectoryFilename", "w", "/etc/envoy/envoy-config.yaml", "full path to directory to watch for files") cmd.Flags().StringVarP(&watchDirectoryFilename, "watchDirectoryFilename", "w", "/etc/envoy/envoy-config.yaml", "full path to directory to watch for files")

View File

@@ -10,14 +10,15 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func Main(filename string, port uint, logger *log.Logger) { func Main(ctx context.Context, filename string, port uint, logger *log.Logger) error {
snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger) snapshotCache := cache.NewSnapshotCache(false, cache.IDHash{}, logger)
proc := NewProcessor(snapshotCache, logger) proc := NewProcessor(snapshotCache, logger)
errChan := make(chan error, 2)
go func() { go func() {
ctx := context.Background()
server := serverv3.NewServer(ctx, snapshotCache, nil) server := serverv3.NewServer(ctx, snapshotCache, nil)
RunServer(ctx, server, port) errChan <- RunServer(ctx, server, port)
}() }()
notifyCh := make(chan NotifyMessage, 100) notifyCh := make(chan NotifyMessage, 100)
@@ -29,20 +30,29 @@ func Main(filename string, port uint, logger *log.Logger) {
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
log.Fatal(fmt.Errorf("failed to create file watcher, err: %v", err)) return fmt.Errorf("failed to create file watcher: %v", err)
} }
defer watcher.Close() defer watcher.Close()
if err = watcher.Add(filename); err != nil { err = watcher.Add(filename)
log.Fatal(fmt.Errorf("failed to add file: %s to wather, err: %v", filename, err)) if err != nil {
return fmt.Errorf("failed to add file: %s to wather: %v", filename, err)
} }
go func() { go func() {
log.Fatal(Watch(watcher, filename, notifyCh)) errChan <- Watch(watcher, filename, notifyCh)
}() }()
for { for {
select { select {
case msg := <-notifyCh: case msg := <-notifyCh:
proc.ProcessFile(msg) err = proc.ProcessFile(msg)
if err != nil {
log.Errorf("failed to process file: %v", err)
return err
}
case err = <-errChan:
return err
case <-ctx.Done():
return ctx.Err()
} }
} }
} }

View File

@@ -44,11 +44,11 @@ func (p *Processor) newVersion() string {
return strconv.FormatInt(p.version, 10) return strconv.FormatInt(p.version, 10)
} }
func (p *Processor) ProcessFile(file NotifyMessage) { func (p *Processor) ProcessFile(file NotifyMessage) error {
configList, err := ParseYaml(file.FilePath) configList, err := ParseYaml(file.FilePath)
if err != nil { if err != nil {
p.logger.Errorf("error parsing yaml file: %+v", err) p.logger.Errorf("error parsing yaml file: %+v", err)
return return err
} }
for _, config := range configList { for _, config := range configList {
if len(config.Uid) == 0 { if len(config.Uid) == 0 {
@@ -76,21 +76,22 @@ func (p *Processor) ProcessFile(file NotifyMessage) {
if err != nil { if err != nil {
p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err) p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err)
return return err
} }
if err = snapshot.Consistent(); err != nil { if err = snapshot.Consistent(); err != nil {
p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err) p.logger.Errorf("snapshot inconsistency: %v, err: %v", snapshot, err)
return return err
} }
p.logger.Debugf("will serve snapshot %+v, nodeID: %s", snapshot, config.Uid) p.logger.Debugf("will serve snapshot %+v, nodeID: %s", snapshot, config.Uid)
if err = p.cache.SetSnapshot(context.Background(), config.Uid, snapshot); err != nil { if err = p.cache.SetSnapshot(context.Background(), config.Uid, snapshot); err != nil {
p.logger.Errorf("snapshot error %q for %v", err, snapshot) p.logger.Errorf("snapshot error %q for %v", err, snapshot)
p.logger.Fatal(err) return err
} }
p.expireCache.Set(config.Uid, config, time.Minute*5) p.expireCache.Set(config.Uid, config, time.Minute*5)
} }
return nil
} }
func ParseYaml(file string) ([]*Virtual, error) { func ParseYaml(file string) ([]*Virtual, error) {

View File

@@ -21,13 +21,13 @@ const (
grpcMaxConcurrentStreams = 1000000 grpcMaxConcurrentStreams = 1000000
) )
func RunServer(ctx context.Context, server serverv3.Server, port uint) { func RunServer(ctx context.Context, server serverv3.Server, port uint) error {
grpcServer := grpc.NewServer(grpc.MaxConcurrentStreams(grpcMaxConcurrentStreams)) grpcServer := grpc.NewServer(grpc.MaxConcurrentStreams(grpcMaxConcurrentStreams))
var lc net.ListenConfig var lc net.ListenConfig
listener, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port)) listener, err := lc.Listen(ctx, "tcp", fmt.Sprintf(":%d", port))
if err != nil { if err != nil {
log.Fatal(err) return err
} }
discoverygrpc.RegisterAggregatedDiscoveryServiceServer(grpcServer, server) discoverygrpc.RegisterAggregatedDiscoveryServiceServer(grpcServer, server)
@@ -39,7 +39,5 @@ func RunServer(ctx context.Context, server serverv3.Server, port uint) {
runtimeservice.RegisterRuntimeDiscoveryServiceServer(grpcServer, server) runtimeservice.RegisterRuntimeDiscoveryServiceServer(grpcServer, server)
log.Infof("management server listening on %d", port) log.Infof("management server listening on %d", port)
if err = grpcServer.Serve(listener); err != nil { return grpcServer.Serve(listener)
log.Fatal(err)
}
} }

View File

@@ -125,7 +125,7 @@ func GvisorUDPListener(addr string) (net.Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tcpKeepAliveListener{ln}, nil return &tcpKeepAliveListener{TCPListener: ln}, nil
} }
func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) { func handle(ctx context.Context, tcpConn net.Conn, udpConn *net.UDPConn) {

View File

@@ -27,7 +27,7 @@ func TCPListener(addr string) (net.Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tcpKeepAliveListener{ln}, nil return &tcpKeepAliveListener{TCPListener: ln}, nil
} }
type tcpKeepAliveListener struct { type tcpKeepAliveListener struct {

View File

@@ -142,8 +142,8 @@ func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
} }
} }
func (h tunHandler) printRoute() { func (h *tunHandler) printRoute(ctx context.Context) {
for { for ctx.Err() == nil {
select { select {
case <-time.Tick(time.Second * 5): case <-time.Tick(time.Second * 5):
var i int var i int
@@ -370,7 +370,7 @@ func (d *Device) SetTunInboundHandler(handler func(tunInbound <-chan *DataElem,
} }
func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) {
go h.printRoute() go h.printRoute(ctx)
device := &Device{ device := &Device{
tun: tun, tun: tun,

View File

@@ -29,7 +29,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/pflag" "github.com/spf13/pflag"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
v1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
@@ -684,29 +683,31 @@ func (c *ConnectOptions) setupDNS(ctx context.Context, lite bool) error {
} }
func Run(ctx context.Context, servers []core.Server) error { func Run(ctx context.Context, servers []core.Server) error {
group, ctx := errgroup.WithContext(ctx) errChan := make(chan error, len(servers))
for i := range servers { for i := range servers {
i := i go func(i int) {
group.Go(func() error { errChan <- func() error {
l := servers[i].Listener svr := servers[i]
defer l.Close() defer svr.Listener.Close()
for { for ctx.Err() == nil {
select { conn, err := svr.Listener.Accept()
case <-ctx.Done(): if err != nil {
return ctx.Err() log.Debugf("server accept connect error: %v", err)
default: return err
}
go svr.Handler.Handle(ctx, conn)
} }
return ctx.Err()
conn, errs := l.Accept() }()
if errs != nil { }(i)
log.Debugf("server accept connect error: %v", errs) }
continue
} select {
go servers[i].Handler.Handle(ctx, conn) case err := <-errChan:
} return err
}) case <-ctx.Done():
return ctx.Err()
} }
return group.Wait()
} }
func Parse(r core.Route) ([]core.Server, error) { func Parse(r core.Route) ([]core.Server, error) {

View File

@@ -31,7 +31,8 @@ func Complete(ctx context.Context, route *core.Route) error {
if err != nil { if err != nil {
return err return err
} }
resp, err := client.RentIP(context.Background(), &rpc.RentIPRequest{ var resp *rpc.RentIPResponse
resp, err = client.RentIP(context.Background(), &rpc.RentIPRequest{
PodName: os.Getenv(config.EnvPodName), PodName: os.Getenv(config.EnvPodName),
PodNamespace: ns, PodNamespace: ns,
}) })
@@ -44,6 +45,8 @@ func Complete(ctx context.Context, route *core.Route) error {
err := release(context.Background(), client) err := release(context.Background(), client)
if err != nil { if err != nil {
log.Errorf("release ip failed: %v", err) log.Errorf("release ip failed: %v", err)
} else {
log.Errorf("release ip secuess")
} }
}() }()
@@ -57,7 +60,8 @@ func Complete(ctx context.Context, route *core.Route) error {
return err return err
} }
for i := 0; i < len(route.ServeNodes); i++ { for i := 0; i < len(route.ServeNodes); i++ {
node, err := core.ParseNode(route.ServeNodes[i]) var node *core.Node
node, err = core.ParseNode(route.ServeNodes[i])
if err != nil { if err != nil {
return err return err
} }