// Copyright 2024 Kelvin Clement Mwinuka // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package acl_test import ( "crypto/sha256" "fmt" "github.com/echovault/echovault/echovault" "github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" "github.com/echovault/echovault/internal/modules/acl" "slices" "sync" ) var bindAddr string var port uint16 var mockServer *echovault.EchoVault func init() { bindAddr = "localhost" p, _ := internal.GetFreePort() port = uint16(p) mockServer = setUpServer(bindAddr, port, true, "") wg := sync.WaitGroup{} wg.Add(1) go func() { wg.Done() mockServer.Start() }() wg.Wait() } func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig string) *echovault.EchoVault { conf := config.Config{ BindAddr: bindAddr, Port: port, DataDir: "", EvictionPolicy: constants.NoEviction, RequirePass: requirePass, Password: "password1", AclConfig: aclConfig, } mockServer, _ := echovault.NewEchoVault( echovault.WithConfig(conf), ) // Add the initial test users to the ACL module. // a.AddUsers(generateInitialTestUsers()) return mockServer } func generateInitialTestUsers() []*acl.User { // User with both hash password and plaintext password withPasswordUser := acl.CreateUser("with_password_user") h := sha256.New() h.Write([]byte("password3")) withPasswordUser.Passwords = []acl.Password{ {PasswordType: acl.PasswordPlainText, PasswordValue: "password2"}, {PasswordType: acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))}, } withPasswordUser.IncludedCategories = []string{"*"} withPasswordUser.IncludedCommands = []string{"*"} // User with NoPassword option noPasswordUser := acl.CreateUser("no_password_user") noPasswordUser.Passwords = []acl.Password{ {PasswordType: acl.PasswordPlainText, PasswordValue: "password4"}, } noPasswordUser.NoPassword = true // Disabled user disabledUser := acl.CreateUser("disabled_user") disabledUser.Passwords = []acl.Password{ {PasswordType: acl.PasswordPlainText, PasswordValue: "password5"}, } disabledUser.Enabled = false return []*acl.User{ withPasswordUser, noPasswordUser, disabledUser, } } // compareSlices compare the elements in 2 slices, it checks if every element is s1 is contained in s2 // and vice versa. It essentially does a deep equality comparison. // This is done manually rather than using slices.Equal because it would be ideal to throw an error // specifying exactly which items are missing in either slice. func compareSlices[T comparable](res, expected []T) error { if len(res) != len(expected) { return fmt.Errorf("expected slice of length %d, got slice of length %d", len(expected), len(res)) } // Check whether all elements in res are contained in expected for _, r := range res { if !slices.Contains(expected, r) { return fmt.Errorf("got response item %+v, but it's not contained in expected slices", r) } } // Check whether all elements in expected are contained in res for _, e := range expected { if !slices.Contains(res, e) { return fmt.Errorf("expected element %+v, not found in res slice", e) } } return nil } // compareUsers compares 2 users and checks if all their fields are equal func compareUsers(user1, user2 *acl.User) error { // Compare flags if user1.Username != user2.Username { return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username) } if user1.Enabled != user2.Enabled { return fmt.Errorf("mismatched enabled flag \"%+v\", and \"%+v\"", user1.Enabled, user2.Enabled) } if user1.NoPassword != user2.NoPassword { return fmt.Errorf("mismatched nopassword flag \"%+v\", and \"%+v\"", user1.NoPassword, user2.NoPassword) } if user1.NoKeys != user2.NoKeys { return fmt.Errorf("mismatched nokeys flag \"%+v\", and \"%+v\"", user1.NoKeys, user2.NoKeys) } // Compare passwords for _, password1 := range user1.Passwords { if !slices.ContainsFunc(user2.Passwords, func(password2 acl.Password) bool { return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue }) { return fmt.Errorf("found password %+v in user1 that was not found in user2", password1) } } for _, password2 := range user2.Passwords { if !slices.ContainsFunc(user1.Passwords, func(password1 acl.Password) bool { return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue }) { return fmt.Errorf("found password %+v in user2 that was not found in user1", password2) } } // Compare permissions permissions := [][][]string{ {user1.IncludedCategories, user2.IncludedCategories}, {user1.ExcludedCategories, user2.ExcludedCategories}, {user1.IncludedCommands, user2.IncludedCommands}, {user1.ExcludedCommands, user2.ExcludedCommands}, {user1.IncludedReadKeys, user2.IncludedReadKeys}, {user1.IncludedWriteKeys, user2.IncludedWriteKeys}, {user1.IncludedPubSubChannels, user2.IncludedPubSubChannels}, {user1.ExcludedPubSubChannels, user2.ExcludedPubSubChannels}, } for _, p := range permissions { if err := compareSlices(p[0], p[1]); err != nil { return err } } return nil } func generateSHA256Password(plain string) string { h := sha256.New() h.Write([]byte(plain)) return string(h.Sum(nil)) } // func Test_HandleAuth(t *testing.T) { // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // cmd []resp.Value // wantRes string // wantErr string // }{ // { // 1. Authenticate with default user without specifying username // cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}, // wantRes: "OK", // wantErr: "", // }, // { // 2. Authenticate with plaintext password // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("with_password_user"), // resp.StringValue("password2"), // }, // wantRes: "OK", // wantErr: "", // }, // { // 3. Authenticate with SHA256 password // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("with_password_user"), // resp.StringValue("password3"), // }, // wantRes: "OK", // wantErr: "", // }, // { // 4. Authenticate with no password user // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("no_password_user"), // resp.StringValue("password4"), // }, // wantRes: "OK", // wantErr: "", // }, // { // 5. Fail to authenticate with disabled user // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("disabled_user"), // resp.StringValue("password5"), // }, // wantRes: "", // wantErr: "Error user disabled_user is disabled", // }, // { // 6. Fail to authenticate with non-existent user // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("non_existent_user"), // resp.StringValue("password6"), // }, // wantRes: "", // wantErr: "Error no user with username non_existent_user", // }, // { // 7. Command too short // cmd: []resp.Value{resp.StringValue("AUTH")}, // wantRes: "", // wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), // }, // { // 8. Command too long // cmd: []resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue("user"), // resp.StringValue("password1"), // resp.StringValue("password2"), // }, // wantRes: "", // wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), // }, // } // // for _, test := range tests { // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // rv, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if rv.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) // } // continue // } // if rv.String() != test.wantRes { // t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String()) // } // } // } // func Test_HandleCat(t *testing.T) { // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // r := resp.NewConn(conn) // // // Authenticate connection // if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil { // t.Error(err) // } // rv, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if rv.String() != "OK" { // t.Error("could not authenticate user") // } // // // Since only ACL commands are loaded in this test suite, this test will only test against the // // list of categories and commands available in the ACL module. // tests := []struct { // cmd []resp.Value // wantRes []string // wantErr string // }{ // { // 1. Return list of categories // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")}, // wantRes: []string{ // constants.ConnectionCategory, // constants.SlowCategory, // constants.FastCategory, // constants.AdminCategory, // constants.DangerousCategory, // }, // wantErr: "", // }, // { // 2. Return list of commands in connection category // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.ConnectionCategory)}, // wantRes: []string{"auth"}, // wantErr: "", // }, // { // 3. Return list of commands in slow category // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.SlowCategory)}, // wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, // wantErr: "", // }, // { // 4. Return list of commands in fast category // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.FastCategory)}, // wantRes: []string{"acl|whoami"}, // wantErr: "", // }, // { // 5. Return list of commands in admin category // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.AdminCategory)}, // wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, // wantErr: "", // }, // { // 6. Return list of commands in dangerous category // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.DangerousCategory)}, // wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, // wantErr: "", // }, // { // 7. Return error when category does not exist // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")}, // wantRes: nil, // wantErr: "Error category NON-EXISTENT not found", // }, // { // 8. Command too long // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")}, // wantRes: nil, // wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), // }, // } // // for _, test := range tests { // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // rv, _, err = r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if rv.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) // } // continue // } // resArr := rv.Array() // // Check if all the elements in the expected array are in the response array // for _, expected := range test.wantRes { // if !slices.ContainsFunc(resArr, func(value resp.Value) bool { // return value.String() == expected // }) { // t.Errorf("could not find expected command \"%s\" in the response array for category", expected) // } // } // } // } // func Test_HandleUsers(t *testing.T) { // port, _ := internal.GetFreePort() // mockServer := setUpServer(bindAddr, uint16(port), false, "") // wg := sync.WaitGroup{} // wg.Add(1) // go func() { // wg.Done() // mockServer.Start() // }() // wg.Wait() // // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // users := []string{"default", "with_password_user", "no_password_user", "disabled_user"} // // if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil { // t.Error(err) // } // // rv, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // // resArr := rv.Array() // // // Check if all the expected users are in the response array // for _, user := range users { // if !slices.ContainsFunc(resArr, func(value resp.Value) bool { // return value.String() == user // }) { // t.Errorf("could not find expected user \"%s\" in response array", user) // } // } // // // Check if all the users in the response array are in the expected users // for _, value := range resArr { // if !slices.ContainsFunc(users, func(user string) bool { // return value.String() == user // }) { // t.Errorf("could not find response user \"%s\" in expected users array", value.String()) // } // } // } // func Test_HandleSetUser(t *testing.T) { // port, _ := internal.GetFreePort() // mockServer := setUpServer(bindAddr, uint16(port), false, "") // wg := sync.WaitGroup{} // wg.Add(1) // go func() { // wg.Done() // mockServer.Start() // }() // wg.Wait() // // a := getACL(mockServer) // // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // presetUser *acl.User // cmd []resp.Value // wantRes string // wantErr string // wantUser *acl.User // }{ // { // // 1. Create new enabled user // presetUser: nil, // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_1"), // resp.StringValue("on"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_1") // user.Enabled = true // user.Normalise() // return user // }(), // }, // { // // 2. Create new disabled user // presetUser: nil, // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_2"), // resp.StringValue("off"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_2") // user.Enabled = false // user.Normalise() // return user // }(), // }, // { // // 3. Create new enabled user with both plaintext and SHA256 passwords // presetUser: nil, // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_3"), // resp.StringValue("on"), // resp.StringValue(">set_user_3_plaintext_password_1"), // resp.StringValue(">set_user_3_plaintext_password_2"), // resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_1"))), // resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_2"))), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_3") // user.Enabled = true // user.Passwords = []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"}, // {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")}, // } // user.Normalise() // return user // }(), // }, // { // // 4. Remove plaintext and SHA256 password from existing user // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_4") // user.Enabled = true // user.Passwords = []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"}, // {PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")}, // } // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_4"), // resp.StringValue("on"), // resp.StringValue("password1"), // resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("password2"))), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_16") // user.Enabled = true // user.NoPassword = true // user.Passwords = []acl.Password{} // user.Normalise() // return user // }(), // }, // { // // 17. Delete all existing users passwords using 'nopass' // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_17") // user.Enabled = true // user.NoPassword = true // user.Passwords = []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "password1"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")}, // } // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_17"), // resp.StringValue("on"), // resp.StringValue("nopass"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_17") // user.Enabled = true // user.NoPassword = true // user.Passwords = []acl.Password{} // user.Normalise() // return user // }(), // }, // { // // 18. Clear all of an existing user's passwords using 'resetpass' // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_18") // user.Enabled = true // user.NoPassword = true // user.Passwords = []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "password1"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")}, // } // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_18"), // resp.StringValue("on"), // resp.StringValue("nopass"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_18") // user.Enabled = true // user.NoPassword = true // user.Passwords = []acl.Password{} // user.Normalise() // return user // }(), // }, // { // // 19. Clear all of an existing user's command privileges using 'nocommands' // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_19") // user.Enabled = true // user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"} // user.ExcludedCommands = []string{"rewriteaof", "save"} // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_19"), // resp.StringValue("on"), // resp.StringValue("nocommands"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_19") // user.Enabled = true // user.IncludedCommands = []string{} // user.ExcludedCommands = []string{"*"} // user.IncludedCategories = []string{} // user.ExcludedCategories = []string{"*"} // user.Normalise() // return user // }(), // }, // { // // 20. Clear all of an existing user's allowed keys using 'resetkeys' // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_20") // user.Enabled = true // user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"} // user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"} // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_20"), // resp.StringValue("on"), // resp.StringValue("resetkeys"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_20") // user.Enabled = true // user.NoKeys = true // user.IncludedReadKeys = []string{} // user.IncludedWriteKeys = []string{} // user.Normalise() // return user // }(), // }, // { // // 21. Allow user to access all channels using 'resetchannels' // presetUser: func() *acl.User { // user := acl.CreateUser("set_user_21") // user.IncludedPubSubChannels = []string{"channel1", "channel2"} // user.ExcludedPubSubChannels = []string{"channel3", "channel4"} // user.Normalise() // return user // }(), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("SETUSER"), // resp.StringValue("set_user_21"), // resp.StringValue("resetchannels"), // }, // wantRes: "OK", // wantErr: "", // wantUser: func() *acl.User { // user := acl.CreateUser("set_user_21") // user.IncludedPubSubChannels = []string{} // user.ExcludedPubSubChannels = []string{"*"} // user.Normalise() // return user // }(), // }, // } // // for i, test := range tests { // if test.presetUser != nil { // a.AddUsers([]*acl.User{test.presetUser}) // } // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // v, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if v.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error()) // } // continue // } // if v.String() != test.wantRes { // t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, v.String()) // } // if test.wantUser == nil { // continue // } // expectedUser := test.wantUser // currUserIdx := slices.IndexFunc(a.Users, func(user *acl.User) bool { // return user.Username == expectedUser.Username // }) // if currUserIdx == -1 { // t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username) // } // if err = compareUsers(expectedUser, a.Users[currUserIdx]); err != nil { // t.Errorf("test idx: %d, %+v", i, err) // } // } // } // func Test_HandleGetUser(t *testing.T) { // port, _ := internal.GetFreePort() // mockServer := setUpServer(bindAddr, uint16(port), false, "") // wg := sync.WaitGroup{} // wg.Add(1) // go func() { // wg.Done() // mockServer.Start() // }() // wg.Wait() // // a := getACL(mockServer) // // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // presetUser *acl.User // cmd []resp.Value // wantRes []resp.Value // wantErr string // }{ // { // 1. Get the user and all their details // presetUser: &acl.User{ // Username: "get_user_1", // Enabled: true, // NoPassword: false, // NoKeys: false, // Passwords: []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "get_user_password_1"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")}, // }, // IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, // ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, // IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"}, // ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"}, // IncludedReadKeys: []string{"key1", "key2", "key3", "key4"}, // IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"}, // IncludedPubSubChannels: []string{"channel1", "channel2"}, // ExcludedPubSubChannels: []string{"channel3", "channel4"}, // }, // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("GETUSER"), resp.StringValue("get_user_1")}, // wantRes: []resp.Value{ // resp.StringValue("username"), // resp.ArrayValue([]resp.Value{resp.StringValue("get_user_1")}), // resp.StringValue("flags"), // resp.ArrayValue([]resp.Value{ // resp.StringValue("on"), // }), // resp.StringValue("categories"), // resp.ArrayValue([]resp.Value{ // resp.StringValue(fmt.Sprintf("+@%s", constants.WriteCategory)), // resp.StringValue(fmt.Sprintf("+@%s", constants.ReadCategory)), // resp.StringValue(fmt.Sprintf("+@%s", constants.PubSubCategory)), // resp.StringValue(fmt.Sprintf("-@%s", constants.AdminCategory)), // resp.StringValue(fmt.Sprintf("-@%s", constants.ConnectionCategory)), // resp.StringValue(fmt.Sprintf("-@%s", constants.DangerousCategory)), // }), // resp.StringValue("commands"), // resp.ArrayValue([]resp.Value{ // resp.StringValue("+acl|setuser"), // resp.StringValue("+acl|getuser"), // resp.StringValue("+acl|deluser"), // resp.StringValue("-rewriteaof"), // resp.StringValue("-save"), // resp.StringValue("-acl|load"), // resp.StringValue("-acl|save"), // }), // resp.StringValue("keys"), // resp.ArrayValue([]resp.Value{ // // Keys here // resp.StringValue("%RW~key1"), // resp.StringValue("%RW~key2"), // resp.StringValue("%R~key3"), // resp.StringValue("%R~key4"), // resp.StringValue("%W~key5"), // resp.StringValue("%W~key6"), // }), // resp.StringValue("channels"), // resp.ArrayValue([]resp.Value{ // // Channels here // resp.StringValue("+&channel1"), // resp.StringValue("+&channel2"), // resp.StringValue("-&channel3"), // resp.StringValue("-&channel4"), // }), // }, // wantErr: "", // }, // { // 2. Return user not found error // presetUser: nil, // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("GETUSER"), // resp.StringValue("non_existent_user")}, // wantRes: nil, // wantErr: "Error user not found", // }, // } // // for _, test := range tests { // if test.presetUser != nil { // a.AddUsers([]*acl.User{test.presetUser}) // } // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // v, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if v.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error()) // } // continue // } // resArr := v.Array() // for i := 0; i < len(resArr); i++ { // if slices.Contains([]string{"username", "flags", "categories", "commands", "keys", "channels"}, resArr[i].String()) { // // String item // if resArr[i].String() != test.wantRes[i].String() { // t.Errorf("expected response component %+v, got %+v", test.wantRes[i], resArr[i]) // } // } else { // // Array item // var expected []string // for _, item := range test.wantRes[i].Array() { // expected = append(expected, item.String()) // } // // var res []string // for _, item := range resArr[i].Array() { // res = append(res, item.String()) // } // // if err = compareSlices(res, expected); err != nil { // t.Error(err) // } // } // } // } // } // func Test_HandleDelUser(t *testing.T) { // port, _ := internal.GetFreePort() // mockServer := setUpServer(bindAddr, uint16(port), false, "") // wg := sync.WaitGroup{} // wg.Add(1) // go func() { // wg.Done() // mockServer.Start() // }() // wg.Wait() // // a := getACL(mockServer) // // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // presetUser *acl.User // cmd []resp.Value // wantRes string // wantErr string // }{ // { // // 1. Delete existing user while skipping default user and non-existent user // presetUser: acl.CreateUser("user_to_delete"), // cmd: []resp.Value{ // resp.StringValue("ACL"), // resp.StringValue("DELUSER"), // resp.StringValue("default"), // resp.StringValue("user_to_delete"), // resp.StringValue("non_existent_user"), // }, // wantRes: "OK", // wantErr: "", // }, // { // // 2. Command too short // presetUser: nil, // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("DELUSER")}, // wantRes: "", // wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), // }, // } // // for _, test := range tests { // if test.presetUser != nil { // a.AddUsers([]*acl.User{test.presetUser}) // } // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // v, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if v.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error()) // } // continue // } // // Check that default user still exists in the list of users // if !slices.ContainsFunc(a.Users, func(user *acl.User) bool { // return user.Username == "default" // }) { // t.Error("could not find user with username \"default\" in the ACL after deleting user") // } // // Check that the deleted user is no longer in the list // if slices.ContainsFunc(a.Users, func(user *acl.User) bool { // return user.Username == "user_to_delete" // }) { // t.Error("deleted user found in the ACL") // } // } // } // func Test_HandleWhoAmI(t *testing.T) { // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // username string // password string // wantRes string // }{ // { // 1. With default user // username: "default", // password: "password1", // wantRes: "default", // }, // { // 2. With user authenticated by plaintext password // username: "with_password_user", // password: "password2", // wantRes: "with_password_user", // }, // { // 3. With user authenticated by SHA256 password // username: "with_password_user", // password: "password3", // wantRes: "with_password_user", // }, // } // // for _, test := range tests { // // Authenticate // if err = r.WriteArray([]resp.Value{ // resp.StringValue("AUTH"), // resp.StringValue(test.username), // resp.StringValue(test.password), // }); err != nil { // t.Error(err) // } // v, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if v.String() != "OK" { // t.Errorf("expected response for auth with %s:%s to be \"OK\", got %s", test.username, test.password, v.String()) // } // // Check whoami response value // if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("WHOAMI")}); err != nil { // t.Error(err) // } // v, _, err = r.ReadValue() // if err != nil { // t.Error(err) // } // if v.String() != test.wantRes { // t.Errorf("expected whoami response to be \"%s\", got \"%s\"", test.wantRes, v.String()) // } // } // } // func Test_HandleList(t *testing.T) { // port, _ := internal.GetFreePort() // mockServer := setUpServer(bindAddr, uint16(port), false, "") // wg := sync.WaitGroup{} // wg.Add(1) // go func() { // wg.Done() // mockServer.Start() // }() // wg.Wait() // // a := getACL(mockServer) // // conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) // if err != nil { // t.Error(err) // return // } // defer func() { // if conn != nil { // _ = conn.Close() // } // }() // // r := resp.NewConn(conn) // // tests := []struct { // presetUsers []*acl.User // cmd []resp.Value // wantRes []string // wantErr string // }{ // { // 1. Get the user and all their details // presetUsers: []*acl.User{ // { // Username: "list_user_1", // Enabled: true, // NoPassword: false, // NoKeys: false, // Passwords: []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_1"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")}, // }, // IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, // ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, // IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"}, // ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"}, // IncludedReadKeys: []string{"key1", "key2", "key3", "key4"}, // IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"}, // IncludedPubSubChannels: []string{"channel1", "channel2"}, // ExcludedPubSubChannels: []string{"channel3", "channel4"}, // }, // { // Username: "list_user_2", // Enabled: true, // NoPassword: true, // NoKeys: true, // Passwords: []acl.Password{}, // IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, // ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, // IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"}, // ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"}, // IncludedReadKeys: []string{}, // IncludedWriteKeys: []string{}, // IncludedPubSubChannels: []string{"channel1", "channel2"}, // ExcludedPubSubChannels: []string{"channel3", "channel4"}, // }, // { // Username: "list_user_3", // Enabled: true, // NoPassword: false, // NoKeys: false, // Passwords: []acl.Password{ // {PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_3"}, // {PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")}, // }, // IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}, // ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}, // IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"}, // ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"}, // IncludedReadKeys: []string{"key1", "key2", "key3", "key4"}, // IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"}, // IncludedPubSubChannels: []string{"channel1", "channel2"}, // ExcludedPubSubChannels: []string{"channel3", "channel4"}, // }, // }, // cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}, // wantRes: []string{ // "default on +@all +all %RW~* +&*", // fmt.Sprintf("with_password_user on >password2 #%s +@all +all", generateSHA256Password("password3")), // "no_password_user on nopass >password4", // "disabled_user off >password5", // fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4"), // fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`), // fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4"), // }, // wantErr: "", // }, // } // // for _, test := range tests { // a.AddUsers(test.presetUsers) // // if err = r.WriteArray(test.cmd); err != nil { // t.Error(err) // } // v, _, err := r.ReadValue() // if err != nil { // t.Error(err) // } // if test.wantErr != "" { // if v.Error().Error() != test.wantErr { // t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error()) // } // continue // } // resArr := v.Array() // if len(resArr) != len(test.wantRes) { // t.Errorf("expected response of lenght %d, got lenght %d", len(test.wantRes), len(resArr)) // } // var resStr []string // for i := 0; i < len(resArr); i++ { // resStr = strings.Split(resArr[i].String(), " ") // if !slices.ContainsFunc(test.wantRes, func(s string) bool { // expectedUserSlice := strings.Split(s, " ") // return compareSlices(resStr, expectedUserSlice) == nil // }) { // t.Errorf("could not find the following user in expected slice: %+v", resStr) // } // clear(resStr) // } // } // }