diff --git a/src/modules/acl/acl.go b/src/modules/acl/acl.go index 1a6545b..11652e2 100644 --- a/src/modules/acl/acl.go +++ b/src/modules/acl/acl.go @@ -164,7 +164,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error { // Terminate every connection attached to this user for connRef, connection := range acl.Connections { if connection.User.Username == user.Username { - (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second)) + _ = (*connRef).SetReadDeadline(time.Now().Add(-1 * time.Second)) } } // Delete the user from the ACL @@ -175,7 +175,7 @@ func (acl *ACL) DeleteUser(ctx context.Context, usernames []string) error { return nil } -func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd []string) error { +func (acl *ACL) AuthenticateConnection(_ context.Context, conn *net.Conn, cmd []string) error { var passwords []Password var user *User @@ -194,6 +194,7 @@ func (acl *ACL) AuthenticateConnection(ctx context.Context, conn *net.Conn, cmd }) user = acl.Users[idx] } + if len(cmd) == 3 { // Process AUTH h.Write([]byte(cmd[2])) @@ -278,7 +279,6 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils. // If the command is 'auth', then return early and allow it if strings.EqualFold(comm, "auth") { - // TODO: Add rate limiting to prevent auth spamming return nil } diff --git a/src/modules/acl/commands.go b/src/modules/acl/commands.go index 939e8f7..b6bc1d8 100644 --- a/src/modules/acl/commands.go +++ b/src/modules/acl/commands.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/echovault/echovault/src/utils" "gopkg.in/yaml.v3" + "log" "net" "os" "path" @@ -28,7 +29,7 @@ func handleAuth(ctx context.Context, cmd []string, server utils.Server, conn *ne return []byte(utils.OkResponse), nil } -func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleGetUser(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) != 3 { return nil, errors.New(utils.WrongArgsResponse) } @@ -144,7 +145,7 @@ func handleGetUser(ctx context.Context, cmd []string, server utils.Server, conn return []byte(res), nil } -func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) > 3 { return nil, errors.New(utils.WrongArgsResponse) } @@ -204,7 +205,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, conn *net return nil, errors.New("category not found") } -func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) { acl, ok := server.GetACL().(*ACL) if !ok { return nil, errors.New("could not load ACL") @@ -217,7 +218,7 @@ func handleUsers(ctx context.Context, cmd []string, server utils.Server, conn *n return []byte(res), nil } -func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleSetUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { acl, ok := server.GetACL().(*ACL) if !ok { return nil, errors.New("could not load ACL") @@ -228,7 +229,7 @@ func handleSetUser(ctx context.Context, cmd []string, server utils.Server, conn return []byte(utils.OkResponse), nil } -func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleDelUser(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) < 3 { return nil, errors.New(utils.WrongArgsResponse) } @@ -242,7 +243,7 @@ func handleDelUser(ctx context.Context, cmd []string, server utils.Server, conn return []byte(utils.OkResponse), nil } -func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleWhoAmI(_ context.Context, _ []string, server utils.Server, conn *net.Conn) ([]byte, error) { acl, ok := server.GetACL().(*ACL) if !ok { return nil, errors.New("could not load ACL") @@ -251,7 +252,7 @@ func handleWhoAmI(ctx context.Context, cmd []string, server utils.Server, conn * return []byte(fmt.Sprintf("+%s\r\n", connectionInfo.User.Username)), nil } -func handleList(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleList(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) > 2 { return nil, errors.New(utils.WrongArgsResponse) } @@ -347,7 +348,7 @@ func handleList(ctx context.Context, cmd []string, server utils.Server, conn *ne return []byte(res), nil } -func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleLoad(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) != 3 { return nil, errors.New(utils.WrongArgsResponse) } @@ -364,8 +365,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne defer func() { if err := f.Close(); err != nil { - // TODO: Log file close error with context - fmt.Println(err) + log.Println(err) } }() @@ -412,7 +412,7 @@ func handleLoad(ctx context.Context, cmd []string, server utils.Server, conn *ne return []byte(utils.OkResponse), nil } -func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { +func handleSave(_ context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) { if len(cmd) > 2 { return nil, errors.New(utils.WrongArgsResponse) } @@ -429,8 +429,7 @@ func handleSave(ctx context.Context, cmd []string, server utils.Server, conn *ne defer func() { if err := f.Close(); err != nil { - // TODO: Log file close error with context - fmt.Println(err) + log.Println(err) } }() @@ -490,10 +489,11 @@ func Commands() []utils.Command { }, SubCommands: []utils.SubCommand{ { - Command: "cat", - Categories: []string{utils.SlowCategory}, - Description: "(ACL CAT [category]) List all the categories and commands inside a category.", - Sync: false, + Command: "cat", + Categories: []string{utils.SlowCategory}, + Description: `(ACL CAT [category]) List all the categories. +If the optional category is provided, list all the commands in the category`, + Sync: false, KeyExtractionFunc: func(cmd []string) ([]string, error) { return []string{}, nil }, diff --git a/src/modules/acl/commands_test.go b/src/modules/acl/commands_test.go index 90a5519..212b00a 100644 --- a/src/modules/acl/commands_test.go +++ b/src/modules/acl/commands_test.go @@ -2,8 +2,12 @@ package acl import ( "context" + "crypto/sha256" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" + "github.com/tidwall/resp" + "net" "testing" ) @@ -22,13 +26,17 @@ func init() { Port: port, DataDir: "", EvictionPolicy: utils.NoEviction, + RequirePass: true, + Password: "password1", } acl = NewACL(config) + acl.Users = append(acl.Users, generateInitialTestUsers()...) mockServer = server.NewServer(server.Opts{ - Config: config, - ACL: acl, + Config: config, + ACL: acl, + Commands: Commands(), }) go func() { @@ -36,9 +44,140 @@ func init() { }() } -func Test_HandleAuth(t *testing.T) {} +func generateInitialTestUsers() []*User { + // User with both hash password and plaintext password + withPasswordUser := CreateUser("with_password_user") + h := sha256.New() + h.Write([]byte("password3")) + withPasswordUser.Passwords = []Password{ + {PasswordType: PasswordPlainText, PasswordValue: "password2"}, + {PasswordType: PasswordSHA256, PasswordValue: string(h.Sum(nil))}, + } -func Test_HandleCat(t *testing.T) {} + // User with NoPassword option + noPasswordUser := CreateUser("no_password_user") + noPasswordUser.Passwords = []Password{ + {PasswordType: PasswordPlainText, PasswordValue: "password4"}, + } + noPasswordUser.NoPassword = true + + // Disabled user + disabledUser := CreateUser("disabled_user") + disabledUser.Passwords = []Password{ + {PasswordType: PasswordPlainText, PasswordValue: "password5"}, + } + disabledUser.Enabled = false + + return []*User{ + withPasswordUser, + noPasswordUser, + disabledUser, + } +} + +func Test_HandleAuth(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port)) + if err != nil { + t.Error(err) + } + r := resp.NewConn(conn) + + tests := []struct { + cmd []resp.Value + wantRes string + wantErr string + }{ + { // 1. Authenticate with default user without specifying username + cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}, + wantRes: "OK", + wantErr: "", + }, + { // 2. Authenticate with plaintext password + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("with_password_user"), + resp.StringValue("password2"), + }, + wantRes: "OK", + wantErr: "", + }, + { // 3. Authenticate with SHA256 password + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("with_password_user"), + resp.StringValue("password3"), + }, + wantRes: "OK", + wantErr: "", + }, + { // 4. Authenticate with no password user + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("no_password_user"), + resp.StringValue("password4"), + }, + wantRes: "OK", + wantErr: "", + }, + { // 5. Fail to authenticate with disabled user + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("disabled_user"), + resp.StringValue("password5"), + }, + wantRes: "", + wantErr: "Error user disabled_user is disabled", + }, + { // 6. Fail to authenticate with non-existent user + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("non_existent_user"), + resp.StringValue("password6"), + }, + wantRes: "", + wantErr: "Error no user with username non_existent_user", + }, + { // 7. Command too short + cmd: []resp.Value{resp.StringValue("AUTH")}, + wantRes: "", + wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse), + }, + { // 8. Command too long + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("user"), + resp.StringValue("password1"), + resp.StringValue("password2"), + }, + wantRes: "", + wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse), + }, + } + + for _, test := range tests { + if err = r.WriteArray(test.cmd); err != nil { + t.Error(err) + } + rv, _, err := r.ReadValue() + if err != nil { + t.Error(err) + } + if test.wantErr != "" { + if rv.Error().Error() != test.wantErr { + t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) + } + continue + } + if rv.String() != test.wantRes { + t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String()) + } + } +} + +func Test_HandleCat(t *testing.T) { + // Since only ACL commands are loaded in this test suite, this test will only test against the + // list of categories and commands available in the ACL module. +} func Test_HandleUsers(t *testing.T) {} diff --git a/src/modules/acl/user.go b/src/modules/acl/user.go index 5137035..1fcd02e 100644 --- a/src/modules/acl/user.go +++ b/src/modules/acl/user.go @@ -5,6 +5,11 @@ import ( "strings" ) +const ( + PasswordPlainText = "plaintext" + PasswordSHA256 = "SHA256" +) + type Password struct { PasswordType string `json:"PasswordType" yaml:"PasswordType"` // plaintext, SHA256 PasswordValue string `json:"PasswordValue" yaml:"PasswordValue"` @@ -105,7 +110,7 @@ func (user *User) UpdateUser(cmd []string) error { } if str[0] == '<' { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { - if strings.EqualFold(password.PasswordType, "SHA256") { + if strings.EqualFold(password.PasswordType, PasswordSHA256) { return false } return password.PasswordValue == str[1:] @@ -114,7 +119,7 @@ func (user *User) UpdateUser(cmd []string) error { } if str[0] == '!' { user.Passwords = slices.DeleteFunc(user.Passwords, func(password Password) bool { - if strings.EqualFold(password.PasswordType, "plaintext") { + if strings.EqualFold(password.PasswordType, PasswordPlainText) { return false } return password.PasswordValue == str[1:] @@ -278,7 +283,7 @@ func CreateUser(username string) *User { func GetPasswordType(password string) string { if password[0] == '#' { - return "SHA256" + return PasswordSHA256 } - return "plaintext" + return PasswordPlainText }