Added tests for ACL LOAD and ACL SAVE commands.

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-02 04:01:45 +08:00
parent bdfaf5446a
commit d4506ce54d
5 changed files with 1412 additions and 2193 deletions

View File

@@ -16,6 +16,7 @@ package acl_test
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"github.com/echovault/echovault/echovault"
"github.com/echovault/echovault/internal"
@@ -151,7 +152,7 @@ func compareUsers(user1, user2 map[string][]string) error {
func generateSHA256Password(plain string) string {
h := sha256.New()
h.Write([]byte(plain))
return string(h.Sum(nil))
return hex.EncodeToString(h.Sum(nil))
}
func Test_ACL(t *testing.T) {
@@ -1497,9 +1498,11 @@ When nopass is provided, ignore any passwords that may have been provided in the
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`,
generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4 %W~key5 %W~key6"),
fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`),
fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`,
generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4 %W~key5 %W~key6"),
},
wantErr: "",
},
@@ -1736,5 +1739,219 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
servers := make([]*echovault.EchoVault, 5)
defer func() {
for _, server := range servers {
if server != nil {
server.ShutDown()
}
}
}()
tests := []struct {
name string
path string
users []echovault.User // Add users after server startup.
cmd []resp.Value // Command to load users from ACL config.
want []string
}{
{
name: "1. Load config from the .json file",
path: path.Join(baseDir, "json_test.json"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "2. Load users from the .yaml file",
path: path.Join(baseDir, "yaml_test.yaml"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "3. Load users from the .yml file",
path: path.Join(baseDir, "yml_test.yml"),
users: []echovault.User{
{Username: "user1", Enabled: true},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 on +@all +all %RW~* +&*",
},
},
{
name: "4. Merge loaded users",
path: path.Join(baseDir, "merge.yml"),
users: []echovault.User{
{ // Disable user1.
Username: "user1",
Enabled: false,
},
{ // Update with_password_user. This should be merged with the existing user.
Username: "with_password_user",
AddPlainPasswords: []string{"password3", "password4"},
IncludeReadWriteKeys: []string{"key1", "key2"},
IncludeWriteKeys: []string{"key3", "key4"},
IncludeReadKeys: []string{"key5", "key6"},
IncludeChannels: []string{"channel[12]"},
ExcludeChannels: []string{"channel[34]"},
},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("MERGE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf(`with_password_user on >password2 >password3 >password4 #%s +@all +all %s~key1 %s~key2 %s~key5 %s~key6 %s~key3 %s~key4 +&channel[12] -&channel[34]`,
generateSHA256Password("password3"), "%RW", "%RW", "%R", "%R", "%W", "%W"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 off +@all +all %RW~* +&*",
},
},
{
name: "5. Replace loaded users",
path: path.Join(baseDir, "replace.yml"),
users: []echovault.User{
{ // Disable user1.
Username: "user1",
Enabled: false,
},
{ // Update with_password_user. This should be merged with the existing user.
Username: "with_password_user",
AddPlainPasswords: []string{"password3", "password4"},
IncludeReadWriteKeys: []string{"key1", "key2"},
IncludeWriteKeys: []string{"key3", "key4"},
IncludeReadKeys: []string{"key5", "key6"},
IncludeChannels: []string{"channel[12]"},
ExcludeChannels: []string{"channel[34]"},
},
},
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LOAD"), resp.StringValue("REPLACE")},
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
"user1 off +@all +all %RW~* +&*",
},
},
}
for i, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Create server with pre-generated users.
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
servers[i] = mockServer
go func() {
mockServer.Start()
}()
// Save the current users to the ACL config file.
if _, err := mockServer.ACLSave(); err != nil {
t.Error(err)
return
}
// Add some users to the ACL.
for _, user := range test.users {
if _, err := mockServer.ACLSetUser(user); err != nil {
t.Error(err)
return
}
}
// Establish client connection
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
// Load the users from the ACL config file.
if err := client.WriteArray(test.cmd); err != nil {
t.Error(err)
return
}
fmt.Println("COMMAND WRITTEN")
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Error(err)
mockServer.ShutDown()
return
}
// Get ACL List
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
t.Error(err)
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
return
}
// Check if ACL LIST returns the expected list of users.
resArr := res.Array()
if len(resArr) != len(test.want) {
t.Errorf("expected response of length %d, got lenght %d", len(test.want), len(resArr))
return
}
var resStr []string
for i := 0; i < len(resArr); i++ {
resStr = strings.Split(resArr[i].String(), " ")
if !slices.ContainsFunc(test.want, func(s string) bool {
expectedUserSlice := strings.Split(s, " ")
return compareSlices(resStr, expectedUserSlice) == nil
}) {
t.Errorf("could not find the following user in expected slice: %+v", resStr)
return
}
}
})
}
})
}