mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-25 08:50:24 +08:00
Implemented tests for ACL permissions
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -184,13 +184,6 @@ func (acl *ACL) SetUser(cmd []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acl *ACL) AddUsers(users []*User) {
|
|
||||||
acl.LockUsers()
|
|
||||||
defer acl.UnlockUsers()
|
|
||||||
|
|
||||||
acl.Users = append(acl.Users, users...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
func (acl *ACL) DeleteUser(_ context.Context, usernames []string) error {
|
||||||
acl.LockUsers()
|
acl.LockUsers()
|
||||||
defer acl.UnlockUsers()
|
defer acl.UnlockUsers()
|
||||||
@@ -329,8 +322,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip connection
|
// Skip PING
|
||||||
if strings.EqualFold(comm, "connection") {
|
if strings.EqualFold(comm, "ping") {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,21 +345,23 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
return errors.New("user must be authenticated")
|
return errors.New("user must be authenticated")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Check if all categories are in IncludedCategories
|
|
||||||
var notAllowed []string
|
var notAllowed []string
|
||||||
if !slices.ContainsFunc(categories, func(category string) bool {
|
|
||||||
return slices.ContainsFunc(connection.User.IncludedCategories, func(includedCategory string) bool {
|
// 2. Check if all categories are in IncludedCategories
|
||||||
if includedCategory == "*" || includedCategory == category {
|
count := make(map[string]int, len(categories))
|
||||||
return true
|
if !slices.Contains(connection.User.IncludedCategories, "*") {
|
||||||
}
|
for _, category := range categories {
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("@%s", category))
|
count[category] = 0
|
||||||
return false
|
}
|
||||||
})
|
for _, category := range connection.User.IncludedCategories {
|
||||||
}) {
|
if _, ok := count[category]; ok {
|
||||||
if len(notAllowed) == 0 {
|
count[category] += 1
|
||||||
notAllowed = []string{"@all"}
|
}
|
||||||
|
}
|
||||||
|
notAllowed = getUnauthorized(count, "@")
|
||||||
|
if len(notAllowed) > 0 {
|
||||||
|
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("unauthorized access to the following categories: %+v", notAllowed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Check if commands category is in ExcludedCategories
|
// 3. Check if commands category is in ExcludedCategories
|
||||||
@@ -386,14 +381,14 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool {
|
if !slices.ContainsFunc(connection.User.IncludedCommands, func(includedCommand string) bool {
|
||||||
return includedCommand == "*" || includedCommand == comm
|
return includedCommand == "*" || includedCommand == comm
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to run %s command", comm)
|
return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Check if command are in ExcludedCommands
|
// 5. Check if command are in ExcludedCommands
|
||||||
if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool {
|
if slices.ContainsFunc(connection.User.ExcludedCommands, func(excludedCommand string) bool {
|
||||||
return excludedCommand == "*" || excludedCommand == comm
|
return excludedCommand == "*" || excludedCommand == comm
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to run %s command", comm)
|
return fmt.Errorf("not authorised to run %s command", strings.ToUpper(comm))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. PUBSUB authorisation.
|
// 6. PUBSUB authorisation.
|
||||||
@@ -428,24 +423,32 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command intern
|
|||||||
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
if acl.GlobPatterns[readKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%R", key)) {
|
||||||
|
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%R", key))
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
|
if len(notAllowed) > 0 {
|
||||||
|
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9. Check if keys are in IncludedWriteKeys
|
// 9. Check if write keys are in IncludedWriteKeys
|
||||||
|
fmt.Println("KEYS: ", writeKeys)
|
||||||
|
fmt.Println("ALLOWED KEYS: ", connection.User.IncludedWriteKeys)
|
||||||
if !slices.ContainsFunc(writeKeys, func(key string) bool {
|
if !slices.ContainsFunc(writeKeys, func(key string) bool {
|
||||||
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
return slices.ContainsFunc(connection.User.IncludedWriteKeys, func(writeKeyGlob string) bool {
|
||||||
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
if acl.GlobPatterns[writeKeyGlob].Match(key) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
if !slices.Contains(notAllowed, fmt.Sprintf("%s~%s", "%W", key)) {
|
||||||
|
notAllowed = append(notAllowed, fmt.Sprintf("%s~%s", "%W", key))
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
return fmt.Errorf("not authorised to access the following keys %+v", notAllowed)
|
return fmt.Errorf("not authorised to access the following keys: %+v", notAllowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,3 +494,17 @@ func (acl *ACL) RLockUsers() {
|
|||||||
func (acl *ACL) RUnlockUsers() {
|
func (acl *ACL) RUnlockUsers() {
|
||||||
acl.UsersMutex.RUnlock()
|
acl.UsersMutex.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getUnauthorized(count map[string]int, prefix string) []string {
|
||||||
|
var notAllowed []string
|
||||||
|
for member, c := range count {
|
||||||
|
if c == 0 {
|
||||||
|
notAllowed = append(notAllowed, fmt.Sprintf("%s%s", prefix, member))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sort the members in alphabetical order.
|
||||||
|
slices.SortStableFunc(notAllowed, func(a, b string) int {
|
||||||
|
return internal.CompareLex(a, b)
|
||||||
|
})
|
||||||
|
return notAllowed
|
||||||
|
}
|
||||||
|
|||||||
@@ -178,12 +178,12 @@ func Test_ACL(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Test_HandleAuth", func(t *testing.T) {
|
t.Run("Test_HandleAuth", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
conn, err := internal.GetConnection("localhost", port)
|
conn, err := internal.GetConnection("localhost", port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
@@ -193,16 +193,19 @@ func Test_ACL(t *testing.T) {
|
|||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
name string
|
||||||
cmd []resp.Value
|
cmd []resp.Value
|
||||||
wantRes string
|
wantRes string
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{ // 1. Authenticate with default user without specifying username
|
{
|
||||||
|
name: "1. Authenticate with default user without specifying username",
|
||||||
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
},
|
},
|
||||||
{ // 2. Authenticate with plaintext password
|
{
|
||||||
|
name: "2. Authenticate with plaintext password",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("with_password_user"),
|
resp.StringValue("with_password_user"),
|
||||||
@@ -211,7 +214,8 @@ func Test_ACL(t *testing.T) {
|
|||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
},
|
},
|
||||||
{ // 3. Authenticate with SHA256 password
|
{
|
||||||
|
name: "3. Authenticate with SHA256 password",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("with_password_user"),
|
resp.StringValue("with_password_user"),
|
||||||
@@ -220,7 +224,8 @@ func Test_ACL(t *testing.T) {
|
|||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
},
|
},
|
||||||
{ // 4. Authenticate with no password user
|
{
|
||||||
|
name: "4. Authenticate with no password user",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("no_password_user"),
|
resp.StringValue("no_password_user"),
|
||||||
@@ -229,7 +234,8 @@ func Test_ACL(t *testing.T) {
|
|||||||
wantRes: "OK",
|
wantRes: "OK",
|
||||||
wantErr: "",
|
wantErr: "",
|
||||||
},
|
},
|
||||||
{ // 5. Fail to authenticate with disabled user
|
{
|
||||||
|
name: "5. Fail to authenticate with disabled user",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("disabled_user"),
|
resp.StringValue("disabled_user"),
|
||||||
@@ -238,7 +244,8 @@ func Test_ACL(t *testing.T) {
|
|||||||
wantRes: "",
|
wantRes: "",
|
||||||
wantErr: "Error user disabled_user is disabled",
|
wantErr: "Error user disabled_user is disabled",
|
||||||
},
|
},
|
||||||
{ // 6. Fail to authenticate with non-existent user
|
{
|
||||||
|
name: "6. Fail to authenticate with non-existent user",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("non_existent_user"),
|
resp.StringValue("non_existent_user"),
|
||||||
@@ -247,12 +254,24 @@ func Test_ACL(t *testing.T) {
|
|||||||
wantRes: "",
|
wantRes: "",
|
||||||
wantErr: "Error no user with username non_existent_user",
|
wantErr: "Error no user with username non_existent_user",
|
||||||
},
|
},
|
||||||
{ // 7. Command too short
|
{
|
||||||
|
name: "7. Fail to authenticate with the wrong password",
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("with_password_user"),
|
||||||
|
resp.StringValue("wrong_password"),
|
||||||
|
},
|
||||||
|
wantRes: "",
|
||||||
|
wantErr: "Error could not authenticate user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "8. Command too short",
|
||||||
cmd: []resp.Value{resp.StringValue("AUTH")},
|
cmd: []resp.Value{resp.StringValue("AUTH")},
|
||||||
wantRes: "",
|
wantRes: "",
|
||||||
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
||||||
},
|
},
|
||||||
{ // 8. Command too long
|
{
|
||||||
|
name: "9. Command too long",
|
||||||
cmd: []resp.Value{
|
cmd: []resp.Value{
|
||||||
resp.StringValue("AUTH"),
|
resp.StringValue("AUTH"),
|
||||||
resp.StringValue("user"),
|
resp.StringValue("user"),
|
||||||
@@ -265,23 +284,279 @@ func Test_ACL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
if err = r.WriteArray(test.cmd); err != nil {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
t.Error(err)
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
}
|
t.Error(err)
|
||||||
rv, _, err := r.ReadValue()
|
|
||||||
if err != nil {
|
|
||||||
t.Error(err)
|
|
||||||
}
|
|
||||||
if test.wantErr != "" {
|
|
||||||
if rv.Error().Error() != test.wantErr {
|
|
||||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
|
||||||
}
|
}
|
||||||
continue
|
rv, _, err := r.ReadValue()
|
||||||
}
|
if err != nil {
|
||||||
if rv.String() != test.wantRes {
|
t.Error(err)
|
||||||
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
}
|
||||||
|
if test.wantErr != "" {
|
||||||
|
if rv.Error().Error() != test.wantErr {
|
||||||
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rv.String() != test.wantRes {
|
||||||
|
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Test_Permissions", func(t *testing.T) {
|
||||||
|
port, err := internal.GetFreePort()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mockServer, err := setUpServer(port, true, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
mockServer.Start()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := internal.GetConnection("localhost", port)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client := resp.NewConn(conn)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
mockServer.ShutDown()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add users to be used in test cases.
|
||||||
|
users := []echovault.User{
|
||||||
|
{
|
||||||
|
// User with nokeys flag enables.
|
||||||
|
Username: "test_nokeys",
|
||||||
|
Enabled: true,
|
||||||
|
NoKeys: true,
|
||||||
|
AddPlainPasswords: []string{"test_nokeys_password"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// This use will be used to test authorization failure when trying to access resources that are not
|
||||||
|
// in their "included" rules.
|
||||||
|
Username: "test_included",
|
||||||
|
Enabled: true,
|
||||||
|
AddPlainPasswords: []string{"test_included_password"},
|
||||||
|
IncludeCategories: []string{
|
||||||
|
constants.WriteCategory,
|
||||||
|
constants.ReadCategory,
|
||||||
|
constants.SlowCategory,
|
||||||
|
constants.PubSubCategory,
|
||||||
|
constants.ConnectionCategory,
|
||||||
|
constants.ListCategory,
|
||||||
|
},
|
||||||
|
IncludeCommands: []string{"set", "get", "subscribe", "lrange", "ltrim"},
|
||||||
|
IncludeChannels: []string{"channel[12]"},
|
||||||
|
IncludeReadWriteKeys: []string{"key1", "key2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// This use will be used to test authorization failure when trying to access resources that are
|
||||||
|
// in their "excluded" rules.
|
||||||
|
Username: "test_excluded",
|
||||||
|
Enabled: true,
|
||||||
|
AddPlainPasswords: []string{"test_excluded_password"},
|
||||||
|
IncludeCategories: []string{"*"},
|
||||||
|
ExcludeCategories: []string{constants.FastCategory, constants.HashCategory},
|
||||||
|
IncludeCommands: []string{"*"},
|
||||||
|
ExcludeCommands: []string{"set", "mset"},
|
||||||
|
IncludeChannels: []string{"*"},
|
||||||
|
ExcludeChannels: []string{"channel[12]"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, user := range users {
|
||||||
|
if _, err := mockServer.ACLSetUser(user); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
auth []resp.Value
|
||||||
|
cmd []resp.Value
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "1. Return error when the connection is not authenticated",
|
||||||
|
auth: []resp.Value{},
|
||||||
|
cmd: []resp.Value{resp.StringValue("SET"), resp.StringValue("key"), resp.StringValue("value")},
|
||||||
|
wantErr: "user must be authenticated",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "2. Return error when the command category is not in the included categories section",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_included"),
|
||||||
|
resp.StringValue("test_included_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("HSET"),
|
||||||
|
resp.StringValue("hash"),
|
||||||
|
resp.StringValue("field1"),
|
||||||
|
resp.StringValue("value1"),
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s @%s]",
|
||||||
|
constants.FastCategory, constants.HashCategory),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3. Return error when the command category is in the excluded categories section",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_excluded"),
|
||||||
|
resp.StringValue("test_excluded_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("HSET"),
|
||||||
|
resp.StringValue("hash"),
|
||||||
|
resp.StringValue("field1"),
|
||||||
|
resp.StringValue("value1"),
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s]",
|
||||||
|
constants.HashCategory),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "4. Return error when the command is not in the included command category",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_included"),
|
||||||
|
resp.StringValue("test_included_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("MSET"),
|
||||||
|
resp.StringValue("key1"),
|
||||||
|
resp.StringValue("value1"),
|
||||||
|
},
|
||||||
|
wantErr: "not authorised to run MSET command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "5. Return error when the command is in the excluded command category",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_excluded"),
|
||||||
|
resp.StringValue("test_excluded_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("SET"),
|
||||||
|
resp.StringValue("key1"),
|
||||||
|
resp.StringValue("value1"),
|
||||||
|
},
|
||||||
|
wantErr: "not authorised to run SET command",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "6. Return error when subscribing to channel that's not in included channels",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_included"),
|
||||||
|
resp.StringValue("test_included_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("SUBSCRIBE"),
|
||||||
|
resp.StringValue("channel3"),
|
||||||
|
},
|
||||||
|
wantErr: "not authorised to access channel &channel3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "7. Return error when publishing to channel that's in excluded channels",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_excluded"),
|
||||||
|
resp.StringValue("test_excluded_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("SUBSCRIBE"),
|
||||||
|
resp.StringValue("channel2"),
|
||||||
|
},
|
||||||
|
wantErr: "not authorised to access channel &channel2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "8. Return error when the user has nokeys flag",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_nokeys"),
|
||||||
|
resp.StringValue("test_nokeys_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{resp.StringValue("GET"), resp.StringValue("key1")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "9. Return error when trying to read from keys that are not in read keys list",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_included"),
|
||||||
|
resp.StringValue("test_included_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("LRANGE"),
|
||||||
|
resp.StringValue("key3"),
|
||||||
|
resp.StringValue("0"),
|
||||||
|
resp.StringValue("-1"),
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "10. Return error when trying to write to keys that are not in write keys list",
|
||||||
|
auth: []resp.Value{
|
||||||
|
resp.StringValue("AUTH"),
|
||||||
|
resp.StringValue("test_included"),
|
||||||
|
resp.StringValue("test_included_password"),
|
||||||
|
},
|
||||||
|
cmd: []resp.Value{
|
||||||
|
resp.StringValue("LTRIM"),
|
||||||
|
resp.StringValue("key3"),
|
||||||
|
resp.StringValue("0"),
|
||||||
|
resp.StringValue("3"),
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
// Authenticate the user if the auth command is provided.
|
||||||
|
if len(test.auth) > 0 {
|
||||||
|
err := client.WriteArray(test.auth)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res, _, err := client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(res.String(), "ok") {
|
||||||
|
t.Errorf("expected auth response to be OK, got \"%s\"", res.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.WriteArray(test.cmd); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, _, err := client.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(res.Error().Error(), test.wantErr) {
|
||||||
|
t.Errorf("expected error to contain string \"%s\", got \"%s\"",
|
||||||
|
test.wantErr, res.Error().Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Test_HandleCat", func(t *testing.T) {
|
t.Run("Test_HandleCat", func(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user