Moved AUTH command from the acl module to the connection module. Added echo command to the list of commands that are skipped on ACL authorization.

This commit is contained in:
Kelvin Mwinuka
2024-06-24 04:16:25 +08:00
parent 93a165e9f9
commit 21aabda04d
5 changed files with 281 additions and 225 deletions

View File

@@ -325,8 +325,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
return nil
}
// Skip PING
if strings.EqualFold(comm, "ping") {
// Skip certain commands from authorization
if slices.Contains([]string{"ping", "echo"}, strings.ToLower(comm)) {
return nil
}

View File

@@ -28,23 +28,6 @@ import (
"strings"
)
func handleAuth(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
acl, ok := params.GetACL().(*ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
acl.LockUsers()
defer acl.UnlockUsers()
if err := acl.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handleCat(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
@@ -496,23 +479,6 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
func Commands() []internal.Command {
return []internal.Command{
{
Command: "auth",
Module: constants.ACLModule,
Categories: []string{constants.ConnectionCategory, constants.SlowCategory},
Description: `(AUTH [username] password)
Authenticates the connection. If the username is not provided, the connection will be authenticated against the
default ACL user. Otherwise, it is authenticated against the ACL user with the provided username.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) {
return internal.KeyExtractionFuncResult{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleAuth,
},
{
Command: "acl",
Module: constants.ACLModule,

View File

@@ -176,135 +176,6 @@ func Test_ACL(t *testing.T) {
mockServer.ShutDown()
})
t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
defer func() {
if conn != nil {
_ = conn.Close()
}
}()
r := resp.NewConn(conn)
tests := []struct {
name string
cmd []resp.Value
wantRes string
wantErr string
}{
{
name: "1. Authenticate with default user without specifying username",
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK",
wantErr: "",
},
{
name: "2. Authenticate with plaintext password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password2"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "3. Authenticate with SHA256 password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password3"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "4. Authenticate with no password user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("no_password_user"),
resp.StringValue("password4"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "5. Fail to authenticate with disabled user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("disabled_user"),
resp.StringValue("password5"),
},
wantRes: "",
wantErr: "Error user disabled_user is disabled",
},
{
name: "6. Fail to authenticate with non-existent user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"),
resp.StringValue("password6"),
},
wantRes: "",
wantErr: "Error no user with username non_existent_user",
},
{
name: "7. Fail to authenticate with the wrong password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("wrong_password"),
},
wantRes: "",
wantErr: "Error could not authenticate user",
},
{
name: "8. Command too short",
cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
},
{
name: "9. Command too long",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("user"),
resp.StringValue("password1"),
resp.StringValue("password2"),
},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
return
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
})
}
})
t.Run("Test_Permissions", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {

View File

@@ -25,6 +25,23 @@ import (
"github.com/echovault/echovault/internal/constants"
)
func handleAuth(params internal.HandlerFuncParams) ([]byte, error) {
if len(params.Command) < 2 || len(params.Command) > 3 {
return nil, errors.New(constants.WrongArgsResponse)
}
accessControlList, ok := params.GetACL().(*acl.ACL)
if !ok {
return nil, errors.New("could not load ACL")
}
accessControlList.LockUsers()
defer accessControlList.UnlockUsers()
if err := accessControlList.AuthenticateConnection(params.Context, params.Connection, params.Command); err != nil {
return nil, err
}
return []byte(constants.OkResponse), nil
}
func handlePing(params internal.HandlerFuncParams) ([]byte, error) {
switch len(params.Command) {
default:
@@ -112,6 +129,23 @@ func handleHello(params internal.HandlerFuncParams) ([]byte, error) {
func Commands() []internal.Command {
return []internal.Command{
{
Command: "auth",
Module: constants.ConnectionModule,
Categories: []string{constants.ConnectionCategory, constants.SlowCategory},
Description: `(AUTH [username] password)
Authenticates the connection. If the username is not provided, the connection will be authenticated against the
default ACL user. Otherwise, it is authenticated against the ACL user with the provided username.`,
Sync: false,
KeyExtractionFunc: func(cmd []string) (internal.KeyExtractionFuncResult, error) {
return internal.KeyExtractionFuncResult{
Channels: make([]string, 0),
ReadKeys: make([]string, 0),
WriteKeys: make([]string, 0),
}, nil
},
HandlerFunc: handleAuth,
},
{
Command: "ping",
Module: constants.ConnectionModule,

View File

@@ -15,7 +15,10 @@
package connection_test
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"strings"
"testing"
@@ -26,6 +29,67 @@ import (
"github.com/tidwall/resp"
)
func setUpServer(port int, requirePass bool, aclConfig string) (*echovault.EchoVault, error) {
conf := config.Config{
BindAddr: "localhost",
Port: uint16(port),
DataDir: "",
EvictionPolicy: constants.NoEviction,
RequirePass: requirePass,
Password: "password1",
AclConfig: aclConfig,
}
mockServer, err := echovault.NewEchoVault(
echovault.WithConfig(conf),
)
if err != nil {
return nil, err
}
// Add the initial test users to the ACL module.
for _, user := range generateInitialTestUsers() {
if _, err := mockServer.ACLSetUser(user); err != nil {
return nil, err
}
}
return mockServer, nil
}
func generateInitialTestUsers() []echovault.User {
return []echovault.User{
{
// User with both hash password and plaintext password.
Username: "with_password_user",
Enabled: true,
IncludeCategories: []string{"*"},
IncludeCommands: []string{"*"},
AddPlainPasswords: []string{"password2"},
AddHashPasswords: []string{generateSHA256Password("password3")},
},
{
// User with NoPassword option.
Username: "no_password_user",
Enabled: true,
NoPassword: true,
AddPlainPasswords: []string{"password4"},
},
{
// Disabled user.
Username: "disabled_user",
Enabled: false,
AddPlainPasswords: []string{"password5"},
},
}
}
func generateSHA256Password(plain string) string {
h := sha256.New()
h.Write([]byte(plain))
return hex.EncodeToString(h.Sum(nil))
}
func Test_Connection(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
@@ -33,14 +97,7 @@ func Test_Connection(t *testing.T) {
return
}
mockServer, err := echovault.NewEchoVault(
echovault.WithConfig(config.Config{
DataDir: "",
EvictionPolicy: constants.NoEviction,
BindAddr: "localhost",
Port: uint16(port),
}),
)
mockServer, err := setUpServer(port, true, "")
if err != nil {
t.Error(err)
return
@@ -54,6 +111,135 @@ func Test_Connection(t *testing.T) {
mockServer.ShutDown()
})
t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
defer func() {
if conn != nil {
_ = conn.Close()
}
}()
r := resp.NewConn(conn)
tests := []struct {
name string
cmd []resp.Value
wantRes string
wantErr string
}{
{
name: "1. Authenticate with default user without specifying username",
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK",
wantErr: "",
},
{
name: "2. Authenticate with plaintext password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password2"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "3. Authenticate with SHA256 password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("password3"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "4. Authenticate with no password user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("no_password_user"),
resp.StringValue("password4"),
},
wantRes: "OK",
wantErr: "",
},
{
name: "5. Fail to authenticate with disabled user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("disabled_user"),
resp.StringValue("password5"),
},
wantRes: "",
wantErr: "Error user disabled_user is disabled",
},
{
name: "6. Fail to authenticate with non-existent user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"),
resp.StringValue("password6"),
},
wantRes: "",
wantErr: "Error no user with username non_existent_user",
},
{
name: "7. Fail to authenticate with the wrong password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("wrong_password"),
},
wantRes: "",
wantErr: "Error could not authenticate user",
},
{
name: "8. Command too short",
cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
},
{
name: "9. Command too long",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("user"),
resp.StringValue("password1"),
resp.StringValue("password2"),
},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
return
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
})
}
})
t.Run("Test_HandlePing", func(t *testing.T) {
conn, err := internal.GetConnection("localhost", port)
if err != nil {
@@ -115,7 +301,6 @@ func Test_Connection(t *testing.T) {
}
})
t.Run("Test_HandleEcho", func(t *testing.T) {
conn, err := internal.GetConnection("localhost", port)
if err != nil {