Implemented ACL SAVE command

This commit is contained in:
Kelvin Clement Mwinuka
2023-12-19 12:12:01 +08:00
parent 1d36d2c772
commit 30e70e3804
4 changed files with 194 additions and 151 deletions

View File

@@ -2,6 +2,7 @@ package acl
import (
"github.com/kelvinmwinuka/memstore/src/utils"
"strings"
)
type User struct {
@@ -84,6 +85,149 @@ func RemoveDuplicateEntries(entries []string, allAlias string) (res []string) {
return
}
func (user *User) UpdateUser(cmd []string) error {
for _, str := range cmd {
// Parse enabled
if strings.EqualFold(str, "on") {
user.Enabled = true
}
if strings.EqualFold(str, "off") {
user.Enabled = false
}
// Parse passwords
if str[0] == '>' || str[0] == '#' {
user.Passwords = append(user.Passwords, Password{
PasswordType: GetPasswordType(str),
PasswordValue: str[1:],
})
user.NoPassword = false
continue
}
if str[0] == '<' {
user.Passwords = utils.Filter(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "SHA256") {
return true
}
return password.PasswordValue == str[1:]
})
continue
}
if str[0] == '!' {
user.Passwords = utils.Filter(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, "plaintext") {
return true
}
return password.PasswordValue == str[1:]
})
continue
}
// Parse categories
if strings.EqualFold(str, "nocommands") {
user.ExcludedCategories = []string{"*"}
user.ExcludedCommands = []string{"*"}
continue
}
if strings.EqualFold(str, "allCategories") {
user.IncludedCategories = []string{"*"}
continue
}
if len(str) > 3 && str[1] == '@' {
if str[0] == '+' {
user.IncludedCategories = append(user.IncludedCategories, str[2:])
continue
}
if str[0] == '-' {
user.ExcludedCategories = append(user.ExcludedCategories, str[2:])
continue
}
}
// Parse keys
if strings.EqualFold(str, "allKeys") {
user.IncludedKeys = []string{"*"}
user.IncludedReadKeys = []string{"*"}
user.IncludedWriteKeys = []string{"*"}
continue
}
if len(str) > 1 && str[0] == '~' {
user.IncludedKeys = append(user.IncludedKeys, str[1:])
continue
}
if len(str) > 4 && strings.EqualFold(str[0:4], "%RW~") {
user.IncludedKeys = append(user.IncludedKeys, str[4:])
continue
}
if len(str) > 3 && strings.EqualFold(str[0:3], "%R~") {
user.IncludedReadKeys = append(user.IncludedReadKeys, str[3:])
continue
}
if len(str) > 3 && strings.EqualFold(str[0:3], "%W~") {
user.IncludedWriteKeys = append(user.IncludedWriteKeys, str[3:])
continue
}
// Parse channels
if strings.EqualFold(str, "allChannels") {
user.IncludedPubSubChannels = []string{"*"}
}
if len(str) > 2 && str[1] == '&' {
if str[0] == '+' {
user.IncludedPubSubChannels = append(user.IncludedPubSubChannels, str[2:])
continue
}
if str[0] == '-' {
user.ExcludedPubSubChannels = append(user.ExcludedPubSubChannels, str[2:])
continue
}
}
// Parse commands
if strings.EqualFold(str, "allCommands") {
user.IncludedCommands = []string{"*"}
continue
}
if len(str) > 2 && !utils.Contains([]uint8{'&', '@'}, str[1]) {
if str[0] == '+' {
user.IncludedCommands = append(user.IncludedCommands, str[1:])
continue
}
if str[0] == '-' {
user.ExcludedCommands = append(user.ExcludedCommands, str[1:])
continue
}
}
}
// If nopass is provided, delete all passwords
for _, str := range cmd {
if strings.EqualFold(str, "nopass") {
user.Passwords = []Password{}
user.NoPassword = true
}
}
for _, str := range cmd {
// If resetpass is provided, delete all passwords and set NoPassword to false
if strings.EqualFold(str, "resetpass") {
user.Passwords = []Password{}
user.NoPassword = false
}
// If nocommands is provided, disable all commands for this user
if strings.EqualFold(str, "nocommands") {
user.ExcludedCommands = []string{"*"}
}
// If resetkeys is provided, reset all keys that the user can access
if strings.EqualFold(str, "resetkeys") {
user.IncludedKeys = []string{}
user.IncludedReadKeys = []string{}
user.IncludedWriteKeys = []string{}
user.NoKeys = true
}
// If resetchannels is provided, remove all the pub/sub channels that the user can access
if strings.EqualFold(str, "resetchannels") {
user.ExcludedPubSubChannels = []string{"*"}
}
}
return nil
}
func GetPasswordType(password string) string {
if password[0] == '#' {
return "SHA256"