Initialise default IncludedCategories to allCategories("*").

Implemented test for ACL USERS command.
This commit is contained in:
Kelvin Mwinuka
2024-03-21 15:57:16 +08:00
parent 9191d16762
commit 018aea7785
3 changed files with 181 additions and 7 deletions

View File

@@ -202,7 +202,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
}
}
return nil, errors.New("category not found")
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
}
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {

View File

@@ -8,6 +8,7 @@ import (
"github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp"
"net"
"slices"
"testing"
)
@@ -21,6 +22,14 @@ func init() {
bindAddr = "localhost"
port = 7490
mockServer = setUpServer(bindAddr, port)
go func() {
mockServer.Start(context.Background())
}()
}
func setUpServer(bindAddr string, port uint16) *server.Server {
config := utils.Config{
BindAddr: bindAddr,
Port: port,
@@ -33,15 +42,11 @@ func init() {
acl = NewACL(config)
acl.Users = append(acl.Users, generateInitialTestUsers()...)
mockServer = server.NewServer(server.Opts{
return server.NewServer(server.Opts{
Config: config,
ACL: acl,
Commands: Commands(),
})
go func() {
mockServer.Start(context.Background())
}()
}
func generateInitialTestUsers() []*User {
@@ -80,6 +85,9 @@ func Test_HandleAuth(t *testing.T) {
if err != nil {
t.Error(err)
}
defer func() {
_ = conn.Close()
}()
r := resp.NewConn(conn)
tests := []struct {
@@ -175,11 +183,174 @@ func Test_HandleAuth(t *testing.T) {
}
func Test_HandleCat(t *testing.T) {
// Since only ACL commands are loaded in this test suite, this test will only test against the
// list of categories and commands available in the ACL module.
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
defer func() {
_ = conn.Close()
}()
r := resp.NewConn(conn)
// Authenticate connection
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if rv.String() != "OK" {
t.Error("could not authenticate user")
}
func Test_HandleUsers(t *testing.T) {}
// Since only ACL commands are loaded in this test suite, this test will only test against the
// list of categories and commands available in the ACL module.
tests := []struct {
cmd []resp.Value
wantRes []string
wantErr string
}{
{ // 1. Return list of categories
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")},
wantRes: []string{
utils.ConnectionCategory,
utils.SlowCategory,
utils.FastCategory,
utils.AdminCategory,
utils.DangerousCategory,
},
wantErr: "",
},
{ // 2. Return list of commands in connection category
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.ConnectionCategory)},
wantRes: []string{"auth"},
wantErr: "",
},
{ // 3. Return list of commands in slow category
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.SlowCategory)},
wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
wantErr: "",
},
{ // 4. Return list of commands in fast category
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.FastCategory)},
wantRes: []string{"acl|whoami"},
wantErr: "",
},
{ // 5. Return list of commands in admin category
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.AdminCategory)},
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
wantErr: "",
},
{ // 6. Return list of commands in dangerous category
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.DangerousCategory)},
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
wantErr: "",
},
{ // 7. Return error when category does not exist
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")},
wantRes: nil,
wantErr: "Error category NON-EXISTENT not found",
},
{ // 8. Command too long
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")},
wantRes: nil,
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
},
}
for _, test := range tests {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err = r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
continue
}
resArr := rv.Array()
// Check if all the elements in the expected array are in the response array
for _, expected := range test.wantRes {
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
return value.String() == expected
}) {
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
}
}
// Check if all the elements in the response array are in the expected array
for _, value := range resArr {
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
return value.String() == expected
}) {
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
}
}
}
}
func Test_HandleUsers(t *testing.T) {
var port uint16 = 7491
mockServer := setUpServer(bindAddr, port)
go func() {
mockServer.Start(context.Background())
}()
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
if err != nil {
t.Error(err)
}
defer func() {
_ = conn.Close()
}()
r := resp.NewConn(conn)
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if rv.String() != "OK" {
t.Errorf("expected OK response, got \"%s\"", rv.String())
}
users := []string{"default", "with_password_user", "no_password_user", "disabled_user"}
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil {
t.Error(err)
}
rv, _, err = r.ReadValue()
if err != nil {
t.Error(err)
}
resArr := rv.Array()
// Check if all the expected users are in the response array
for _, user := range users {
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
return value.String() == user
}) {
t.Errorf("could not find expected user \"%s\" in response array", user)
}
}
// Check if all the users in the response array are in the expected users
for _, value := range resArr {
if !slices.ContainsFunc(users, func(user string) bool {
return value.String() == user
}) {
t.Errorf("could not find response user \"%s\" in expected users array", value.String())
}
}
}
func Test_HandleSetUser(t *testing.T) {}

View File

@@ -38,6 +38,9 @@ type User struct {
func (user *User) Normalise() {
user.IncludedCategories = RemoveDuplicateEntries(user.IncludedCategories, "allCategories")
if len(user.IncludedCategories) == 0 {
user.IncludedCategories = []string{"*"}
}
user.ExcludedCategories = RemoveDuplicateEntries(user.ExcludedCategories, "allCategories")
if slices.Contains(user.ExcludedCategories, "*") {
user.IncludedCategories = []string{}