mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 16:30:21 +08:00
Implemented tests for ACL permissions
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -184,13 +184,6 @@ func (acl *ACL) SetUser(cmd []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (acl *ACL) AddUsers(users []*User) {
|
||||
acl.LockUsers()
|
||||
defer acl.UnlockUsers()
|
||||
|
||||
acl.Users = append(acl.Users, users...)
|
||||
}
|
||||
|
||||
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
||||
acl.LockUsers()
|
||||
defer acl.UnlockUsers()
|
||||
@@ -329,8 +322,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip connection
|
||||
if strings.EqualFold(comm, "connection") {
|
||||
// Skip PING
|
||||
if strings.EqualFold(comm, "ping") {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -352,22 +345,24 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
||||
return errors.New("user must be authenticated")
|
||||
}
|
||||
|
||||
// 2. Check if all categories are in IncludedCategories
|
||||
var notAllowed []string
|
||||
if !slices.ContainsFunc(categories, func(category string) bool {
|
||||
return slices.ContainsFunc(connection.User.IncludedCategories, func(includedCategory string) bool {
|
||||
if includedCategory == "*" || includedCategory == category {
|
||||
return true
|
||||
|
||||
// 2. Check if all categories are in IncludedCategories
|
||||
count := make(map[string]int, len(categories))
|
||||
if !slices.Contains(connection.User.IncludedCategories, "*") {
|
||||
for _, category := range categories {
|
||||
count[category] = 0
|
||||
}
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("@%s", category))
|
||||
return false
|
||||
})
|
||||
}) {
|
||||
if len(notAllowed) == 0 {
|
||||
notAllowed = []string{"@all"}
|
||||
for _, category := range connection.User.IncludedCategories {
|
||||
if _, ok := count[category]; ok {
|
||||
count[category] += 1
|
||||
}
|
||||
}
|
||||
notAllowed = getUnauthorized(count, "@")
|
||||
if len(notAllowed) > 0 {
|
||||
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check if commands category is in ExcludedCategories
|
||||
if slices.ContainsFunc(categories, func(category string) bool {
|
||||
@@ -386,14 +381,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
||||
if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool {
|
||||
return includedCommand == "*" || includedCommand == comm
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to run %s command", comm)
|
||||
return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
|
||||
}
|
||||
|
||||
// 5. Check if command are in ExcludedCommands
|
||||
if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool {
|
||||
return excludedCommand == "*" || excludedCommand == comm
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to run %s command", comm)
|
||||
return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
|
||||
}
|
||||
|
||||
// 6. PUBSUB authorisation.
|
||||
@@ -428,24 +423,32 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
||||
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
||||
return true
|
||||
}
|
||||
if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%R", key)) {
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
||||
}
|
||||
return false
|
||||
})
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
|
||||
if len(notAllowed) > 0 {
|
||||
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// 9. Check if keys are in IncludedWriteKeys
|
||||
// 9. Check if write keys are in IncludedWriteKeys
|
||||
fmt.Println("KEYS: ", writeKeys)
|
||||
fmt.Println("ALLOWED KEYS: ", connection.User.IncludedWriteKeys)
|
||||
if !slices.ContainsFunc(writeKeys, func(key string) bool {
|
||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
||||
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
||||
return true
|
||||
}
|
||||
if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%W", key)) {
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
||||
}
|
||||
return false
|
||||
})
|
||||
}) {
|
||||
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
|
||||
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,3 +494,17 @@ func (acl *ACL) RLockUsers() {
|
||||
func (acl *ACL) RUnlockUsers() {
|
||||
acl.UsersMutex.RUnlock()
|
||||
}
|
||||
|
||||
func getUnauthorized(count map[string]int, prefix string) []string {
|
||||
var notAllowed []string
|
||||
for member, c := range count {
|
||||
if c == 0 {
|
||||
notAllowed = append(notAllowed, fmt.Sprintf("%s%s", prefix, member))
|
||||
}
|
||||
}
|
||||
// Sort the members in alphabetical order.
|
||||
slices.SortStableFunc(notAllowed, func(a, b string) int {
|
||||
return internal.CompareLex(a, b)
|
||||
})
|
||||
return notAllowed
|
||||
}
|
||||
|
||||
@@ -178,12 +178,12 @@ func Test_ACL(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
@@ -193,16 +193,19 @@ func Test_ACL(t *testing.T) {
|
||||
r := resp.NewConn(conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd []resp.Value
|
||||
wantRes string
|
||||
wantErr string
|
||||
}{
|
||||
{ // 1. Authenticate with default user without specifying username
|
||||
{
|
||||
name: "1. Authenticate with default user without specifying username",
|
||||
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 2. Authenticate with plaintext password
|
||||
{
|
||||
name: "2. Authenticate with plaintext password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
@@ -211,7 +214,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 3. Authenticate with SHA256 password
|
||||
{
|
||||
name: "3. Authenticate with SHA256 password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
@@ -220,7 +224,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 4. Authenticate with no password user
|
||||
{
|
||||
name: "4. Authenticate with no password user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("no_password_user"),
|
||||
@@ -229,7 +234,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 5. Fail to authenticate with disabled user
|
||||
{
|
||||
name: "5. Fail to authenticate with disabled user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("disabled_user"),
|
||||
@@ -238,7 +244,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "",
|
||||
wantErr: "Error user disabled_user is disabled",
|
||||
},
|
||||
{ // 6. Fail to authenticate with non-existent user
|
||||
{
|
||||
name: "6. Fail to authenticate with non-existent user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("non_existent_user"),
|
||||
@@ -247,12 +254,24 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "",
|
||||
wantErr: "Error no user with username non_existent_user",
|
||||
},
|
||||
{ // 7. Command too short
|
||||
{
|
||||
name: "7. Fail to authenticate with the wrong password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
resp.StringValue("wrong_password"),
|
||||
},
|
||||
wantRes: "",
|
||||
wantErr: "Error could not authenticate user",
|
||||
},
|
||||
{
|
||||
name: "8. Command too short",
|
||||
cmd: []resp.Value{resp.StringValue("AUTH")},
|
||||
wantRes: "",
|
||||
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
||||
},
|
||||
{ // 8. Command too long
|
||||
{
|
||||
name: "9. Command too long",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("user"),
|
||||
@@ -265,6 +284,7 @@ func Test_ACL(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if err = r.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@@ -276,11 +296,266 @@ func Test_ACL(t *testing.T) {
|
||||
if rv.Error().Error() != test.wantErr {
|
||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||
}
|
||||
continue
|
||||
return
|
||||
}
|
||||
if rv.String() != test.wantRes {
|
||||
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_Permissions", func(t *testing.T) {
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
mockServer, err := setUpServer(port, true, "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
|
||||
// Add users to be used in test cases.
|
||||
users := []echovault.User{
|
||||
{
|
||||
// User with nokeys flag enables.
|
||||
Username: "test_nokeys",
|
||||
Enabled: true,
|
||||
NoKeys: true,
|
||||
AddPlainPasswords: []string{"test_nokeys_password"},
|
||||
},
|
||||
{
|
||||
// This use will be used to test authorization failure when trying to access resources that are not
|
||||
// in their "included" rules.
|
||||
Username: "test_included",
|
||||
Enabled: true,
|
||||
AddPlainPasswords: []string{"test_included_password"},
|
||||
IncludeCategories: []string{
|
||||
constants.WriteCategory,
|
||||
constants.ReadCategory,
|
||||
constants.SlowCategory,
|
||||
constants.PubSubCategory,
|
||||
constants.ConnectionCategory,
|
||||
constants.ListCategory,
|
||||
},
|
||||
IncludeCommands: []string{"set", "get", "subscribe", "lrange", "ltrim"},
|
||||
IncludeChannels: []string{"channel[12]"},
|
||||
IncludeReadWriteKeys: []string{"key1", "key2"},
|
||||
},
|
||||
{
|
||||
// This use will be used to test authorization failure when trying to access resources that are
|
||||
// in their "excluded" rules.
|
||||
Username: "test_excluded",
|
||||
Enabled: true,
|
||||
AddPlainPasswords: []string{"test_excluded_password"},
|
||||
IncludeCategories: []string{"*"},
|
||||
ExcludeCategories: []string{constants.FastCategory, constants.HashCategory},
|
||||
IncludeCommands: []string{"*"},
|
||||
ExcludeCommands: []string{"set", "mset"},
|
||||
IncludeChannels: []string{"*"},
|
||||
ExcludeChannels: []string{"channel[12]"},
|
||||
},
|
||||
}
|
||||
for _, user := range users {
|
||||
if _, err := mockServer.ACLSetUser(user); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
auth []resp.Value
|
||||
cmd []resp.Value
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "1. Return error when the connection is not authenticated",
|
||||
auth: []resp.Value{},
|
||||
cmd: []resp.Value{resp.StringValue("SET"), resp.StringValue("key"), resp.StringValue("value")},
|
||||
wantErr: "user must be authenticated",
|
||||
},
|
||||
{
|
||||
name: "2. Return error when the command category is not in the included categories section",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("HSET"),
|
||||
resp.StringValue("hash"),
|
||||
resp.StringValue("field1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s @%s]",
|
||||
constants.FastCategory, constants.HashCategory),
|
||||
},
|
||||
{
|
||||
name: "3. Return error when the command category is in the excluded categories section",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("HSET"),
|
||||
resp.StringValue("hash"),
|
||||
resp.StringValue("field1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s]",
|
||||
constants.HashCategory),
|
||||
},
|
||||
{
|
||||
name: "4. Return error when the command is not in the included command category",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("MSET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: "not authorised to run MSET command",
|
||||
},
|
||||
{
|
||||
name: "5. Return error when the command is in the excluded command category",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: "not authorised to run SET command",
|
||||
},
|
||||
{
|
||||
name: "6. Return error when subscribing to channel that's not in included channels",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SUBSCRIBE"),
|
||||
resp.StringValue("channel3"),
|
||||
},
|
||||
wantErr: "not authorised to access channel &channel3",
|
||||
},
|
||||
{
|
||||
name: "7. Return error when publishing to channel that's in excluded channels",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SUBSCRIBE"),
|
||||
resp.StringValue("channel2"),
|
||||
},
|
||||
wantErr: "not authorised to access channel &channel2",
|
||||
},
|
||||
{
|
||||
name: "8. Return error when the user has nokeys flag",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_nokeys"),
|
||||
resp.StringValue("test_nokeys_password"),
|
||||
},
|
||||
cmd: []resp.Value{resp.StringValue("GET"), resp.StringValue("key1")},
|
||||
},
|
||||
{
|
||||
name: "9. Return error when trying to read from keys that are not in read keys list",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("LRANGE"),
|
||||
resp.StringValue("key3"),
|
||||
resp.StringValue("0"),
|
||||
resp.StringValue("-1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"),
|
||||
},
|
||||
{
|
||||
name: "10. Return error when trying to write to keys that are not in write keys list",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("LTRIM"),
|
||||
resp.StringValue("key3"),
|
||||
resp.StringValue("0"),
|
||||
resp.StringValue("3"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Authenticate the user if the auth command is provided.
|
||||
if len(test.auth) > 0 {
|
||||
err := client.WriteArray(test.auth)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected auth response to be OK, got \"%s\"", res.String())
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(res.Error().Error(), test.wantErr) {
|
||||
t.Errorf("expected error to contain string \"%s\", got \"%s\"",
|
||||
test.wantErr, res.Error().Error())
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user