diff --git a/src/modules/acl/commands.go b/src/modules/acl/commands.go index 03caa1b..ef3d308 100644 --- a/src/modules/acl/commands.go +++ b/src/modules/acl/commands.go @@ -359,47 +359,113 @@ func (p Plugin) handleList(ctx context.Context, cmd []string, server utils.Serve } func (p Plugin) handleLoad(ctx context.Context, cmd []string, server utils.Server) ([]byte, error) { - return nil, errors.New("ACL LOAD not implemented") + if len(cmd) != 3 { + return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + } + + f, err := os.Open(p.acl.Config.AclConfig) + if err != nil { + return nil, err + } + + defer func() { + if err := f.Close(); err != nil { + // TODO: Log file close error with context + fmt.Println(err) + } + }() + + ext := path.Ext(f.Name()) + + var users []*User + + if ext == ".json" { + if err := json.NewDecoder(f).Decode(&users); err != nil { + return nil, err + } + } + + if ext == ".yaml" || ext == ".yml" { + if err := yaml.NewDecoder(f).Decode(&users); err != nil { + return nil, err + } + } + + // Normalise each user + for _, user := range users { + user.Normalise() + // Traverse the list of users. + userFound := false + for _, u := range p.acl.Users { + if u.Username == user.Username { + userFound = true + // If we have a user with the current username and are in merge mode, merge the two users. + if strings.EqualFold(cmd[2], "merge") { + u.Merge(user) + } else { + // If we have a user with the current username and are in replace mode, merge the two users. + u.Replace(user) + } + break + } + } + // If the no user with current loaded username is already in acl list, then append the user to the list + if !userFound { + p.acl.Users = append(p.acl.Users, user) + } + } + + return []byte(utils.OK_RESPONSE), nil } func (p Plugin) handleSave(ctx context.Context, cmd []string, server utils.Server) ([]byte, error) { - if f, err := os.OpenFile(p.acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend); err != nil { + if len(cmd) > 2 { + return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + } + + f, err := os.OpenFile(p.acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend) + if err != nil { return nil, err - } else { - defer func() { - if err := f.Close(); err != nil { - // TODO: Log file close error - fmt.Println(err) - } - }() - ext := path.Ext(f.Name()) - if ext == ".json" { - // Write to JSON config file - out, err := json.Marshal(p.acl.Users) - if err != nil { - return nil, err - } - _, err = f.Write(out) - if err != nil { - return nil, err - } + } + + defer func() { + if err := f.Close(); err != nil { + // TODO: Log file close error with context + fmt.Println(err) } - if ext == ".yaml" || ext == ".yml" { - // Write to yaml file - out, err := yaml.Marshal(p.acl.Users) - if err != nil { - return nil, err - } - _, err = f.Write(out) - if err != nil { - return nil, err - } + }() + + ext := path.Ext(f.Name()) + + if ext == ".json" { + // Write to JSON config file + out, err := json.Marshal(p.acl.Users) + if err != nil { + return nil, err } - err = f.Sync() + _, err = f.Write(out) if err != nil { return nil, err } } + + if ext == ".yaml" || ext == ".yml" { + // Write to yaml file + out, err := yaml.Marshal(p.acl.Users) + if err != nil { + return nil, err + } + _, err = f.Write(out) + if err != nil { + return nil, err + } + } + + err = f.Sync() + if err != nil { + return nil, err + } + return []byte(utils.OK_RESPONSE), nil } @@ -487,10 +553,13 @@ func NewModule(acl *ACL) Plugin { }, }, { - Command: "load", - Categories: []string{utils.AdminCategory, utils.SlowCategory, utils.DangerousCategory}, - Description: "(ACL LOAD) Reloads the rules from the configured ACL config file", - Sync: true, + Command: "load", + Categories: []string{utils.AdminCategory, utils.SlowCategory, utils.DangerousCategory}, + Description: ` +(ACL LOAD ) Reloads the rules from the configured ACL config file. +When 'MERGE' is passed, users from config file who share a username with users in memory will be merged. +When 'REPLACED' is passed, users from config file who share a username with users in memory will replace the user in memory.`, + Sync: true, KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, diff --git a/src/modules/acl/user.go b/src/modules/acl/user.go index bc30a51..0ed1151 100644 --- a/src/modules/acl/user.go +++ b/src/modules/acl/user.go @@ -228,6 +228,39 @@ func (user *User) UpdateUser(cmd []string) error { return nil } +func (user *User) Merge(new *User) { + user.Enabled = new.Enabled + user.NoKeys = new.NoKeys + user.NoPassword = new.NoPassword + user.Passwords = append(user.Passwords, new.Passwords...) + user.IncludedCategories = append(user.IncludedCategories, new.IncludedCategories...) + user.ExcludedCategories = append(user.ExcludedCategories, new.ExcludedCategories...) + user.IncludedCommands = append(user.IncludedCommands, new.IncludedCommands...) + user.ExcludedCommands = append(user.ExcludedCommands, new.ExcludedCommands...) + user.IncludedKeys = append(user.IncludedKeys, new.IncludedKeys...) + user.IncludedReadKeys = append(user.IncludedReadKeys, new.IncludedReadKeys...) + user.IncludedWriteKeys = append(user.IncludedWriteKeys, new.IncludedWriteKeys...) + user.IncludedPubSubChannels = append(user.IncludedPubSubChannels, new.IncludedPubSubChannels...) + user.ExcludedPubSubChannels = append(user.ExcludedPubSubChannels, new.ExcludedPubSubChannels...) + user.Normalise() +} + +func (user *User) Replace(new *User) { + user.Enabled = new.Enabled + user.NoKeys = new.NoKeys + user.NoPassword = new.NoPassword + user.Passwords = new.Passwords + user.IncludedCategories = new.IncludedCategories + user.ExcludedCategories = new.ExcludedCategories + user.IncludedCommands = new.IncludedCommands + user.ExcludedCommands = new.ExcludedCommands + user.IncludedKeys = new.IncludedKeys + user.IncludedReadKeys = new.IncludedReadKeys + user.IncludedWriteKeys = new.IncludedWriteKeys + user.IncludedPubSubChannels = new.IncludedPubSubChannels + user.ExcludedPubSubChannels = new.ExcludedPubSubChannels +} + func GetPasswordType(password string) string { if password[0] == '#' { return "SHA256"