Added tests for ACL LOAD and ACL SAVE commands.

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-02 04:01:45 +08:00
parent bdfaf5446a
commit d4506ce54d
5 changed files with 1412 additions and 2193 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -17,6 +17,7 @@ package acl
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -57,7 +58,7 @@ func loadUsersFromConfigFile(users []*User, filePath string) {
return return
} }
// Open the config file. Create it if it does not exist. // Open the config file. Create it if it does not exist.
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, os.ModePerm) f, err := os.OpenFile(filePath, os.O_RDONLY|os.O_CREATE, os.ModePerm)
if err != nil { if err != nil {
log.Printf("open ACL config: %v\n", err) log.Printf("open ACL config: %v\n", err)
return return
@@ -231,14 +232,13 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
var passwords []Password var passwords []Password
var user *User var user *User
h := sha256.New()
if len(cmd) == 2 { if len(cmd) == 2 {
// Process AUTH <password> // Process AUTH <password>
h := sha256.New()
h.Write([]byte(cmd[1])) h.Write([]byte(cmd[1]))
passwords = []Password{ passwords = []Password{
{PasswordType: "plaintext", PasswordValue: cmd[1]}, {PasswordType: PasswordPlainText, PasswordValue: cmd[1]},
{PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))}, {PasswordType: PasswordSHA256, PasswordValue: hex.EncodeToString(h.Sum(nil))},
} }
// Authenticate with default user // Authenticate with default user
idx := slices.IndexFunc(acl.Users, func(user *User) bool { idx := slices.IndexFunc(acl.Users, func(user *User) bool {
@@ -249,10 +249,11 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
if len(cmd) == 3 { if len(cmd) == 3 {
// Process AUTH <username> <password> // Process AUTH <username> <password>
h := sha256.New()
h.Write([]byte(cmd[2])) h.Write([]byte(cmd[2]))
passwords = []Password{ passwords = []Password{
{PasswordType: "plaintext", PasswordValue: cmd[2]}, {PasswordType: PasswordPlainText, PasswordValue: cmd[2]},
{PasswordType: "SHA256", PasswordValue: string(h.Sum(nil))}, {PasswordType: PasswordSHA256, PasswordValue: hex.EncodeToString(h.Sum(nil))},
} }
// Find user with the specified username // Find user with the specified username
userFound := false userFound := false
@@ -284,7 +285,7 @@ func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []
for _, userPassword := range user.Passwords { for _, userPassword := range user.Passwords {
for _, password := range passwords { for _, password := range passwords {
if strings.EqualFold(userPassword.PasswordType, password.PasswordType) && if userPassword.PasswordType == password.PasswordType &&
userPassword.PasswordValue == password.PasswordValue && userPassword.PasswordValue == password.PasswordValue &&
user.Enabled { user.Enabled {
// Set the current connection to the selected user and set them as authenticated // Set the current connection to the selected user and set them as authenticated

View File

@@ -342,7 +342,7 @@ func handleList(params internal.HandlerFuncParams) ([]byte, error) {
s += fmt.Sprintf(" %s~%s", "%R", key) s += fmt.Sprintf(" %s~%s", "%R", key)
} }
// Included write keys // Included write keys
for _, key := range user.IncludedReadKeys { for _, key := range user.IncludedWriteKeys {
if !slices.Contains(user.IncludedReadKeys, key) { if !slices.Contains(user.IncludedReadKeys, key) {
s += fmt.Sprintf(" %s~%s", "%W", key) s += fmt.Sprintf(" %s~%s", "%W", key)
} }
@@ -375,7 +375,7 @@ func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
acl.LockUsers() acl.LockUsers()
defer acl.UnlockUsers() defer acl.UnlockUsers()
f, err := os.Open(acl.Config.AclConfig) f, err := os.OpenFile(acl.Config.AclConfig, os.O_RDONLY, os.ModePerm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -390,13 +390,13 @@ func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
var users []*User var users []*User
if ext == ".json" { if strings.ToLower(ext) == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil { if err := json.NewDecoder(f).Decode(&users); err != nil {
return nil, err return nil, err
} }
} }
if ext == ".yaml" || ext == ".yml" { if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) {
if err := yaml.NewDecoder(f).Decode(&users); err != nil { if err := yaml.NewDecoder(f).Decode(&users); err != nil {
return nil, err return nil, err
} }
@@ -420,7 +420,7 @@ func handleLoad(params internal.HandlerFuncParams) ([]byte, error) {
break break
} }
} }
// If the no user with current loaded username is already in acl list, then append the user to the list // If there is no user with current loaded username is already in acl list, then append the user to the list
if !userFound { if !userFound {
acl.Users = append(acl.Users, user) acl.Users = append(acl.Users, user)
} }
@@ -439,8 +439,8 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
return nil, errors.New("could not load ACL") return nil, errors.New("could not load ACL")
} }
acl.RLockUsers() acl.LockUsers()
acl.RUnlockUsers() acl.UnlockUsers()
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil { if err != nil {

View File

@@ -16,6 +16,7 @@ package acl_test
import ( import (
"crypto/sha256" "crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"github.com/echovault/echovault/echovault" "github.com/echovault/echovault/echovault"
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
@@ -151,7 +152,7 @@ func compareUsers(user1, user2 map[string][]string) error {
func generateSHA256Password(plain string) string { func generateSHA256Password(plain string) string {
h := sha256.New() h := sha256.New()
h.Write([]byte(plain)) h.Write([]byte(plain))
return string(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
func Test_ACL(t *testing.T) { func Test_ACL(t *testing.T) {
@@ -1497,9 +1498,11 @@ When nopass is provided, ignore any passwords that may have been provided in the
generateSHA256Password("password3"), "%RW"), generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*", "no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*", "disabled_user off >password5 +@all +all %RW~* +&*",
fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4"), fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`,
generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4 %W~key5 %W~key6"),
fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`), fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`),
fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4"), fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`,
generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4 %W~key5 %W~key6"),
}, },
wantErr: "", wantErr: "",
}, },
@@ -1736,5 +1739,219 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Cleanup(func() { t.Cleanup(func() {
_ = os.RemoveAll(baseDir) _ = os.RemoveAll(baseDir)
}) })
servers := make([]*echovault.EchoVault, 5)
defer func() {
for _, server := range servers {
if server != nil {
server.ShutDown()
}
}
}()
tests := []struct {
name string
path string
users []echovault.User // Add users after server startup.
cmd []resp.Value // Command to load users from ACL config.
want []string
}{
{
name: "1. Load config from the .json file",
path: path.Join(baseDir, "json_test.json"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "2. Load users from the .yaml file",
path: path.Join(baseDir, "yaml_test.yaml"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "3. Load users from the .yml file",
path: path.Join(baseDir, "yml_test.yml"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "4. Merge loaded users",
path: path.Join(baseDir, "merge.yml"),
users: []echovault.User{
{ // Disable user1.
Username: "user1",
Enabled: false,
},
{ // Update with_password_user. This should be merged with the existing user.
Username: "with_password_user",
AddPlainPasswords: []string{"password3", "password4"},
IncludeReadWriteKeys: []string{"key1", "key2"},
IncludeWriteKeys: []string{"key3", "key4"},
IncludeReadKeys: []string{"key5", "key6"},
IncludeChannels: []string{"channel[12]"},
ExcludeChannels: []string{"channel[34]"},
},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("MERGE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf(`with_password_user on >password2 >password3 >password4 #%s +@all +all %s~key1 %s~key2 %s~key5 %s~key6 %s~key3 %s~key4 +&channel[12] -&channel[34]`,
generateSHA256Password("password3"), "%RW", "%RW", "%R", "%R", "%W", "%W"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 off +@all +all %RW~* +&*",
},
},
{
name: "5. Replace loaded users",
path: path.Join(baseDir, "replace.yml"),
users: []echovault.User{
{ // Disable user1.
Username: "user1",
Enabled: false,
},
{ // Update with_password_user. This should be merged with the existing user.
Username: "with_password_user",
AddPlainPasswords: []string{"password3", "password4"},
IncludeReadWriteKeys: []string{"key1", "key2"},
IncludeWriteKeys: []string{"key3", "key4"},
IncludeReadKeys: []string{"key5", "key6"},
IncludeChannels: []string{"channel[12]"},
ExcludeChannels: []string{"channel[34]"},
},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 off +@all +all %RW~* +&*",
},
},
}
for i, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Create server with pre-generated users.
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
servers[i] = mockServer
go func() {
mockServer.Start()
}()
// Save the current users to the ACL config file.
if _, err := mockServer.ACLSave(); err != nil {
t.Error(err)
return
}
// Add some users to the ACL.
for _, user := range test.users {
if _, err := mockServer.ACLSetUser(user); err != nil {
t.Error(err)
return
}
}
// Establish client connection
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Load the users from the ACL config file.
if err := client.WriteArray(test.cmd); err != nil {
t.Error(err)
return
}
fmt.Println("COMMAND WRITTEN")
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Error(err)
mockServer.ShutDown()
return
}
// Get ACL List
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
t.Error(err)
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
return
}
// Check if ACL LIST returns the expected list of users.
resArr := res.Array()
if len(resArr) != len(test.want) {
t.Errorf("expected response of length %d, got lenght %d", len(test.want), len(resArr))
return
}
var resStr []string
for i := 0; i < len(resArr); i++ {
resStr = strings.Split(resArr[i].String(), " ")
if !slices.ContainsFunc(test.want, func(s string) bool {
expectedUserSlice := strings.Split(s, " ")
return compareSlices(resStr, expectedUserSlice) == nil
}) {
t.Errorf("could not find the following user in expected slice: %+v", resStr)
return
}
}
})
}
}) })
} }

View File

@@ -86,6 +86,15 @@ func (user *User) Normalise() {
if slices.Contains(user.ExcludedPubSubChannels, "*") { if slices.Contains(user.ExcludedPubSubChannels, "*") {
user.IncludedPubSubChannels = []string{} user.IncludedPubSubChannels = []string{}
} }
// Sort passwords
slices.SortStableFunc(user.Passwords, func(a, b Password) int {
types := map[string]int{
PasswordPlainText: 0,
PasswordSHA256: 1,
}
return types[a.PasswordType] - types[b.PasswordType]
})
} }
func RemoveDuplicateEntries(entries []string, allAlias string) (res []string) { func RemoveDuplicateEntries(entries []string, allAlias string) (res []string) {
@@ -98,11 +107,13 @@ func RemoveDuplicateEntries(entries []string, allAlias string) (res []string) {
entriesMap[entry] += 1 entriesMap[entry] += 1
} }
for key, _ := range entriesMap { for key, _ := range entriesMap {
if key == "*" { if key == "*" && len(entriesMap) == 1 {
res = []string{"*"} res = []string{"*"}
return return
} }
res = append(res, key) if key != "*" {
res = append(res, key)
}
} }
return return
} }
@@ -127,19 +138,13 @@ func (user *User) UpdateUser(cmd []string) error {
} }
if str[0] == '<' { if str[0] == '<' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, PasswordSHA256) { return strings.EqualFold(password.PasswordType, PasswordPlainText) && password.PasswordValue == str[1:]
return false
}
return password.PasswordValue == str[1:]
}) })
continue continue
} }
if str[0] == '!' { if str[0] == '!' {
user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool {
if strings.EqualFold(password.PasswordType, PasswordPlainText) { return strings.EqualFold(password.PasswordType, PasswordSHA256) && password.PasswordValue == str[1:]
return false
}
return password.PasswordValue == str[1:]
}) })
continue continue
} }
@@ -253,6 +258,7 @@ func (user *User) UpdateUser(cmd []string) error {
user.ExcludedPubSubChannels = []string{"*"} user.ExcludedPubSubChannels = []string{"*"}
} }
} }
return nil return nil
} }
@@ -260,7 +266,6 @@ func (user *User) Merge(new *User) {
user.Enabled = new.Enabled user.Enabled = new.Enabled
user.NoKeys = new.NoKeys user.NoKeys = new.NoKeys
user.NoPassword = new.NoPassword user.NoPassword = new.NoPassword
user.Passwords = append(user.Passwords, new.Passwords...)
user.IncludedCategories = append(user.IncludedCategories, new.IncludedCategories...) user.IncludedCategories = append(user.IncludedCategories, new.IncludedCategories...)
user.ExcludedCategories = append(user.ExcludedCategories, new.ExcludedCategories...) user.ExcludedCategories = append(user.ExcludedCategories, new.ExcludedCategories...)
user.IncludedCommands = append(user.IncludedCommands, new.IncludedCommands...) user.IncludedCommands = append(user.IncludedCommands, new.IncludedCommands...)
@@ -269,6 +274,16 @@ func (user *User) Merge(new *User) {
user.IncludedWriteKeys = append(user.IncludedWriteKeys, new.IncludedWriteKeys...) user.IncludedWriteKeys = append(user.IncludedWriteKeys, new.IncludedWriteKeys...)
user.IncludedPubSubChannels = append(user.IncludedPubSubChannels, new.IncludedPubSubChannels...) user.IncludedPubSubChannels = append(user.IncludedPubSubChannels, new.IncludedPubSubChannels...)
user.ExcludedPubSubChannels = append(user.ExcludedPubSubChannels, new.ExcludedPubSubChannels...) user.ExcludedPubSubChannels = append(user.ExcludedPubSubChannels, new.ExcludedPubSubChannels...)
// Add passwords.
for _, password := range new.Passwords {
if !slices.ContainsFunc(user.Passwords, func(p Password) bool {
return p.PasswordType == password.PasswordType && p.PasswordValue == password.PasswordValue
}) {
user.Passwords = append(user.Passwords, new.Passwords...)
}
}
user.Normalise() user.Normalise()
} }