Implemented tests for ACL permissions

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-02 22:52:02 +08:00
parent 555387494b
commit 66b2842e11
3 changed files with 1397 additions and 1095 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -184,13 +184,6 @@ func (acl *ACL) SetUser(cmd []string) error {
return nil return nil
} }
func (acl *ACL) AddUsers(users []*User) {
acl.LockUsers()
defer acl.UnlockUsers()
acl.Users = append(acl.Users, users...)
}
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error { func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
acl.LockUsers() acl.LockUsers()
defer acl.UnlockUsers() defer acl.UnlockUsers()
@@ -329,8 +322,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
return nil return nil
} }
// Skip connection // Skip PING
if strings.EqualFold(comm, "connection") { if strings.EqualFold(comm, "ping") {
return nil return nil
} }
@@ -352,21 +345,23 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
return errors.New("user must be authenticated") return errors.New("user must be authenticated")
} }
// 2. Check if all categories are in IncludedCategories
var notAllowed []string var notAllowed []string
if !slices.ContainsFunc(categories, func(category string) bool {
return slices.ContainsFunc(connection.User.IncludedCategories, func(includedCategory string) bool { // 2. Check if all categories are in IncludedCategories
if includedCategory == "*" || includedCategory == category { count := make(map[string]int, len(categories))
return true if !slices.Contains(connection.User.IncludedCategories, "*") {
} for _, category := range categories {
notAllowed = append(notAllowed, fmt.Sprintf("@%s", category)) count[category] = 0
return false }
}) for _, category := range connection.User.IncludedCategories {
}) { if _, ok := count[category]; ok {
if len(notAllowed) == 0 { count[category] += 1
notAllowed = []string{"@all"} }
}
notAllowed = getUnauthorized(count, "@")
if len(notAllowed) > 0 {
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
} }
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
} }
// 3. Check if commands category is in ExcludedCategories // 3. Check if commands category is in ExcludedCategories
@@ -386,14 +381,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool { if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool {
return includedCommand == "*" || includedCommand == comm return includedCommand == "*" || includedCommand == comm
}) { }) {
return fmt.Errorf("not authorised to run %s command", comm) return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
} }
// 5. Check if command are in ExcludedCommands // 5. Check if command are in ExcludedCommands
if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool { if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool {
return excludedCommand == "*" || excludedCommand == comm return excludedCommand == "*" || excludedCommand == comm
}) { }) {
return fmt.Errorf("not authorised to run %s command", comm) return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
} }
// 6. PUBSUB authorisation. // 6. PUBSUB authorisation.
@@ -428,24 +423,32 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
if acl.GlobPatterns[readKeyGlob].Match(key) { if acl.GlobPatterns[readKeyGlob].Match(key) {
return true return true
} }
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key)) if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%R", key)) {
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
}
return false return false
}) })
}) { }) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed) if len(notAllowed) > 0 {
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
}
} }
// 9. Check if keys are in IncludedWriteKeys // 9. Check if write keys are in IncludedWriteKeys
fmt.Println("KEYS: ", writeKeys)
fmt.Println("ALLOWED KEYS: ", connection.User.IncludedWriteKeys)
if !slices.ContainsFunc(writeKeys, func(key string) bool { if !slices.ContainsFunc(writeKeys, func(key string) bool {
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool { return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
if acl.GlobPatterns[writeKeyGlob].Match(key) { if acl.GlobPatterns[writeKeyGlob].Match(key) {
return true return true
} }
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key)) if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%W", key)) {
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
}
return false return false
}) })
}) { }) {
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed) return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
} }
} }
@@ -491,3 +494,17 @@ func (acl *ACL) RLockUsers() {
func (acl *ACL) RUnlockUsers() { func (acl *ACL) RUnlockUsers() {
acl.UsersMutex.RUnlock() acl.UsersMutex.RUnlock()
} }
func getUnauthorized(count map[string]int, prefix string) []string {
var notAllowed []string
for member, c := range count {
if c == 0 {
notAllowed = append(notAllowed, fmt.Sprintf("%s%s", prefix, member))
}
}
// Sort the members in alphabetical order.
slices.SortStableFunc(notAllowed, func(a, b string) int {
return internal.CompareLex(a, b)
})
return notAllowed
}

View File

@@ -178,12 +178,12 @@ func Test_ACL(t *testing.T) {
t.Run("Test_HandleAuth", func(t *testing.T) { t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel() t.Parallel()
conn, err := internal.GetConnection("localhost", port) conn, err := internal.GetConnection("localhost", port)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() { defer func() {
if conn != nil { if conn != nil {
_ = conn.Close() _ = conn.Close()
@@ -193,16 +193,19 @@ func Test_ACL(t *testing.T) {
r := resp.NewConn(conn) r := resp.NewConn(conn)
tests := []struct { tests := []struct {
name string
cmd []resp.Value cmd []resp.Value
wantRes string wantRes string
wantErr string wantErr string
}{ }{
{ // 1. Authenticate with default user without specifying username {
name: "1. Authenticate with default user without specifying username",
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}, cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
}, },
{ // 2. Authenticate with plaintext password {
name: "2. Authenticate with plaintext password",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("with_password_user"), resp.StringValue("with_password_user"),
@@ -211,7 +214,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
}, },
{ // 3. Authenticate with SHA256 password {
name: "3. Authenticate with SHA256 password",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("with_password_user"), resp.StringValue("with_password_user"),
@@ -220,7 +224,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
}, },
{ // 4. Authenticate with no password user {
name: "4. Authenticate with no password user",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("no_password_user"), resp.StringValue("no_password_user"),
@@ -229,7 +234,8 @@ func Test_ACL(t *testing.T) {
wantRes: "OK", wantRes: "OK",
wantErr: "", wantErr: "",
}, },
{ // 5. Fail to authenticate with disabled user {
name: "5. Fail to authenticate with disabled user",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("disabled_user"), resp.StringValue("disabled_user"),
@@ -238,7 +244,8 @@ func Test_ACL(t *testing.T) {
wantRes: "", wantRes: "",
wantErr: "Error user disabled_user is disabled", wantErr: "Error user disabled_user is disabled",
}, },
{ // 6. Fail to authenticate with non-existent user {
name: "6. Fail to authenticate with non-existent user",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("non_existent_user"), resp.StringValue("non_existent_user"),
@@ -247,12 +254,24 @@ func Test_ACL(t *testing.T) {
wantRes: "", wantRes: "",
wantErr: "Error no user with username non_existent_user", wantErr: "Error no user with username non_existent_user",
}, },
{ // 7. Command too short {
name: "7. Fail to authenticate with the wrong password",
cmd: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("with_password_user"),
resp.StringValue("wrong_password"),
},
wantRes: "",
wantErr: "Error could not authenticate user",
},
{
name: "8. Command too short",
cmd: []resp.Value{resp.StringValue("AUTH")}, cmd: []resp.Value{resp.StringValue("AUTH")},
wantRes: "", wantRes: "",
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse), wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
}, },
{ // 8. Command too long {
name: "9. Command too long",
cmd: []resp.Value{ cmd: []resp.Value{
resp.StringValue("AUTH"), resp.StringValue("AUTH"),
resp.StringValue("user"), resp.StringValue("user"),
@@ -265,23 +284,279 @@ func Test_ACL(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
if err = r.WriteArray(test.cmd); err != nil { t.Run(test.name, func(t *testing.T) {
t.Error(err) if err = r.WriteArray(test.cmd); err != nil {
} t.Error(err)
rv, _, err := r.ReadValue()
if err != nil {
t.Error(err)
}
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
} }
continue rv, _, err := r.ReadValue()
} if err != nil {
if rv.String() != test.wantRes { t.Error(err)
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String()) }
if test.wantErr != "" {
if rv.Error().Error() != test.wantErr {
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
}
return
}
if rv.String() != test.wantRes {
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
}
})
}
})
t.Run("Test_Permissions", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setUpServer(port, true, "")
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
return
}
client := resp.NewConn(conn)
t.Cleanup(func() {
_ = conn.Close()
mockServer.ShutDown()
})
// Add users to be used in test cases.
users := []echovault.User{
{
// User with nokeys flag enables.
Username: "test_nokeys",
Enabled: true,
NoKeys: true,
AddPlainPasswords: []string{"test_nokeys_password"},
},
{
// This use will be used to test authorization failure when trying to access resources that are not
// in their "included" rules.
Username: "test_included",
Enabled: true,
AddPlainPasswords: []string{"test_included_password"},
IncludeCategories: []string{
constants.WriteCategory,
constants.ReadCategory,
constants.SlowCategory,
constants.PubSubCategory,
constants.ConnectionCategory,
constants.ListCategory,
},
IncludeCommands: []string{"set", "get", "subscribe", "lrange", "ltrim"},
IncludeChannels: []string{"channel[12]"},
IncludeReadWriteKeys: []string{"key1", "key2"},
},
{
// This use will be used to test authorization failure when trying to access resources that are
// in their "excluded" rules.
Username: "test_excluded",
Enabled: true,
AddPlainPasswords: []string{"test_excluded_password"},
IncludeCategories: []string{"*"},
ExcludeCategories: []string{constants.FastCategory, constants.HashCategory},
IncludeCommands: []string{"*"},
ExcludeCommands: []string{"set", "mset"},
IncludeChannels: []string{"*"},
ExcludeChannels: []string{"channel[12]"},
},
}
for _, user := range users {
if _, err := mockServer.ACLSetUser(user); err != nil {
t.Error(err)
return
} }
} }
tests := []struct {
name string
auth []resp.Value
cmd []resp.Value
wantErr string
}{
{
name: "1. Return error when the connection is not authenticated",
auth: []resp.Value{},
cmd: []resp.Value{resp.StringValue("SET"), resp.StringValue("key"), resp.StringValue("value")},
wantErr: "user must be authenticated",
},
{
name: "2. Return error when the command category is not in the included categories section",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("HSET"),
resp.StringValue("hash"),
resp.StringValue("field1"),
resp.StringValue("value1"),
},
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s @%s]",
constants.FastCategory, constants.HashCategory),
},
{
name: "3. Return error when the command category is in the excluded categories section",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("HSET"),
resp.StringValue("hash"),
resp.StringValue("field1"),
resp.StringValue("value1"),
},
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s]",
constants.HashCategory),
},
{
name: "4. Return error when the command is not in the included command category",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("MSET"),
resp.StringValue("key1"),
resp.StringValue("value1"),
},
wantErr: "not authorised to run MSET command",
},
{
name: "5. Return error when the command is in the excluded command category",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("SET"),
resp.StringValue("key1"),
resp.StringValue("value1"),
},
wantErr: "not authorised to run SET command",
},
{
name: "6. Return error when subscribing to channel that's not in included channels",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("SUBSCRIBE"),
resp.StringValue("channel3"),
},
wantErr: "not authorised to access channel &channel3",
},
{
name: "7. Return error when publishing to channel that's in excluded channels",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_excluded"),
resp.StringValue("test_excluded_password"),
},
cmd: []resp.Value{
resp.StringValue("SUBSCRIBE"),
resp.StringValue("channel2"),
},
wantErr: "not authorised to access channel &channel2",
},
{
name: "8. Return error when the user has nokeys flag",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_nokeys"),
resp.StringValue("test_nokeys_password"),
},
cmd: []resp.Value{resp.StringValue("GET"), resp.StringValue("key1")},
},
{
name: "9. Return error when trying to read from keys that are not in read keys list",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("LRANGE"),
resp.StringValue("key3"),
resp.StringValue("0"),
resp.StringValue("-1"),
},
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"),
},
{
name: "10. Return error when trying to write to keys that are not in write keys list",
auth: []resp.Value{
resp.StringValue("AUTH"),
resp.StringValue("test_included"),
resp.StringValue("test_included_password"),
},
cmd: []resp.Value{
resp.StringValue("LTRIM"),
resp.StringValue("key3"),
resp.StringValue("0"),
resp.StringValue("3"),
},
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Authenticate the user if the auth command is provided.
if len(test.auth) > 0 {
err := client.WriteArray(test.auth)
if err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected auth response to be OK, got \"%s\"", res.String())
}
}
if err := client.WriteArray(test.cmd); err != nil {
t.Error(err)
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
return
}
if !strings.Contains(res.Error().Error(), test.wantErr) {
t.Errorf("expected error to contain string \"%s\", got \"%s\"",
test.wantErr, res.Error().Error())
return
}
})
}
}) })
t.Run("Test_HandleCat", func(t *testing.T) { t.Run("Test_HandleCat", func(t *testing.T) {