mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-05 07:56:52 +08:00

Return "empty command" error from handleCommand method if an empty command is passed to the server. Wait until connection is no longer nil in acl package tests.
1511 lines
45 KiB
Go
1511 lines
45 KiB
Go
// Copyright 2024 Kelvin Clement Mwinuka
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package acl_test
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"github.com/echovault/echovault/echovault"
|
|
"github.com/echovault/echovault/internal"
|
|
"github.com/echovault/echovault/internal/config"
|
|
"github.com/echovault/echovault/internal/constants"
|
|
"github.com/echovault/echovault/internal/modules/acl"
|
|
"github.com/tidwall/resp"
|
|
"net"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"unsafe"
|
|
)
|
|
|
|
var bindAddr string
|
|
var port uint16
|
|
var mockServer *echovault.EchoVault
|
|
|
|
func init() {
|
|
bindAddr = "localhost"
|
|
|
|
p, _ := internal.GetFreePort()
|
|
port = uint16(p)
|
|
|
|
mockServer = setUpServer(bindAddr, port, true, "")
|
|
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
}
|
|
|
|
func setUpServer(bindAddr string, port uint16, requirePass bool, aclConfig string) *echovault.EchoVault {
|
|
conf := config.Config{
|
|
BindAddr: bindAddr,
|
|
Port: port,
|
|
DataDir: "",
|
|
EvictionPolicy: constants.NoEviction,
|
|
RequirePass: requirePass,
|
|
Password: "password1",
|
|
AclConfig: aclConfig,
|
|
}
|
|
|
|
mockServer, _ := echovault.NewEchoVault(
|
|
echovault.WithConfig(conf),
|
|
)
|
|
|
|
// Add the initial test users to the ACL module
|
|
a := getACL(mockServer)
|
|
a.AddUsers(generateInitialTestUsers())
|
|
|
|
return mockServer
|
|
}
|
|
|
|
func getUnexportedField(field reflect.Value) interface{} {
|
|
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
|
|
}
|
|
|
|
func getACL(mockServer *echovault.EchoVault) *acl.ACL {
|
|
method := getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getACL"))
|
|
f := method.(func() interface{})
|
|
return f().(*acl.ACL)
|
|
}
|
|
|
|
func generateInitialTestUsers() []*acl.User {
|
|
// User with both hash password and plaintext password
|
|
withPasswordUser := acl.CreateUser("with_password_user")
|
|
h := sha256.New()
|
|
h.Write([]byte("password3"))
|
|
withPasswordUser.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password2"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: string(h.Sum(nil))},
|
|
}
|
|
withPasswordUser.IncludedCategories = []string{"*"}
|
|
withPasswordUser.IncludedCommands = []string{"*"}
|
|
|
|
// User with NoPassword option
|
|
noPasswordUser := acl.CreateUser("no_password_user")
|
|
noPasswordUser.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password4"},
|
|
}
|
|
noPasswordUser.NoPassword = true
|
|
|
|
// Disabled user
|
|
disabledUser := acl.CreateUser("disabled_user")
|
|
disabledUser.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password5"},
|
|
}
|
|
disabledUser.Enabled = false
|
|
|
|
return []*acl.User{
|
|
withPasswordUser,
|
|
noPasswordUser,
|
|
disabledUser,
|
|
}
|
|
}
|
|
|
|
// compareSlices compare the elements in 2 slices, it checks if every element is s1 is contained in s2
|
|
// and vice versa. It essentially does a deep equality comparison.
|
|
// This is done manually rather than using slices.Equal because it would be ideal to throw an error
|
|
// specifying exactly which items are missing in either slice.
|
|
func compareSlices[T comparable](res, expected []T) error {
|
|
if len(res) != len(expected) {
|
|
return fmt.Errorf("expected slice of length %d, got slice of length %d", len(expected), len(res))
|
|
}
|
|
// Check whether all elements in res are contained in expected
|
|
for _, r := range res {
|
|
if !slices.Contains(expected, r) {
|
|
return fmt.Errorf("got response item %+v, but it's not contained in expected slices", r)
|
|
}
|
|
}
|
|
// Check whether all elements in expected are contained in res
|
|
for _, e := range expected {
|
|
if !slices.Contains(res, e) {
|
|
return fmt.Errorf("expected element %+v, not found in res slice", e)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// compareUsers compares 2 users and checks if all their fields are equal
|
|
func compareUsers(user1, user2 *acl.User) error {
|
|
// Compare flags
|
|
if user1.Username != user2.Username {
|
|
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1.Username, user2.Username)
|
|
}
|
|
if user1.Enabled != user2.Enabled {
|
|
return fmt.Errorf("mismatched enabled flag \"%+v\", and \"%+v\"", user1.Enabled, user2.Enabled)
|
|
}
|
|
if user1.NoPassword != user2.NoPassword {
|
|
return fmt.Errorf("mismatched nopassword flag \"%+v\", and \"%+v\"", user1.NoPassword, user2.NoPassword)
|
|
}
|
|
if user1.NoKeys != user2.NoKeys {
|
|
return fmt.Errorf("mismatched nokeys flag \"%+v\", and \"%+v\"", user1.NoKeys, user2.NoKeys)
|
|
}
|
|
|
|
// Compare passwords
|
|
for _, password1 := range user1.Passwords {
|
|
if !slices.ContainsFunc(user2.Passwords, func(password2 acl.Password) bool {
|
|
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
|
}) {
|
|
return fmt.Errorf("found password %+v in user1 that was not found in user2", password1)
|
|
}
|
|
}
|
|
for _, password2 := range user2.Passwords {
|
|
if !slices.ContainsFunc(user1.Passwords, func(password1 acl.Password) bool {
|
|
return password1.PasswordType == password2.PasswordType && password1.PasswordValue == password2.PasswordValue
|
|
}) {
|
|
return fmt.Errorf("found password %+v in user2 that was not found in user1", password2)
|
|
}
|
|
}
|
|
|
|
// Compare permissions
|
|
permissions := [][][]string{
|
|
{user1.IncludedCategories, user2.IncludedCategories},
|
|
{user1.ExcludedCategories, user2.ExcludedCategories},
|
|
{user1.IncludedCommands, user2.IncludedCommands},
|
|
{user1.ExcludedCommands, user2.ExcludedCommands},
|
|
{user1.IncludedReadKeys, user2.IncludedReadKeys},
|
|
{user1.IncludedWriteKeys, user2.IncludedWriteKeys},
|
|
{user1.IncludedPubSubChannels, user2.IncludedPubSubChannels},
|
|
{user1.ExcludedPubSubChannels, user2.ExcludedPubSubChannels},
|
|
}
|
|
for _, p := range permissions {
|
|
if err := compareSlices(p[0], p[1]); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateSHA256Password(plain string) string {
|
|
h := sha256.New()
|
|
h.Write([]byte(plain))
|
|
return string(h.Sum(nil))
|
|
}
|
|
|
|
func Test_HandleAuth(t *testing.T) {
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
for {
|
|
// Wait until connection is not nil before breaking out.
|
|
if conn != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
}{
|
|
{ // 1. Authenticate with default user without specifying username
|
|
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 2. Authenticate with plaintext password
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("with_password_user"),
|
|
resp.StringValue("password2"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 3. Authenticate with SHA256 password
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("with_password_user"),
|
|
resp.StringValue("password3"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 4. Authenticate with no password user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("no_password_user"),
|
|
resp.StringValue("password4"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 5. Fail to authenticate with disabled user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("disabled_user"),
|
|
resp.StringValue("password5"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: "Error user disabled_user is disabled",
|
|
},
|
|
{ // 6. Fail to authenticate with non-existent user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("non_existent_user"),
|
|
resp.StringValue("password6"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: "Error no user with username non_existent_user",
|
|
},
|
|
{ // 7. Command too short
|
|
cmd: []resp.Value{resp.StringValue("AUTH")},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
{ // 8. Command too long
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("user"),
|
|
resp.StringValue("password1"),
|
|
resp.StringValue("password2"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if rv.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
if rv.String() != test.wantRes {
|
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleCat(t *testing.T) {
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
r := resp.NewConn(conn)
|
|
|
|
// Authenticate connection
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if rv.String() != "OK" {
|
|
t.Error("could not authenticate user")
|
|
}
|
|
|
|
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
|
// list of categories and commands available in the ACL module.
|
|
tests := []struct {
|
|
cmd []resp.Value
|
|
wantRes []string
|
|
wantErr string
|
|
}{
|
|
{ // 1. Return list of categories
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")},
|
|
wantRes: []string{
|
|
constants.ConnectionCategory,
|
|
constants.SlowCategory,
|
|
constants.FastCategory,
|
|
constants.AdminCategory,
|
|
constants.DangerousCategory,
|
|
},
|
|
wantErr: "",
|
|
},
|
|
{ // 2. Return list of commands in connection category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.ConnectionCategory)},
|
|
wantRes: []string{"auth"},
|
|
wantErr: "",
|
|
},
|
|
{ // 3. Return list of commands in slow category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.SlowCategory)},
|
|
wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 4. Return list of commands in fast category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.FastCategory)},
|
|
wantRes: []string{"acl|whoami"},
|
|
wantErr: "",
|
|
},
|
|
{ // 5. Return list of commands in admin category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.AdminCategory)},
|
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 6. Return list of commands in dangerous category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.DangerousCategory)},
|
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 7. Return error when category does not exist
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")},
|
|
wantRes: nil,
|
|
wantErr: "Error category NON-EXISTENT not found",
|
|
},
|
|
{ // 8. Command too long
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")},
|
|
wantRes: nil,
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err = r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if rv.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
resArr := rv.Array()
|
|
// Check if all the elements in the expected array are in the response array
|
|
for _, expected := range test.wantRes {
|
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
|
return value.String() == expected
|
|
}) {
|
|
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleUsers(t *testing.T) {
|
|
port, _ := internal.GetFreePort()
|
|
mockServer := setUpServer(bindAddr, uint16(port), false, "")
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
for {
|
|
// Wait until connection is not nil before continuing.
|
|
if conn != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
users := []string{"default", "with_password_user", "no_password_user", "disabled_user"}
|
|
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resArr := rv.Array()
|
|
|
|
// Check if all the expected users are in the response array
|
|
for _, user := range users {
|
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
|
return value.String() == user
|
|
}) {
|
|
t.Errorf("could not find expected user \"%s\" in response array", user)
|
|
}
|
|
}
|
|
|
|
// Check if all the users in the response array are in the expected users
|
|
for _, value := range resArr {
|
|
if !slices.ContainsFunc(users, func(user string) bool {
|
|
return value.String() == user
|
|
}) {
|
|
t.Errorf("could not find response user \"%s\" in expected users array", value.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleSetUser(t *testing.T) {
|
|
port, _ := internal.GetFreePort()
|
|
mockServer := setUpServer(bindAddr, uint16(port), false, "")
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
|
|
a := getACL(mockServer)
|
|
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
presetUser *acl.User
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
wantUser *acl.User
|
|
}{
|
|
{
|
|
// 1. Create new enabled user
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_1"),
|
|
resp.StringValue("on"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_1")
|
|
user.Enabled = true
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 2. Create new disabled user
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_2"),
|
|
resp.StringValue("off"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_2")
|
|
user.Enabled = false
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 3. Create new enabled user with both plaintext and SHA256 passwords
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_3"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue(">set_user_3_plaintext_password_1"),
|
|
resp.StringValue(">set_user_3_plaintext_password_2"),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_1"))),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_3")
|
|
user.Enabled = true
|
|
user.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
|
}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 4. Remove plaintext and SHA256 password from existing user
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_4")
|
|
user.Enabled = true
|
|
user.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_2"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_2")},
|
|
}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_4"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("<set_user_3_plaintext_password_2"),
|
|
resp.StringValue(fmt.Sprintf("!%s", generateSHA256Password("set_user_3_hash_password_2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_4")
|
|
user.Enabled = true
|
|
user.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "set_user_3_plaintext_password_1"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("set_user_3_hash_password_1")},
|
|
}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 5. Create user with no commands allowed to be executed
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_5"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nocommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_5")
|
|
user.Enabled = true
|
|
user.ExcludedCommands = []string{"*"}
|
|
user.ExcludedCategories = []string{"*"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 6. Create user that can access all categories with +@*
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_6"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+@*"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_6")
|
|
user.Enabled = true
|
|
user.IncludedCategories = []string{"*"}
|
|
user.ExcludedCategories = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 7. Create user that can access all categories with allcategories flag
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_7"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allcategories"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_7")
|
|
user.Enabled = true
|
|
user.IncludedCategories = []string{"*"}
|
|
user.ExcludedCategories = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 8. Create user with a few allowed categories and a few disallowed categories
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_8"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.WriteCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.ReadCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.PubSubCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.AdminCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.ConnectionCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.DangerousCategory)),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_8")
|
|
user.Enabled = true
|
|
user.IncludedCategories = []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory}
|
|
user.ExcludedCategories = []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 9. Create user that is not allowed to access any keys
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_9"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("resetkeys"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_9")
|
|
user.Enabled = true
|
|
user.NoKeys = true
|
|
user.IncludedReadKeys = []string{}
|
|
user.IncludedWriteKeys = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 10. Create user that can access some read keys and some write keys
|
|
// Provide keys that are RW, W-Only and R-Only
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_10"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("~key1"),
|
|
resp.StringValue("~key2"),
|
|
resp.StringValue("%RW~key3"),
|
|
resp.StringValue("%RW~key4"),
|
|
resp.StringValue("%R~key5"),
|
|
resp.StringValue("%R~key6"),
|
|
resp.StringValue("%W~key7"),
|
|
resp.StringValue("%W~key8"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_10")
|
|
user.Enabled = true
|
|
user.NoKeys = false
|
|
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
|
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key7", "key8"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 11. Create user that can access all pubsub channels with +&*
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_11"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+&*"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_11")
|
|
user.Enabled = true
|
|
user.IncludedPubSubChannels = []string{"*"}
|
|
user.ExcludedPubSubChannels = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 12. Create user that can access all pubsub channels with allchannels flag
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_12"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allchannels"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_12")
|
|
user.Enabled = true
|
|
user.IncludedPubSubChannels = []string{"*"}
|
|
user.ExcludedPubSubChannels = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 13. Create user with a few allowed pubsub channels and a few disallowed channels
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_13"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+&channel1"),
|
|
resp.StringValue("+&channel2"),
|
|
resp.StringValue("-&channel3"),
|
|
resp.StringValue("-&channel4"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_13")
|
|
user.Enabled = true
|
|
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
|
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 14. Create user that can access all commands
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_14"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allcommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_14")
|
|
user.Enabled = true
|
|
user.IncludedCommands = []string{"*"}
|
|
user.ExcludedCommands = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 15. Create user with some allowed commands and disallowed commands
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_15"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+acl|getuser"),
|
|
resp.StringValue("+acl|setuser"),
|
|
resp.StringValue("+acl|deluser"),
|
|
resp.StringValue("-rewriteaof"),
|
|
resp.StringValue("-save"),
|
|
resp.StringValue("-publish"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_15")
|
|
user.Enabled = true
|
|
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
|
user.ExcludedCommands = []string{"rewriteaof", "save", "publish"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 16. Create new user with no password using 'nopass'.
|
|
// When nopass is provided, ignore any passwords that may have been provided in the command.
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_16"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
resp.StringValue(">password1"),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("password2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_16")
|
|
user.Enabled = true
|
|
user.NoPassword = true
|
|
user.Passwords = []acl.Password{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 17. Delete all existing users passwords using 'nopass'
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_17")
|
|
user.Enabled = true
|
|
user.NoPassword = true
|
|
user.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
|
}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_17"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_17")
|
|
user.Enabled = true
|
|
user.NoPassword = true
|
|
user.Passwords = []acl.Password{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 18. Clear all of an existing user's passwords using 'resetpass'
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_18")
|
|
user.Enabled = true
|
|
user.NoPassword = true
|
|
user.Passwords = []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "password1"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("password2")},
|
|
}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_18"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_18")
|
|
user.Enabled = true
|
|
user.NoPassword = true
|
|
user.Passwords = []acl.Password{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 19. Clear all of an existing user's command privileges using 'nocommands'
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_19")
|
|
user.Enabled = true
|
|
user.IncludedCommands = []string{"acl|getuser", "acl|setuser", "acl|deluser"}
|
|
user.ExcludedCommands = []string{"rewriteaof", "save"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_19"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nocommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_19")
|
|
user.Enabled = true
|
|
user.IncludedCommands = []string{}
|
|
user.ExcludedCommands = []string{"*"}
|
|
user.IncludedCategories = []string{}
|
|
user.ExcludedCategories = []string{"*"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 20. Clear all of an existing user's allowed keys using 'resetkeys'
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_20")
|
|
user.Enabled = true
|
|
user.IncludedWriteKeys = []string{"key1", "key2", "key3", "key4", "key5", "key6"}
|
|
user.IncludedReadKeys = []string{"key1", "key2", "key3", "key7", "key8", "key9"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_20"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("resetkeys"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_20")
|
|
user.Enabled = true
|
|
user.NoKeys = true
|
|
user.IncludedReadKeys = []string{}
|
|
user.IncludedWriteKeys = []string{}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
{
|
|
// 21. Allow user to access all channels using 'resetchannels'
|
|
presetUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_21")
|
|
user.IncludedPubSubChannels = []string{"channel1", "channel2"}
|
|
user.ExcludedPubSubChannels = []string{"channel3", "channel4"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_21"),
|
|
resp.StringValue("resetchannels"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: func() *acl.User {
|
|
user := acl.CreateUser("set_user_21")
|
|
user.IncludedPubSubChannels = []string{}
|
|
user.ExcludedPubSubChannels = []string{"*"}
|
|
user.Normalise()
|
|
return user
|
|
}(),
|
|
},
|
|
}
|
|
|
|
for i, test := range tests {
|
|
if test.presetUser != nil {
|
|
a.AddUsers([]*acl.User{test.presetUser})
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
if v.String() != test.wantRes {
|
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, v.String())
|
|
}
|
|
if test.wantUser == nil {
|
|
continue
|
|
}
|
|
expectedUser := test.wantUser
|
|
currUserIdx := slices.IndexFunc(a.Users, func(user *acl.User) bool {
|
|
return user.Username == expectedUser.Username
|
|
})
|
|
if currUserIdx == -1 {
|
|
t.Errorf("expected to find user with username \"%s\" but could not find them.", expectedUser.Username)
|
|
}
|
|
if err = compareUsers(expectedUser, a.Users[currUserIdx]); err != nil {
|
|
t.Errorf("test idx: %d, %+v", i, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleGetUser(t *testing.T) {
|
|
port, _ := internal.GetFreePort()
|
|
mockServer := setUpServer(bindAddr, uint16(port), false, "")
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
|
|
a := getACL(mockServer)
|
|
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
presetUser *acl.User
|
|
cmd []resp.Value
|
|
wantRes []resp.Value
|
|
wantErr string
|
|
}{
|
|
{ // 1. Get the user and all their details
|
|
presetUser: &acl.User{
|
|
Username: "get_user_1",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
Passwords: []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "get_user_password_1"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("get_user_password_2")},
|
|
},
|
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludedReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludedPubSubChannels: []string{"channel1", "channel2"},
|
|
ExcludedPubSubChannels: []string{"channel3", "channel4"},
|
|
},
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("GETUSER"), resp.StringValue("get_user_1")},
|
|
wantRes: []resp.Value{
|
|
resp.StringValue("username"),
|
|
resp.ArrayValue([]resp.Value{resp.StringValue("get_user_1")}),
|
|
resp.StringValue("flags"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue("on"),
|
|
}),
|
|
resp.StringValue("categories"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.WriteCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.ReadCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.PubSubCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.AdminCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.ConnectionCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.DangerousCategory)),
|
|
}),
|
|
resp.StringValue("commands"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue("+acl|setuser"),
|
|
resp.StringValue("+acl|getuser"),
|
|
resp.StringValue("+acl|deluser"),
|
|
resp.StringValue("-rewriteaof"),
|
|
resp.StringValue("-save"),
|
|
resp.StringValue("-acl|load"),
|
|
resp.StringValue("-acl|save"),
|
|
}),
|
|
resp.StringValue("keys"),
|
|
resp.ArrayValue([]resp.Value{
|
|
// Keys here
|
|
resp.StringValue("%RW~key1"),
|
|
resp.StringValue("%RW~key2"),
|
|
resp.StringValue("%R~key3"),
|
|
resp.StringValue("%R~key4"),
|
|
resp.StringValue("%W~key5"),
|
|
resp.StringValue("%W~key6"),
|
|
}),
|
|
resp.StringValue("channels"),
|
|
resp.ArrayValue([]resp.Value{
|
|
// Channels here
|
|
resp.StringValue("+&channel1"),
|
|
resp.StringValue("+&channel2"),
|
|
resp.StringValue("-&channel3"),
|
|
resp.StringValue("-&channel4"),
|
|
}),
|
|
},
|
|
wantErr: "",
|
|
},
|
|
{ // 2. Return user not found error
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("GETUSER"),
|
|
resp.StringValue("non_existent_user")},
|
|
wantRes: nil,
|
|
wantErr: "Error user not found",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if test.presetUser != nil {
|
|
a.AddUsers([]*acl.User{test.presetUser})
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
resArr := v.Array()
|
|
for i := 0; i < len(resArr); i++ {
|
|
if slices.Contains([]string{"username", "flags", "categories", "commands", "keys", "channels"}, resArr[i].String()) {
|
|
// String item
|
|
if resArr[i].String() != test.wantRes[i].String() {
|
|
t.Errorf("expected response component %+v, got %+v", test.wantRes[i], resArr[i])
|
|
}
|
|
} else {
|
|
// Array item
|
|
var expected []string
|
|
for _, item := range test.wantRes[i].Array() {
|
|
expected = append(expected, item.String())
|
|
}
|
|
|
|
var res []string
|
|
for _, item := range resArr[i].Array() {
|
|
res = append(res, item.String())
|
|
}
|
|
|
|
if err = compareSlices(res, expected); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleDelUser(t *testing.T) {
|
|
port, _ := internal.GetFreePort()
|
|
mockServer := setUpServer(bindAddr, uint16(port), false, "")
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
|
|
a := getACL(mockServer)
|
|
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
presetUser *acl.User
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
}{
|
|
{
|
|
// 1. Delete existing user while skipping default user and non-existent user
|
|
presetUser: acl.CreateUser("user_to_delete"),
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("DELUSER"),
|
|
resp.StringValue("default"),
|
|
resp.StringValue("user_to_delete"),
|
|
resp.StringValue("non_existent_user"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{
|
|
// 2. Command too short
|
|
presetUser: nil,
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("DELUSER")},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if test.presetUser != nil {
|
|
a.AddUsers([]*acl.User{test.presetUser})
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
// Check that default user still exists in the list of users
|
|
if !slices.ContainsFunc(a.Users, func(user *acl.User) bool {
|
|
return user.Username == "default"
|
|
}) {
|
|
t.Error("could not find user with username \"default\" in the ACL after deleting user")
|
|
}
|
|
// Check that the deleted user is no longer in the list
|
|
if slices.ContainsFunc(a.Users, func(user *acl.User) bool {
|
|
return user.Username == "user_to_delete"
|
|
}) {
|
|
t.Error("deleted user found in the ACL")
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleWhoAmI(t *testing.T) {
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
username string
|
|
password string
|
|
wantRes string
|
|
}{
|
|
{ // 1. With default user
|
|
username: "default",
|
|
password: "password1",
|
|
wantRes: "default",
|
|
},
|
|
{ // 2. With user authenticated by plaintext password
|
|
username: "with_password_user",
|
|
password: "password2",
|
|
wantRes: "with_password_user",
|
|
},
|
|
{ // 3. With user authenticated by SHA256 password
|
|
username: "with_password_user",
|
|
password: "password3",
|
|
wantRes: "with_password_user",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
// Authenticate
|
|
if err = r.WriteArray([]resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue(test.username),
|
|
resp.StringValue(test.password),
|
|
}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if v.String() != "OK" {
|
|
t.Errorf("expected response for auth with %s:%s to be \"OK\", got %s", test.username, test.password, v.String())
|
|
}
|
|
// Check whoami response value
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("WHOAMI")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err = r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if v.String() != test.wantRes {
|
|
t.Errorf("expected whoami response to be \"%s\", got \"%s\"", test.wantRes, v.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
func Test_HandleList(t *testing.T) {
|
|
port, _ := internal.GetFreePort()
|
|
mockServer := setUpServer(bindAddr, uint16(port), false, "")
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
wg.Done()
|
|
mockServer.Start()
|
|
}()
|
|
wg.Wait()
|
|
|
|
a := getACL(mockServer)
|
|
|
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
presetUsers []*acl.User
|
|
cmd []resp.Value
|
|
wantRes []string
|
|
wantErr string
|
|
}{
|
|
{ // 1. Get the user and all their details
|
|
presetUsers: []*acl.User{
|
|
{
|
|
Username: "list_user_1",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
Passwords: []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_1"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_2")},
|
|
},
|
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludedReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludedPubSubChannels: []string{"channel1", "channel2"},
|
|
ExcludedPubSubChannels: []string{"channel3", "channel4"},
|
|
},
|
|
{
|
|
Username: "list_user_2",
|
|
Enabled: true,
|
|
NoPassword: true,
|
|
NoKeys: true,
|
|
Passwords: []acl.Password{},
|
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludedReadKeys: []string{},
|
|
IncludedWriteKeys: []string{},
|
|
IncludedPubSubChannels: []string{"channel1", "channel2"},
|
|
ExcludedPubSubChannels: []string{"channel3", "channel4"},
|
|
},
|
|
{
|
|
Username: "list_user_3",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
Passwords: []acl.Password{
|
|
{PasswordType: acl.PasswordPlainText, PasswordValue: "list_user_password_3"},
|
|
{PasswordType: acl.PasswordSHA256, PasswordValue: generateSHA256Password("list_user_password_4")},
|
|
},
|
|
IncludedCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludedCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludedCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludedCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludedReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludedWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludedPubSubChannels: []string{"channel1", "channel2"},
|
|
ExcludedPubSubChannels: []string{"channel3", "channel4"},
|
|
},
|
|
},
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")},
|
|
wantRes: []string{
|
|
"default on +@all +all %RW~* +&*",
|
|
fmt.Sprintf("with_password_user on >password2 #%s +@all +all", generateSHA256Password("password3")),
|
|
"no_password_user on nopass >password4",
|
|
"disabled_user off >password5",
|
|
fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
|
|
fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`),
|
|
fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
|
|
},
|
|
wantErr: "",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
a.AddUsers(test.presetUsers)
|
|
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
resArr := v.Array()
|
|
if len(resArr) != len(test.wantRes) {
|
|
t.Errorf("expected response of lenght %d, got lenght %d", len(test.wantRes), len(resArr))
|
|
}
|
|
var resStr []string
|
|
for i := 0; i < len(resArr); i++ {
|
|
resStr = strings.Split(resArr[i].String(), " ")
|
|
if !slices.ContainsFunc(test.wantRes, func(s string) bool {
|
|
expectedUserSlice := strings.Split(s, " ")
|
|
return compareSlices(resStr, expectedUserSlice) == nil
|
|
}) {
|
|
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
|
}
|
|
clear(resStr)
|
|
}
|
|
}
|
|
}
|