diff --git a/src/modules/acl/commands.go b/src/modules/acl/commands.go index b6bc1d8..d4c3ed1 100644 --- a/src/modules/acl/commands.go +++ b/src/modules/acl/commands.go @@ -202,7 +202,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Co } } - return nil, errors.New("category not found") + return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2])) } func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) { diff --git a/src/modules/acl/commands_test.go b/src/modules/acl/commands_test.go index 212b00a..9732976 100644 --- a/src/modules/acl/commands_test.go +++ b/src/modules/acl/commands_test.go @@ -8,6 +8,7 @@ import ( "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" "net" + "slices" "testing" ) @@ -21,6 +22,14 @@ func init() { bindAddr = "localhost" port = 7490 + mockServer = setUpServer(bindAddr, port) + + go func() { + mockServer.Start(context.Background()) + }() +} + +func setUpServer(bindAddr string, port uint16) *server.Server { config := utils.Config{ BindAddr: bindAddr, Port: port, @@ -33,15 +42,11 @@ func init() { acl = NewACL(config) acl.Users = append(acl.Users, generateInitialTestUsers()...) - mockServer = server.NewServer(server.Opts{ + return server.NewServer(server.Opts{ Config: config, ACL: acl, Commands: Commands(), }) - - go func() { - mockServer.Start(context.Background()) - }() } func generateInitialTestUsers() []*User { @@ -80,6 +85,9 @@ func Test_HandleAuth(t *testing.T) { if err != nil { t.Error(err) } + defer func() { + _ = conn.Close() + }() r := resp.NewConn(conn) tests := []struct { @@ -175,11 +183,174 @@ func Test_HandleAuth(t *testing.T) { } func Test_HandleCat(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) + if err != nil { + t.Error(err) + } + defer func() { + _ = conn.Close() + }() + r := resp.NewConn(conn) + + // Authenticate connection + if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil { + t.Error(err) + } + rv, _, err := r.ReadValue() + if err != nil { + t.Error(err) + } + if rv.String() != "OK" { + t.Error("could not authenticate user") + } + // Since only ACL commands are loaded in this test suite, this test will only test against the // list of categories and commands available in the ACL module. + tests := []struct { + cmd []resp.Value + wantRes []string + wantErr string + }{ + { // 1. Return list of categories + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")}, + wantRes: []string{ + utils.ConnectionCategory, + utils.SlowCategory, + utils.FastCategory, + utils.AdminCategory, + utils.DangerousCategory, + }, + wantErr: "", + }, + { // 2. Return list of commands in connection category + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.ConnectionCategory)}, + wantRes: []string{"auth"}, + wantErr: "", + }, + { // 3. Return list of commands in slow category + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.SlowCategory)}, + wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, + wantErr: "", + }, + { // 4. Return list of commands in fast category + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.FastCategory)}, + wantRes: []string{"acl|whoami"}, + wantErr: "", + }, + { // 5. Return list of commands in admin category + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.AdminCategory)}, + wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, + wantErr: "", + }, + { // 6. Return list of commands in dangerous category + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.DangerousCategory)}, + wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"}, + wantErr: "", + }, + { // 7. Return error when category does not exist + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")}, + wantRes: nil, + wantErr: "Error category NON-EXISTENT not found", + }, + { // 8. Command too long + cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")}, + wantRes: nil, + wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse), + }, + } + + for _, test := range tests { + if err = r.WriteArray(test.cmd); err != nil { + t.Error(err) + } + rv, _, err = r.ReadValue() + if err != nil { + t.Error(err) + } + if test.wantErr != "" { + if rv.Error().Error() != test.wantErr { + t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) + } + continue + } + resArr := rv.Array() + // Check if all the elements in the expected array are in the response array + for _, expected := range test.wantRes { + if !slices.ContainsFunc(resArr, func(value resp.Value) bool { + return value.String() == expected + }) { + t.Errorf("could not find expected command \"%s\" in the response array for category", expected) + } + } + // Check if all the elements in the response array are in the expected array + for _, value := range resArr { + if !slices.ContainsFunc(test.wantRes, func(expected string) bool { + return value.String() == expected + }) { + t.Errorf("could not find response command \"%s\" in the expected array", value.String()) + } + } + } } -func Test_HandleUsers(t *testing.T) {} +func Test_HandleUsers(t *testing.T) { + var port uint16 = 7491 + mockServer := setUpServer(bindAddr, port) + go func() { + mockServer.Start(context.Background()) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) + if err != nil { + t.Error(err) + } + defer func() { + _ = conn.Close() + }() + + r := resp.NewConn(conn) + if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil { + t.Error(err) + } + rv, _, err := r.ReadValue() + if err != nil { + t.Error(err) + } + if rv.String() != "OK" { + t.Errorf("expected OK response, got \"%s\"", rv.String()) + } + + users := []string{"default", "with_password_user", "no_password_user", "disabled_user"} + + if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil { + t.Error(err) + } + + rv, _, err = r.ReadValue() + if err != nil { + t.Error(err) + } + + resArr := rv.Array() + + // Check if all the expected users are in the response array + for _, user := range users { + if !slices.ContainsFunc(resArr, func(value resp.Value) bool { + return value.String() == user + }) { + t.Errorf("could not find expected user \"%s\" in response array", user) + } + } + + // Check if all the users in the response array are in the expected users + for _, value := range resArr { + if !slices.ContainsFunc(users, func(user string) bool { + return value.String() == user + }) { + t.Errorf("could not find response user \"%s\" in expected users array", value.String()) + } + } +} func Test_HandleSetUser(t *testing.T) {} diff --git a/src/modules/acl/user.go b/src/modules/acl/user.go index 1fcd02e..8f030c0 100644 --- a/src/modules/acl/user.go +++ b/src/modules/acl/user.go @@ -38,6 +38,9 @@ type User struct { func (user *User) Normalise() { user.IncludedCategories = RemoveDuplicateEntries(user.IncludedCategories, "allCategories") + if len(user.IncludedCategories) == 0 { + user.IncludedCategories = []string{"*"} + } user.ExcludedCategories = RemoveDuplicateEntries(user.ExcludedCategories, "allCategories") if slices.Contains(user.ExcludedCategories, "*") { user.IncludedCategories = []string{}