Implemented tests for ACL permissions

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-02 22:52:02 +08:00
parent 555387494b
commit 66b2842e11
3 changed files with 1397 additions and 1095 deletions

View File

@@ -178,12 +178,12 @@ func Test_ACL(t *testing.T) {
t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
defer func() {
if conn != nil {
_ = conn.Close()
@@ -193,16 +193,19 @@ func Test_ACL(t *testing.T) {
r := resp.NewConn(conn)
tests := []struct {
name string
cmd []resp.Value
wantRes string
wantErr string
}{
{ // 1. Authenticate with default user without specifying username
{
name: "1. Authenticate with default user without specifying username",
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK",
wantErr: "",
},
{ // 2. Authenticate with plaintext password
{
name: "2. Authenticate with plaintext password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
@@ -211,7 +214,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK",
wantErr: "",
},
{ // 3. Authenticate with SHA256 password
{
name: "3. Authenticate with SHA256 password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
@@ -220,7 +224,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK",
wantErr: "",
},
{ // 4. Authenticate with no password user
{
name: "4. Authenticate with no password user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("no_password_user"),
@@ -229,7 +234,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK",
wantErr: "",
},
{ // 5. Fail to authenticate with disabled user
{
name: "5. Fail to authenticate with disabled user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("disabled_user"),
@@ -238,7 +244,8 @@ func Test_ACL(t *testing.T) {
wantRes: "",
wantErr: "Error user disabled_user is disabled",
},
{ // 6. Fail to authenticate with non-existent user
{
name: "6. Fail to authenticate with non-existent user",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"),
@@ -247,12 +254,24 @@ func Test_ACL(t *testing.T) {
wantRes: "",
wantErr: "Error no user with username non_existent_user",
},
{ // 7. Command too short
{
name: "7. Fail to authenticate with the wrong password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("wrong_password"),
},
wantRes: "",
wantErr: "Error could not authenticate user",
},
{
name: "8. Command too short",
cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
},
{ // 8. Command too long
{
name: "9. Command too long",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("user"),
@@ -265,23 +284,279 @@ func Test_ACL(t *testing.T) {
}
for _, test := range tests {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
t.Run(test.name, func(t *testing.T) {
if err = r.WriteArray(test.cmd); err != nil {
t.Error(err)
}
continue
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
return
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
})
}
})
t.Run("Test_Permissions", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port, true, "")
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
t.Cleanup(func() {
_ = conn.Close()
mockServer.ShutDown()
})
// Add users to be used in test cases.
users := []echovault.User{
{
// User with nokeys flag enables.
Username: "test_nokeys",
Enabled: true,
NoKeys: true,
AddPlainPasswords: []string{"test_nokeys_password"},
},
{
// This use will be used to test authorization failure when trying to access resources that are not
// in their "included" rules.
Username: "test_included",
Enabled: true,
AddPlainPasswords: []string{"test_included_password"},
IncludeCategories: []string{
constants.WriteCategory,
constants.ReadCategory,
constants.SlowCategory,
constants.PubSubCategory,
constants.ConnectionCategory,
constants.ListCategory,
},
IncludeCommands: []string{"set", "get", "subscribe", "lrange", "ltrim"},
IncludeChannels: []string{"channel[12]"},
IncludeReadWriteKeys: []string{"key1", "key2"},
},
{
// This use will be used to test authorization failure when trying to access resources that are
// in their "excluded" rules.
Username: "test_excluded",
Enabled: true,
AddPlainPasswords: []string{"test_excluded_password"},
IncludeCategories: []string{"*"},
ExcludeCategories: []string{constants.FastCategory, constants.HashCategory},
IncludeCommands: []string{"*"},
ExcludeCommands: []string{"set", "mset"},
IncludeChannels: []string{"*"},
ExcludeChannels: []string{"channel[12]"},
},
}
for _, user := range users {
if _, err := mockServer.ACLSetUser(user); err != nil {
t.Error(err)
return
}
}
tests := []struct {
name string
auth []resp.Value
cmd []resp.Value
wantErr string
}{
{
name: "1. Return error when the connection is not authenticated",
auth: []resp.Value{},
cmd: []resp.Value{resp.StringValue("SET"), resp.StringValue("key"), resp.StringValue("value")},
wantErr: "user must be authenticated",
},
{
name: "2. Return error when the command category is not in the included categories section",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("HSET"),
resp.StringValue("hash"),
resp.StringValue("field1"),
resp.StringValue("value1"),
},
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s @%s]",
constants.FastCategory, constants.HashCategory),
},
{
name: "3. Return error when the command category is in the excluded categories section",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("HSET"),
resp.StringValue("hash"),
resp.StringValue("field1"),
resp.StringValue("value1"),
},
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s]",
constants.HashCategory),
},
{
name: "4. Return error when the command is not in the included command category",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("MSET"),
resp.StringValue("key1"),
resp.StringValue("value1"),
},
wantErr: "not authorised to run MSET command",
},
{
name: "5. Return error when the command is in the excluded command category",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("SET"),
resp.StringValue("key1"),
resp.StringValue("value1"),
},
wantErr: "not authorised to run SET command",
},
{
name: "6. Return error when subscribing to channel that's not in included channels",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("SUBSCRIBE"),
resp.StringValue("channel3"),
},
wantErr: "not authorised to access channel &channel3",
},
{
name: "7. Return error when publishing to channel that's in excluded channels",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("SUBSCRIBE"),
resp.StringValue("channel2"),
},
wantErr: "not authorised to access channel &channel2",
},
{
name: "8. Return error when the user has nokeys flag",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_nokeys"),
resp.StringValue("test_nokeys_password"),
},
cmd: []resp.Value{resp.StringValue("GET"), resp.StringValue("key1")},
},
{
name: "9. Return error when trying to read from keys that are not in read keys list",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("LRANGE"),
resp.StringValue("key3"),
resp.StringValue("0"),
resp.StringValue("-1"),
},
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"),
},
{
name: "10. Return error when trying to write to keys that are not in write keys list",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("LTRIM"),
resp.StringValue("key3"),
resp.StringValue("0"),
resp.StringValue("3"),
},
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Authenticate the user if the auth command is provided.
if len(test.auth) > 0 {
err := client.WriteArray(test.auth)
if err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected auth response to be OK, got \"%s\"", res.String())
}
}
if err := client.WriteArray(test.cmd); err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.Contains(res.Error().Error(), test.wantErr) {
t.Errorf("expected error to contain string \"%s\", got \"%s\"",
test.wantErr, res.Error().Error())
return
}
})
}
})
t.Run("Test_HandleCat", func(t *testing.T) {