mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-05 07:56:52 +08:00
1741 lines
49 KiB
Go
1741 lines
49 KiB
Go
// Copyright 2024 Kelvin Clement Mwinuka
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package acl_test
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"github.com/echovault/echovault/echovault"
|
|
"github.com/echovault/echovault/internal"
|
|
"github.com/echovault/echovault/internal/config"
|
|
"github.com/echovault/echovault/internal/constants"
|
|
"github.com/tidwall/resp"
|
|
"os"
|
|
"path"
|
|
"slices"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func setUpServer(port int, requirePass bool, aclConfig string) (*echovault.EchoVault, error) {
|
|
conf := config.Config{
|
|
BindAddr: "localhost",
|
|
Port: uint16(port),
|
|
DataDir: "",
|
|
EvictionPolicy: constants.NoEviction,
|
|
RequirePass: requirePass,
|
|
Password: "password1",
|
|
AclConfig: aclConfig,
|
|
}
|
|
|
|
mockServer, err := echovault.NewEchoVault(
|
|
echovault.WithConfig(conf),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Add the initial test users to the ACL module.
|
|
for _, user := range generateInitialTestUsers() {
|
|
if _, err := mockServer.ACLSetUser(user); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return mockServer, nil
|
|
}
|
|
|
|
func generateInitialTestUsers() []echovault.User {
|
|
return []echovault.User{
|
|
{
|
|
// User with both hash password and plaintext password.
|
|
Username: "with_password_user",
|
|
Enabled: true,
|
|
IncludeCategories: []string{"*"},
|
|
IncludeCommands: []string{"*"},
|
|
AddPlainPasswords: []string{"password2"},
|
|
AddHashPasswords: []string{generateSHA256Password("password3")},
|
|
},
|
|
{
|
|
// User with NoPassword option.
|
|
Username: "no_password_user",
|
|
Enabled: true,
|
|
NoPassword: true,
|
|
AddPlainPasswords: []string{"password4"},
|
|
},
|
|
{
|
|
// Disabled user.
|
|
Username: "disabled_user",
|
|
Enabled: false,
|
|
AddPlainPasswords: []string{"password5"},
|
|
},
|
|
}
|
|
}
|
|
|
|
// compareSlices compare the elements in 2 slices, it checks if every element is s1 is contained in s2
|
|
// and vice versa. It essentially does a deep equality comparison.
|
|
// This is done manually rather than using slices.Equal because it would be ideal to throw an error
|
|
// specifying exactly which items are missing in either slice.
|
|
func compareSlices[T comparable](res, expected []T) error {
|
|
if len(res) != len(expected) {
|
|
return fmt.Errorf("expected slice of length %d, got slice of length %d", len(expected), len(res))
|
|
}
|
|
// Check whether all elements in res are contained in expected
|
|
for _, r := range res {
|
|
if !slices.Contains(expected, r) {
|
|
return fmt.Errorf("got response item %+v, but it's not contained in expected slices", r)
|
|
}
|
|
}
|
|
// Check whether all elements in expected are contained in res
|
|
for _, e := range expected {
|
|
if !slices.Contains(res, e) {
|
|
return fmt.Errorf("expected element %+v, not found in res slice", e)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// compareUsers compares 2 users and checks if all their fields are equal
|
|
func compareUsers(user1, user2 map[string][]string) error {
|
|
// Compare flags
|
|
if user1["username"][0] != user2["username"][0] {
|
|
return fmt.Errorf("mismatched usernames \"%s\", and \"%s\"", user1["username"][0], user2["username"][0])
|
|
}
|
|
|
|
// Check if both users are enabled.
|
|
if slices.Contains(user1["flags"], "on") != slices.Contains(user2["flags"], "on") {
|
|
return fmt.Errorf("mismatched enabled flag \"%+v\", and \"%+v\"",
|
|
slices.Contains(user1["flags"], "on"), slices.Contains(user2["flags"], "on"))
|
|
}
|
|
|
|
// Check if "nokeys" is present
|
|
if slices.Contains(user1["flags"], "nokeys") != slices.Contains(user2["flags"], "nokeys") {
|
|
return fmt.Errorf("mismatched nokeys flag \"%+v\", and \"%+v\"",
|
|
slices.Contains(user1["flags"], "nokeys"), slices.Contains(user2["flags"], "nokeys"))
|
|
}
|
|
|
|
// Check if "nopass" is present
|
|
if slices.Contains(user1["flags"], "nopass") != slices.Contains(user1["flags"], "nopass") {
|
|
return fmt.Errorf("mismatched nopassword flag \"%+v\", and \"%+v\"",
|
|
slices.Contains(user1["flags"], "nopass"), slices.Contains(user1["flags"], "nopass"))
|
|
}
|
|
|
|
// Compare permissions
|
|
permissions := [][][]string{
|
|
{user1["categories"], user2["categories"]},
|
|
{user1["commands"], user2["commands"]},
|
|
{user1["keys"], user2["keys"]},
|
|
{user1["channels"], user2["channels"]},
|
|
}
|
|
for _, p := range permissions {
|
|
if err := compareSlices(p[0], p[1]); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func generateSHA256Password(plain string) string {
|
|
h := sha256.New()
|
|
h.Write([]byte(plain))
|
|
return string(h.Sum(nil))
|
|
}
|
|
|
|
func Test_ACL(t *testing.T) {
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, true, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
mockServer.ShutDown()
|
|
})
|
|
|
|
t.Run("Test_HandleAuth", func(t *testing.T) {
|
|
t.Parallel()
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
}{
|
|
{ // 1. Authenticate with default user without specifying username
|
|
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 2. Authenticate with plaintext password
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("with_password_user"),
|
|
resp.StringValue("password2"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 3. Authenticate with SHA256 password
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("with_password_user"),
|
|
resp.StringValue("password3"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 4. Authenticate with no password user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("no_password_user"),
|
|
resp.StringValue("password4"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{ // 5. Fail to authenticate with disabled user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("disabled_user"),
|
|
resp.StringValue("password5"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: "Error user disabled_user is disabled",
|
|
},
|
|
{ // 6. Fail to authenticate with non-existent user
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("non_existent_user"),
|
|
resp.StringValue("password6"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: "Error no user with username non_existent_user",
|
|
},
|
|
{ // 7. Command too short
|
|
cmd: []resp.Value{resp.StringValue("AUTH")},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
{ // 8. Command too long
|
|
cmd: []resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue("user"),
|
|
resp.StringValue("password1"),
|
|
resp.StringValue("password2"),
|
|
},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if rv.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
if rv.String() != test.wantRes {
|
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleCat", func(t *testing.T) {
|
|
t.Parallel()
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
r := resp.NewConn(conn)
|
|
|
|
// Authenticate connection
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if rv.String() != "OK" {
|
|
t.Error("could not authenticate user")
|
|
}
|
|
|
|
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
|
// list of categories and commands available in the ACL module.
|
|
tests := []struct {
|
|
cmd []resp.Value
|
|
wantRes []string
|
|
wantErr string
|
|
}{
|
|
{ // 1. Return list of categories
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")},
|
|
wantRes: []string{
|
|
constants.ConnectionCategory,
|
|
constants.SlowCategory,
|
|
constants.FastCategory,
|
|
constants.AdminCategory,
|
|
constants.DangerousCategory,
|
|
},
|
|
wantErr: "",
|
|
},
|
|
{ // 2. Return list of commands in connection category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.ConnectionCategory)},
|
|
wantRes: []string{"auth"},
|
|
wantErr: "",
|
|
},
|
|
{ // 3. Return list of commands in slow category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.SlowCategory)},
|
|
wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 4. Return list of commands in fast category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.FastCategory)},
|
|
wantRes: []string{"acl|whoami"},
|
|
wantErr: "",
|
|
},
|
|
{ // 5. Return list of commands in admin category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.AdminCategory)},
|
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 6. Return list of commands in dangerous category
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(constants.DangerousCategory)},
|
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
|
wantErr: "",
|
|
},
|
|
{ // 7. Return error when category does not exist
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")},
|
|
wantRes: nil,
|
|
wantErr: "Error category NON-EXISTENT not found",
|
|
},
|
|
{ // 8. Command too long
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")},
|
|
wantRes: nil,
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
rv, _, err = r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if rv.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
|
}
|
|
continue
|
|
}
|
|
resArr := rv.Array()
|
|
// Check if all the elements in the expected array are in the response array
|
|
for _, expected := range test.wantRes {
|
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
|
return value.String() == expected
|
|
}) {
|
|
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleUsers", func(t *testing.T) {
|
|
t.Parallel()
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, false, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
users := []string{"default", "with_password_user", "no_password_user", "disabled_user"}
|
|
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
rv, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
resArr := rv.Array()
|
|
|
|
// Check if all the expected users are in the response array
|
|
for _, user := range users {
|
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
|
return value.String() == user
|
|
}) {
|
|
t.Errorf("could not find expected user \"%s\" in response array", user)
|
|
}
|
|
}
|
|
|
|
// Check if all the users in the response array are in the expected users
|
|
for _, value := range resArr {
|
|
if !slices.ContainsFunc(users, func(user string) bool {
|
|
return value.String() == user
|
|
}) {
|
|
t.Errorf("could not find response user \"%s\" in expected users array", value.String())
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleSetUser", func(t *testing.T) {
|
|
t.Parallel()
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, false, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
t.Cleanup(func() {
|
|
mockServer.ShutDown()
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
presetUser *echovault.User
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
wantUser map[string][]string
|
|
}{
|
|
{
|
|
name: "1. Create new enabled user",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_1"),
|
|
resp.StringValue("on"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_1"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "2. Create new disabled user",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_2"),
|
|
resp.StringValue("off"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_2"},
|
|
"flags": {"off"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "3. Create new enabled user with both plaintext and SHA256 passwords",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_3"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue(">set_user_3_plaintext_password_1"),
|
|
resp.StringValue(">set_user_3_plaintext_password_2"),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_1"))),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("set_user_3_hash_password_2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_3"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "4. Remove plaintext and SHA256 password from existing user",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_4",
|
|
Enabled: true,
|
|
AddPlainPasswords: []string{"set_user_4_plaintext_password_1", "set_user_4_plaintext_password_2"},
|
|
AddHashPasswords: []string{
|
|
generateSHA256Password("set_user_4_hash_password_1"),
|
|
generateSHA256Password("set_user_4_hash_password_2"),
|
|
},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_4"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("<set_user_3_plaintext_password_2"),
|
|
resp.StringValue(fmt.Sprintf("!%s", generateSHA256Password("set_user_3_hash_password_2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_4"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "5. Create user with no commands allowed to be executed",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_5"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nocommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_5"},
|
|
"flags": {"on"},
|
|
"categories": {"-@all"},
|
|
"commands": {"-all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "6. Create user that can access all categories with +@*",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_6"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+@*"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_6"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "7. Create user that can access all categories with allcategories flag",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_7"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allcategories"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_7"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "8. Create user with a few allowed categories and a few disallowed categories",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_8"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.WriteCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.ReadCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.PubSubCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.AdminCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.ConnectionCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.DangerousCategory)),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_8"},
|
|
"flags": {"on"},
|
|
"categories": {
|
|
fmt.Sprintf("+@%s", constants.WriteCategory),
|
|
fmt.Sprintf("+@%s", constants.ReadCategory),
|
|
fmt.Sprintf("+@%s", constants.PubSubCategory),
|
|
fmt.Sprintf("-@%s", constants.AdminCategory),
|
|
fmt.Sprintf("-@%s", constants.ConnectionCategory),
|
|
fmt.Sprintf("-@%s", constants.DangerousCategory),
|
|
},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "9. Create user that is not allowed to access any keys",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_9"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("resetkeys"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_9"},
|
|
"flags": {"on", "nokeys"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: `10. Create user that can access some read keys and some write keys.
|
|
Provide keys that are RW, W-Only and R-Only`,
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_10"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("%RW~key1"),
|
|
resp.StringValue("%RW~key2"),
|
|
resp.StringValue("%RW~key3"),
|
|
resp.StringValue("%RW~key4"),
|
|
resp.StringValue("%R~key5"),
|
|
resp.StringValue("%R~key6"),
|
|
resp.StringValue("%W~key7"),
|
|
resp.StringValue("%W~key8"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_10"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~key1", "%RW~key2", "%RW~key3", "%RW~key4", "%R~key5", "%R~key6", "%W~key7", "%W~key8"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "11. Create user that can access all pubsub channels with +&*",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_11"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+&*"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_11"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "12. Create user that can access all pubsub channels with allchannels flag",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_12"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allchannels"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_12"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "13. Create user with a few allowed pubsub channels and a few disallowed channels",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_13"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+&channel1"),
|
|
resp.StringValue("+&channel2"),
|
|
resp.StringValue("-&channel3"),
|
|
resp.StringValue("-&channel4"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_13"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&channel1", "+&channel2", "-&channel3", "-&channel4"},
|
|
},
|
|
},
|
|
{
|
|
name: "14. Create user that can access all commands",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_14"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("allcommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_14"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "15. Create user with some allowed commands and disallowed commands",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_15"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("+acl|getuser"),
|
|
resp.StringValue("+acl|setuser"),
|
|
resp.StringValue("+acl|deluser"),
|
|
resp.StringValue("-rewriteaof"),
|
|
resp.StringValue("-save"),
|
|
resp.StringValue("-publish"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_15"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+acl|getuser", "+acl|setuser", "+acl|deluser", "-rewriteaof", "-save", "-publish"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: `16. Create new user with no password using 'nopass'.
|
|
When nopass is provided, ignore any passwords that may have been provided in the command.`,
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_16"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
resp.StringValue(">password1"),
|
|
resp.StringValue(fmt.Sprintf("#%s", generateSHA256Password("password2"))),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_16"},
|
|
"flags": {"on", "nopass"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "17. Delete all existing users passwords using 'nopass'",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_17",
|
|
Enabled: true,
|
|
NoPassword: true,
|
|
AddPlainPasswords: []string{"password1"},
|
|
AddHashPasswords: []string{generateSHA256Password("password2")},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_17"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_17"},
|
|
"flags": {"on", "nopass"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "18. Clear all of an existing user's passwords using 'resetpass'",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_18",
|
|
Enabled: true,
|
|
NoPassword: true,
|
|
AddPlainPasswords: []string{"password1"},
|
|
AddHashPasswords: []string{generateSHA256Password("password2")},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_18"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nopass"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_18"},
|
|
"flags": {"on", "nopass"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "19. Clear all of an existing user's command privileges using 'nocommands'",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_19",
|
|
Enabled: true,
|
|
IncludeCommands: []string{"acl|getuser", "acl|setuser", "acl|deluser"},
|
|
ExcludeCommands: []string{"rewriteaof", "save"},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_19"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("nocommands"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_19"},
|
|
"flags": {"on"},
|
|
"categories": {"-@all"},
|
|
"commands": {"-all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "20. Clear all of an existing user's allowed keys using 'resetkeys'",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_20",
|
|
Enabled: true,
|
|
IncludeWriteKeys: []string{"key1", "key2", "key3", "key4", "key5", "key6"},
|
|
IncludeReadKeys: []string{"key1", "key2", "key3", "key7", "key8", "key9"},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_20"),
|
|
resp.StringValue("on"),
|
|
resp.StringValue("resetkeys"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_20"},
|
|
"flags": {"on", "nokeys"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {},
|
|
"channels": {"+&*"},
|
|
},
|
|
},
|
|
{
|
|
name: "21. Allow user to access all channels using 'resetchannels'",
|
|
presetUser: &echovault.User{
|
|
Username: "set_user_21",
|
|
Enabled: true,
|
|
IncludeChannels: []string{"channel1", "channel2"},
|
|
ExcludeChannels: []string{"channel3", "channel4"},
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("SETUSER"),
|
|
resp.StringValue("set_user_21"),
|
|
resp.StringValue("resetchannels"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
wantUser: map[string][]string{
|
|
"username": {"set_user_21"},
|
|
"flags": {"on"},
|
|
"categories": {"+@all"},
|
|
"commands": {"+all"},
|
|
"keys": {"%RW~*"},
|
|
"channels": {"-&*"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.presetUser != nil {
|
|
if _, err := mockServer.ACLSetUser(*test.presetUser); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
return
|
|
}
|
|
if v.String() != test.wantRes {
|
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, v.String())
|
|
}
|
|
if test.wantUser == nil {
|
|
return
|
|
}
|
|
|
|
user, err := mockServer.ACLGetUser(test.wantUser["username"][0])
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
if err = compareUsers(test.wantUser, user); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleGetUser", func(t *testing.T) {
|
|
t.Parallel()
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, false, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
mockServer.ShutDown()
|
|
})
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
presetUser *echovault.User
|
|
cmd []resp.Value
|
|
wantRes []resp.Value
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "1. Get the user and all their details",
|
|
presetUser: &echovault.User{
|
|
Username: "get_user_1",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
AddPlainPasswords: []string{"get_user_password_1"},
|
|
AddHashPasswords: []string{generateSHA256Password("get_user_password_2")},
|
|
IncludeCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludeCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludeCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludeCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludeReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludeWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludeChannels: []string{"channel1", "channel2"},
|
|
ExcludeChannels: []string{"channel3", "channel4"},
|
|
},
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("GETUSER"), resp.StringValue("get_user_1")},
|
|
wantRes: []resp.Value{
|
|
resp.StringValue("username"),
|
|
resp.ArrayValue([]resp.Value{resp.StringValue("get_user_1")}),
|
|
resp.StringValue("flags"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue("on"),
|
|
}),
|
|
resp.StringValue("categories"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.WriteCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.ReadCategory)),
|
|
resp.StringValue(fmt.Sprintf("+@%s", constants.PubSubCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.AdminCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.ConnectionCategory)),
|
|
resp.StringValue(fmt.Sprintf("-@%s", constants.DangerousCategory)),
|
|
}),
|
|
resp.StringValue("commands"),
|
|
resp.ArrayValue([]resp.Value{
|
|
resp.StringValue("+acl|setuser"),
|
|
resp.StringValue("+acl|getuser"),
|
|
resp.StringValue("+acl|deluser"),
|
|
resp.StringValue("-rewriteaof"),
|
|
resp.StringValue("-save"),
|
|
resp.StringValue("-acl|load"),
|
|
resp.StringValue("-acl|save"),
|
|
}),
|
|
resp.StringValue("keys"),
|
|
resp.ArrayValue([]resp.Value{
|
|
// Keys here
|
|
resp.StringValue("%RW~key1"),
|
|
resp.StringValue("%RW~key2"),
|
|
resp.StringValue("%R~key3"),
|
|
resp.StringValue("%R~key4"),
|
|
resp.StringValue("%W~key5"),
|
|
resp.StringValue("%W~key6"),
|
|
}),
|
|
resp.StringValue("channels"),
|
|
resp.ArrayValue([]resp.Value{
|
|
// Channels here
|
|
resp.StringValue("+&channel1"),
|
|
resp.StringValue("+&channel2"),
|
|
resp.StringValue("-&channel3"),
|
|
resp.StringValue("-&channel4"),
|
|
}),
|
|
},
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "2. Return user not found error",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("GETUSER"),
|
|
resp.StringValue("non_existent_user")},
|
|
wantRes: nil,
|
|
wantErr: "Error user not found",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.presetUser != nil {
|
|
if _, err := mockServer.ACLSetUser(*test.presetUser); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
return
|
|
}
|
|
resArr := v.Array()
|
|
for i := 0; i < len(resArr); i++ {
|
|
if slices.Contains([]string{"username", "flags", "categories", "commands", "keys", "channels"}, resArr[i].String()) {
|
|
// String item
|
|
if resArr[i].String() != test.wantRes[i].String() {
|
|
t.Errorf("expected response component %+v, got %+v", test.wantRes[i], resArr[i])
|
|
}
|
|
} else {
|
|
// Array item
|
|
var expected []string
|
|
for _, item := range test.wantRes[i].Array() {
|
|
expected = append(expected, item.String())
|
|
}
|
|
|
|
var res []string
|
|
for _, item := range resArr[i].Array() {
|
|
res = append(res, item.String())
|
|
}
|
|
|
|
if err = compareSlices(res, expected); err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleDelUser", func(t *testing.T) {
|
|
t.Parallel()
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, false, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
mockServer.ShutDown()
|
|
})
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
presetUser *echovault.User
|
|
cmd []resp.Value
|
|
wantRes string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "1. Delete existing user while skipping default user and non-existent user",
|
|
presetUser: &echovault.User{
|
|
Username: "user_to_delete",
|
|
Enabled: true,
|
|
},
|
|
cmd: []resp.Value{
|
|
resp.StringValue("ACL"),
|
|
resp.StringValue("DELUSER"),
|
|
resp.StringValue("default"),
|
|
resp.StringValue("user_to_delete"),
|
|
resp.StringValue("non_existent_user"),
|
|
},
|
|
wantRes: "OK",
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "2. Command too short",
|
|
presetUser: nil,
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("DELUSER")},
|
|
wantRes: "",
|
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.presetUser != nil {
|
|
if _, err := mockServer.ACLSetUser(*test.presetUser); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if test.wantErr != "" {
|
|
if v.Error().Error() != test.wantErr {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
usernames, err := mockServer.ACLUsers()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
// Check that default user still exists in the list of users
|
|
if !slices.Contains(usernames, "default") {
|
|
t.Error("could not find user with username \"default\" in the ACL after deleting user")
|
|
return
|
|
}
|
|
|
|
// Check that the deleted user is no longer in the list
|
|
if slices.Contains(usernames, "user_to_delete") {
|
|
t.Error("deleted user found in the ACL")
|
|
return
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleWhoAmI", func(t *testing.T) {
|
|
t.Parallel()
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
username string
|
|
password string
|
|
wantRes string
|
|
}{
|
|
{
|
|
name: "1. With default user",
|
|
username: "default",
|
|
password: "password1",
|
|
wantRes: "default",
|
|
},
|
|
{
|
|
name: "2. With user authenticated by plaintext password",
|
|
username: "with_password_user",
|
|
password: "password2",
|
|
wantRes: "with_password_user",
|
|
},
|
|
{
|
|
name: "3. With user authenticated by SHA256 password",
|
|
username: "with_password_user",
|
|
password: "password3",
|
|
wantRes: "with_password_user",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
// Authenticate
|
|
if err = r.WriteArray([]resp.Value{
|
|
resp.StringValue("AUTH"),
|
|
resp.StringValue(test.username),
|
|
resp.StringValue(test.password),
|
|
}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if v.String() != "OK" {
|
|
t.Errorf("expected response for auth with %s:%s to be \"OK\", got %s", test.username, test.password, v.String())
|
|
}
|
|
// Check whoami response value
|
|
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("WHOAMI")}); err != nil {
|
|
t.Error(err)
|
|
}
|
|
v, _, err = r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
if v.String() != test.wantRes {
|
|
t.Errorf("expected whoami response to be \"%s\", got \"%s\"", test.wantRes, v.String())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleList", func(t *testing.T) {
|
|
t.Parallel()
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
mockServer, err := setUpServer(port, false, "")
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
mockServer.ShutDown()
|
|
})
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if conn != nil {
|
|
_ = conn.Close()
|
|
}
|
|
}()
|
|
|
|
r := resp.NewConn(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
presetUsers []*echovault.User
|
|
cmd []resp.Value
|
|
wantRes []string
|
|
wantErr string
|
|
}{
|
|
{
|
|
name: "1. Get the user and all their details",
|
|
presetUsers: []*echovault.User{
|
|
{
|
|
Username: "list_user_1",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
AddPlainPasswords: []string{"list_user_password_1"},
|
|
AddHashPasswords: []string{generateSHA256Password("list_user_password_2")},
|
|
IncludeCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludeCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludeCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludeCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludeReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludeWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludeChannels: []string{"channel1", "channel2"},
|
|
ExcludeChannels: []string{"channel3", "channel4"},
|
|
},
|
|
{
|
|
Username: "list_user_2",
|
|
Enabled: true,
|
|
NoPassword: true,
|
|
NoKeys: true,
|
|
IncludeCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludeCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludeCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludeCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludeReadKeys: []string{},
|
|
IncludeWriteKeys: []string{},
|
|
IncludeChannels: []string{"channel1", "channel2"},
|
|
ExcludeChannels: []string{"channel3", "channel4"},
|
|
},
|
|
{
|
|
Username: "list_user_3",
|
|
Enabled: true,
|
|
NoPassword: false,
|
|
NoKeys: false,
|
|
AddPlainPasswords: []string{"list_user_password_3"},
|
|
AddHashPasswords: []string{generateSHA256Password("list_user_password_4")},
|
|
IncludeCategories: []string{constants.WriteCategory, constants.ReadCategory, constants.PubSubCategory},
|
|
ExcludeCategories: []string{constants.AdminCategory, constants.ConnectionCategory, constants.DangerousCategory},
|
|
IncludeCommands: []string{"acl|setuser", "acl|getuser", "acl|deluser"},
|
|
ExcludeCommands: []string{"rewriteaof", "save", "acl|load", "acl|save"},
|
|
IncludeReadKeys: []string{"key1", "key2", "key3", "key4"},
|
|
IncludeWriteKeys: []string{"key1", "key2", "key5", "key6"},
|
|
IncludeChannels: []string{"channel1", "channel2"},
|
|
ExcludeChannels: []string{"channel3", "channel4"},
|
|
},
|
|
},
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")},
|
|
wantRes: []string{
|
|
"default on +@all +all %RW~* +&*",
|
|
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
|
generateSHA256Password("password3"), "%RW"),
|
|
"no_password_user on nopass +@all +all %RW~* +&*",
|
|
"disabled_user off >password5 +@all +all %RW~* +&*",
|
|
fmt.Sprintf(`list_user_1 on >list_user_password_1 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_2"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
|
|
fmt.Sprintf(`list_user_2 on nopass nokeys +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save +&channel1 +&channel2 -&channel3 -&channel4`),
|
|
fmt.Sprintf(`list_user_3 on >list_user_password_3 #%s +@write +@read +@pubsub -@admin -@connection -@dangerous +acl|setuser +acl|getuser +acl|deluser -rewriteaof -save -acl|load -acl|save %s +&channel1 +&channel2 -&channel3 -&channel4`, generateSHA256Password("list_user_password_4"), "%RW~key1 %RW~key2 %R~key3 %R~key4"),
|
|
},
|
|
wantErr: "",
|
|
},
|
|
{
|
|
name: "2. Command too long",
|
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST"), resp.StringValue("USERNAME")},
|
|
wantRes: nil,
|
|
wantErr: constants.WrongArgsResponse,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
if test.presetUsers != nil {
|
|
for _, user := range test.presetUsers {
|
|
if _, err := mockServer.ACLSetUser(*user); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
if err = r.WriteArray(test.cmd); err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
v, _, err := r.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
if test.wantErr != "" {
|
|
if !strings.Contains(v.Error().Error(), test.wantErr) {
|
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, v.Error().Error())
|
|
}
|
|
return
|
|
}
|
|
resArr := v.Array()
|
|
if len(resArr) != len(test.wantRes) {
|
|
t.Errorf("expected response of lenght %d, got lenght %d", len(test.wantRes), len(resArr))
|
|
return
|
|
}
|
|
|
|
var resStr []string
|
|
for i := 0; i < len(resArr); i++ {
|
|
resStr = strings.Split(resArr[i].String(), " ")
|
|
if !slices.ContainsFunc(test.wantRes, func(s string) bool {
|
|
expectedUserSlice := strings.Split(s, " ")
|
|
return compareSlices(resStr, expectedUserSlice) == nil
|
|
}) {
|
|
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
|
return
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleSave", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseDir := path.Join(".", "testdata", "save")
|
|
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll(baseDir)
|
|
})
|
|
|
|
tests := []struct {
|
|
name string
|
|
path string
|
|
want []string // Response from ACL List command.
|
|
}{
|
|
{
|
|
name: "1. Save ACL config to .json file",
|
|
path: path.Join(baseDir, "json_test.json"),
|
|
want: []string{
|
|
"default on +@all +all %RW~* +&*",
|
|
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
|
generateSHA256Password("password3"), "%RW"),
|
|
"no_password_user on nopass +@all +all %RW~* +&*",
|
|
"disabled_user off >password5 +@all +all %RW~* +&*",
|
|
},
|
|
},
|
|
{
|
|
name: "2. Save ACL config to .yaml file",
|
|
path: path.Join(baseDir, "yaml_test.yaml"),
|
|
want: []string{
|
|
"default on +@all +all %RW~* +&*",
|
|
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
|
generateSHA256Password("password3"), "%RW"),
|
|
"no_password_user on nopass +@all +all %RW~* +&*",
|
|
"disabled_user off >password5 +@all +all %RW~* +&*",
|
|
},
|
|
},
|
|
{
|
|
name: "3. Save ACL config to .yml file",
|
|
path: path.Join(baseDir, "yml_test.yml"),
|
|
want: []string{
|
|
"default on +@all +all %RW~* +&*",
|
|
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
|
generateSHA256Password("password3"), "%RW"),
|
|
"no_password_user on nopass +@all +all %RW~* +&*",
|
|
"disabled_user off >password5 +@all +all %RW~* +&*",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Get free port.
|
|
port, err := internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
// Create new server instance
|
|
mockServer, err := setUpServer(port, false, test.path)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
conn, err := internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
client := resp.NewConn(conn)
|
|
|
|
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
res, _, err := client.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
if !strings.EqualFold(res.String(), "ok") {
|
|
t.Errorf("expected OK response, got \"%s\"", res.String())
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
// Close client connection
|
|
if err = conn.Close(); err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
// Shutdown the mock server
|
|
mockServer.ShutDown()
|
|
|
|
// Restart server and create new client connection
|
|
port, err = internal.GetFreePort()
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
mockServer, err = setUpServer(port, false, test.path)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return
|
|
}
|
|
go func() {
|
|
mockServer.Start()
|
|
}()
|
|
|
|
conn, err = internal.GetConnection("localhost", port)
|
|
if err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
client = resp.NewConn(conn)
|
|
|
|
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
res, _, err = client.ReadValue()
|
|
if err != nil {
|
|
t.Error(err)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
// Check if ACL LIST returns the expected list of users.
|
|
resArr := res.Array()
|
|
if len(resArr) != len(test.want) {
|
|
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
|
|
var resStr []string
|
|
for i := 0; i < len(resArr); i++ {
|
|
resStr = strings.Split(resArr[i].String(), " ")
|
|
if !slices.ContainsFunc(test.want, func(s string) bool {
|
|
expectedUserSlice := strings.Split(s, " ")
|
|
return compareSlices(resStr, expectedUserSlice) == nil
|
|
}) {
|
|
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
|
mockServer.ShutDown()
|
|
return
|
|
}
|
|
}
|
|
|
|
mockServer.ShutDown()
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Test_HandleLoad", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseDir := path.Join(".", "testdata", "load")
|
|
|
|
t.Cleanup(func() {
|
|
_ = os.RemoveAll(baseDir)
|
|
})
|
|
})
|
|
}
|