diff --git a/internal/modules/acl/acl.go b/internal/modules/acl/acl.go index c2b65cc..d9226f6 100644 --- a/internal/modules/acl/acl.go +++ b/internal/modules/acl/acl.go @@ -325,8 +325,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern return nil } - // Skip PING - if strings.EqualFold(comm, "ping") { + // Skip certain commands from authorization + if slices.Contains([]string{"ping", "echo"}, strings.ToLower(comm)) { return nil } diff --git a/internal/modules/acl/commands.go b/internal/modules/acl/commands.go index cc87c47..b4ac08f 100644 --- a/internal/modules/acl/commands.go +++ b/internal/modules/acl/commands.go @@ -28,23 +28,6 @@ import ( "strings" ) -func handleAuth(params internal.HandlerFuncParams) ([]byte, error) { - if len(params.Command) < 2 || len(params.Command) > 3 { - return nil, errors.New(constants.WrongArgsResponse) - } - acl, ok := params.GetACL().(*ACL) - if !ok { - return nil, errors.New("could not load ACL") - } - acl.LockUsers() - defer acl.UnlockUsers() - - if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil { - return nil, err - } - return []byte(constants.OkResponse), nil -} - func handleCat(params internal.HandlerFuncParams) ([]byte, error) { if len(params.Command) > 3 { return nil, errors.New(constants.WrongArgsResponse) @@ -496,23 +479,6 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) { func Commands() []internal.Command { return []internal.Command{ - { - Command: "auth", - Module: constants.ACLModule, - Categories: []string{constants.ConnectionCategory, constants.SlowCategory}, - Description: `(AUTH [username] password) -Authenticates the connection. If the username is not provided, the connection will be authenticated against the -default ACL user. Otherwise, it is authenticated against the ACL user with the provided username.`, - Sync: false, - KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { - return internal.KeyExtractionFuncResult{ - Channels: make([]string, 0), - ReadKeys: make([]string, 0), - WriteKeys: make([]string, 0), - }, nil - }, - HandlerFunc: handleAuth, - }, { Command: "acl", Module: constants.ACLModule, diff --git a/internal/modules/acl/commands_test.go b/internal/modules/acl/commands_test.go index 1ad7466..1bf861e 100644 --- a/internal/modules/acl/commands_test.go +++ b/internal/modules/acl/commands_test.go @@ -176,135 +176,6 @@ func Test_ACL(t *testing.T) { mockServer.ShutDown() }) - t.Run("Test_HandleAuth", func(t *testing.T) { - t.Parallel() - - conn, err := internal.GetConnection("localhost", port) - if err != nil { - t.Error(err) - return - } - defer func() { - if conn != nil { - _ = conn.Close() - } - }() - - r := resp.NewConn(conn) - - tests := []struct { - name string - cmd []resp.Value - wantRes string - wantErr string - }{ - { - name: "1. Authenticate with default user without specifying username", - cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}, - wantRes: "OK", - wantErr: "", - }, - { - name: "2. Authenticate with plaintext password", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("with_password_user"), - resp.StringValue("password2"), - }, - wantRes: "OK", - wantErr: "", - }, - { - name: "3. Authenticate with SHA256 password", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("with_password_user"), - resp.StringValue("password3"), - }, - wantRes: "OK", - wantErr: "", - }, - { - name: "4. Authenticate with no password user", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("no_password_user"), - resp.StringValue("password4"), - }, - wantRes: "OK", - wantErr: "", - }, - { - name: "5. Fail to authenticate with disabled user", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("disabled_user"), - resp.StringValue("password5"), - }, - wantRes: "", - wantErr: "Error user disabled_user is disabled", - }, - { - name: "6. Fail to authenticate with non-existent user", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("non_existent_user"), - resp.StringValue("password6"), - }, - wantRes: "", - wantErr: "Error no user with username non_existent_user", - }, - { - name: "7. Fail to authenticate with the wrong password", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("with_password_user"), - resp.StringValue("wrong_password"), - }, - wantRes: "", - wantErr: "Error could not authenticate user", - }, - { - name: "8. Command too short", - cmd: []resp.Value{resp.StringValue("AUTH")}, - wantRes: "", - wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), - }, - { - name: "9. Command too long", - cmd: []resp.Value{ - resp.StringValue("AUTH"), - resp.StringValue("user"), - resp.StringValue("password1"), - resp.StringValue("password2"), - }, - wantRes: "", - wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err = r.WriteArray(test.cmd); err != nil { - t.Error(err) - } - rv, _, err := r.ReadValue() - if err != nil { - t.Error(err) - } - if test.wantErr != "" { - if rv.Error().Error() != test.wantErr { - t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) - } - return - } - if rv.String() != test.wantRes { - t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String()) - } - }) - } - }) - t.Run("Test_Permissions", func(t *testing.T) { port, err := internal.GetFreePort() if err != nil { diff --git a/internal/modules/connection/commands.go b/internal/modules/connection/commands.go index 3bb4839..b3079a5 100644 --- a/internal/modules/connection/commands.go +++ b/internal/modules/connection/commands.go @@ -25,6 +25,23 @@ import ( "github.com/echovault/echovault/internal/constants" ) +func handleAuth(params internal.HandlerFuncParams) ([]byte, error) { + if len(params.Command) < 2 || len(params.Command) > 3 { + return nil, errors.New(constants.WrongArgsResponse) + } + accessControlList, ok := params.GetACL().(*acl.ACL) + if !ok { + return nil, errors.New("could not load ACL") + } + accessControlList.LockUsers() + defer accessControlList.UnlockUsers() + + if err := accessControlList.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil { + return nil, err + } + return []byte(constants.OkResponse), nil +} + func handlePing(params internal.HandlerFuncParams) ([]byte, error) { switch len(params.Command) { default: @@ -112,6 +129,23 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) { func Commands() []internal.Command { return []internal.Command{ + { + Command: "auth", + Module: constants.ConnectionModule, + Categories: []string{constants.ConnectionCategory, constants.SlowCategory}, + Description: `(AUTH [username] password) +Authenticates the connection. If the username is not provided, the connection will be authenticated against the +default ACL user. Otherwise, it is authenticated against the ACL user with the provided username.`, + Sync: false, + KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) { + return internal.KeyExtractionFuncResult{ + Channels: make([]string, 0), + ReadKeys: make([]string, 0), + WriteKeys: make([]string, 0), + }, nil + }, + HandlerFunc: handleAuth, + }, { Command: "ping", Module: constants.ConnectionModule, diff --git a/internal/modules/connection/commands_test.go b/internal/modules/connection/commands_test.go index d4e77f8..f175ef3 100644 --- a/internal/modules/connection/commands_test.go +++ b/internal/modules/connection/commands_test.go @@ -15,7 +15,10 @@ package connection_test import ( + "crypto/sha256" + "encoding/hex" "errors" + "fmt" "strings" "testing" @@ -26,6 +29,67 @@ import ( "github.com/tidwall/resp" ) +func setUpServer(port int, requirePass bool, aclConfig string) (*echovault.EchoVault, error) { + conf := config.Config{ + BindAddr: "localhost", + Port: uint16(port), + DataDir: "", + EvictionPolicy: constants.NoEviction, + RequirePass: requirePass, + Password: "password1", + AclConfig: aclConfig, + } + + mockServer, err := echovault.NewEchoVault( + echovault.WithConfig(conf), + ) + if err != nil { + return nil, err + } + + // Add the initial test users to the ACL module. + for _, user := range generateInitialTestUsers() { + if _, err := mockServer.ACLSetUser(user); err != nil { + return nil, err + } + } + + return mockServer, nil +} + +func generateInitialTestUsers() []echovault.User { + return []echovault.User{ + { + // User with both hash password and plaintext password. + Username: "with_password_user", + Enabled: true, + IncludeCategories: []string{"*"}, + IncludeCommands: []string{"*"}, + AddPlainPasswords: []string{"password2"}, + AddHashPasswords: []string{generateSHA256Password("password3")}, + }, + { + // User with NoPassword option. + Username: "no_password_user", + Enabled: true, + NoPassword: true, + AddPlainPasswords: []string{"password4"}, + }, + { + // Disabled user. + Username: "disabled_user", + Enabled: false, + AddPlainPasswords: []string{"password5"}, + }, + } +} + +func generateSHA256Password(plain string) string { + h := sha256.New() + h.Write([]byte(plain)) + return hex.EncodeToString(h.Sum(nil)) +} + func Test_Connection(t *testing.T) { port, err := internal.GetFreePort() if err != nil { @@ -33,14 +97,7 @@ func Test_Connection(t *testing.T) { return } - mockServer, err := echovault.NewEchoVault( - echovault.WithConfig(config.Config{ - DataDir: "", - EvictionPolicy: constants.NoEviction, - BindAddr: "localhost", - Port: uint16(port), - }), - ) + mockServer, err := setUpServer(port, true, "") if err != nil { t.Error(err) return @@ -54,6 +111,135 @@ func Test_Connection(t *testing.T) { mockServer.ShutDown() }) + t.Run("Test_HandleAuth", func(t *testing.T) { + t.Parallel() + + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + if conn != nil { + _ = conn.Close() + } + }() + + r := resp.NewConn(conn) + + tests := []struct { + name string + cmd []resp.Value + wantRes string + wantErr string + }{ + { + name: "1. Authenticate with default user without specifying username", + cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}, + wantRes: "OK", + wantErr: "", + }, + { + name: "2. Authenticate with plaintext password", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("with_password_user"), + resp.StringValue("password2"), + }, + wantRes: "OK", + wantErr: "", + }, + { + name: "3. Authenticate with SHA256 password", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("with_password_user"), + resp.StringValue("password3"), + }, + wantRes: "OK", + wantErr: "", + }, + { + name: "4. Authenticate with no password user", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("no_password_user"), + resp.StringValue("password4"), + }, + wantRes: "OK", + wantErr: "", + }, + { + name: "5. Fail to authenticate with disabled user", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("disabled_user"), + resp.StringValue("password5"), + }, + wantRes: "", + wantErr: "Error user disabled_user is disabled", + }, + { + name: "6. Fail to authenticate with non-existent user", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("non_existent_user"), + resp.StringValue("password6"), + }, + wantRes: "", + wantErr: "Error no user with username non_existent_user", + }, + { + name: "7. Fail to authenticate with the wrong password", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("with_password_user"), + resp.StringValue("wrong_password"), + }, + wantRes: "", + wantErr: "Error could not authenticate user", + }, + { + name: "8. Command too short", + cmd: []resp.Value{resp.StringValue("AUTH")}, + wantRes: "", + wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), + }, + { + name: "9. Command too long", + cmd: []resp.Value{ + resp.StringValue("AUTH"), + resp.StringValue("user"), + resp.StringValue("password1"), + resp.StringValue("password2"), + }, + wantRes: "", + wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if err = r.WriteArray(test.cmd); err != nil { + t.Error(err) + } + rv, _, err := r.ReadValue() + if err != nil { + t.Error(err) + } + if test.wantErr != "" { + if rv.Error().Error() != test.wantErr { + t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error()) + } + return + } + if rv.String() != test.wantRes { + t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String()) + } + }) + } + }) + t.Run("Test_HandlePing", func(t *testing.T) { conn, err := internal.GetConnection("localhost", port) if err != nil { @@ -115,66 +301,65 @@ func Test_Connection(t *testing.T) { } }) - t.Run("Test_HandleEcho", func(t *testing.T) { - conn, err := internal.GetConnection("localhost", port) - if err != nil { + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + command []resp.Value + expected string + expectedErr error + }{ + { + command: []resp.Value{resp.StringValue("ECHO"), resp.StringValue("Hello, EchoVault!")}, + expected: "Hello, EchoVault!", + expectedErr: nil, + }, + { + command: []resp.Value{resp.StringValue("ECHO")}, + expected: "", + expectedErr: errors.New(constants.WrongArgsResponse), + }, + { + command: []resp.Value{ + resp.StringValue("ECHO"), + resp.StringValue("Hello, EchoVault!"), + resp.StringValue("Once more"), + }, + expected: "", + expectedErr: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + if err = client.WriteArray(test.command); err != nil { t.Error(err) return } - defer func() { - _ = conn.Close() - }() - client := resp.NewConn(conn) - tests := []struct { - command []resp.Value - expected string - expectedErr error - }{ - { - command: []resp.Value{resp.StringValue("ECHO"), resp.StringValue("Hello, EchoVault!")}, - expected: "Hello, EchoVault!", - expectedErr: nil, - }, - { - command: []resp.Value{resp.StringValue("ECHO")}, - expected: "", - expectedErr: errors.New(constants.WrongArgsResponse), - }, - { - command: []resp.Value{ - resp.StringValue("ECHO"), - resp.StringValue("Hello, EchoVault!"), - resp.StringValue("Once more"), - }, - expected: "", - expectedErr: errors.New(constants.WrongArgsResponse), - }, + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - for _, test := range tests { - if err = client.WriteArray(test.command); err != nil { - t.Error(err) - return - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedErr != nil { - if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) - } - continue - } - - if res.String() != test.expected { - t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String()) + if test.expectedErr != nil { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) } + continue } - }) + + if res.String() != test.expected { + t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String()) + } + } + }) }