mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-24 00:14:08 +08:00
Implemented tests for ACL permissions
This commit is contained in:
@@ -178,12 +178,12 @@ func Test_ACL(t *testing.T) {
|
||||
|
||||
t.Run("Test_HandleAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
@@ -193,16 +193,19 @@ func Test_ACL(t *testing.T) {
|
||||
r := resp.NewConn(conn)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd []resp.Value
|
||||
wantRes string
|
||||
wantErr string
|
||||
}{
|
||||
{ // 1. Authenticate with default user without specifying username
|
||||
{
|
||||
name: "1. Authenticate with default user without specifying username",
|
||||
cmd: []resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")},
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 2. Authenticate with plaintext password
|
||||
{
|
||||
name: "2. Authenticate with plaintext password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
@@ -211,7 +214,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 3. Authenticate with SHA256 password
|
||||
{
|
||||
name: "3. Authenticate with SHA256 password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
@@ -220,7 +224,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 4. Authenticate with no password user
|
||||
{
|
||||
name: "4. Authenticate with no password user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("no_password_user"),
|
||||
@@ -229,7 +234,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "OK",
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 5. Fail to authenticate with disabled user
|
||||
{
|
||||
name: "5. Fail to authenticate with disabled user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("disabled_user"),
|
||||
@@ -238,7 +244,8 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "",
|
||||
wantErr: "Error user disabled_user is disabled",
|
||||
},
|
||||
{ // 6. Fail to authenticate with non-existent user
|
||||
{
|
||||
name: "6. Fail to authenticate with non-existent user",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("non_existent_user"),
|
||||
@@ -247,12 +254,24 @@ func Test_ACL(t *testing.T) {
|
||||
wantRes: "",
|
||||
wantErr: "Error no user with username non_existent_user",
|
||||
},
|
||||
{ // 7. Command too short
|
||||
{
|
||||
name: "7. Fail to authenticate with the wrong password",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("with_password_user"),
|
||||
resp.StringValue("wrong_password"),
|
||||
},
|
||||
wantRes: "",
|
||||
wantErr: "Error could not authenticate user",
|
||||
},
|
||||
{
|
||||
name: "8. Command too short",
|
||||
cmd: []resp.Value{resp.StringValue("AUTH")},
|
||||
wantRes: "",
|
||||
wantErr: fmt.Sprintf("Error %s", constants.WrongArgsResponse),
|
||||
},
|
||||
{ // 8. Command too long
|
||||
{
|
||||
name: "9. Command too long",
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("user"),
|
||||
@@ -265,23 +284,279 @@ func Test_ACL(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if err = r.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if test.wantErr != "" {
|
||||
if rv.Error().Error() != test.wantErr {
|
||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
if err = r.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if rv.String() != test.wantRes {
|
||||
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if test.wantErr != "" {
|
||||
if rv.Error().Error() != test.wantErr {
|
||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if rv.String() != test.wantRes {
|
||||
t.Errorf("expected response \"%s\", got \"%s\"", test.wantRes, rv.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_Permissions", func(t *testing.T) {
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
mockServer, err := setUpServer(port, true, "")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = conn.Close()
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
|
||||
// Add users to be used in test cases.
|
||||
users := []echovault.User{
|
||||
{
|
||||
// User with nokeys flag enables.
|
||||
Username: "test_nokeys",
|
||||
Enabled: true,
|
||||
NoKeys: true,
|
||||
AddPlainPasswords: []string{"test_nokeys_password"},
|
||||
},
|
||||
{
|
||||
// This use will be used to test authorization failure when trying to access resources that are not
|
||||
// in their "included" rules.
|
||||
Username: "test_included",
|
||||
Enabled: true,
|
||||
AddPlainPasswords: []string{"test_included_password"},
|
||||
IncludeCategories: []string{
|
||||
constants.WriteCategory,
|
||||
constants.ReadCategory,
|
||||
constants.SlowCategory,
|
||||
constants.PubSubCategory,
|
||||
constants.ConnectionCategory,
|
||||
constants.ListCategory,
|
||||
},
|
||||
IncludeCommands: []string{"set", "get", "subscribe", "lrange", "ltrim"},
|
||||
IncludeChannels: []string{"channel[12]"},
|
||||
IncludeReadWriteKeys: []string{"key1", "key2"},
|
||||
},
|
||||
{
|
||||
// This use will be used to test authorization failure when trying to access resources that are
|
||||
// in their "excluded" rules.
|
||||
Username: "test_excluded",
|
||||
Enabled: true,
|
||||
AddPlainPasswords: []string{"test_excluded_password"},
|
||||
IncludeCategories: []string{"*"},
|
||||
ExcludeCategories: []string{constants.FastCategory, constants.HashCategory},
|
||||
IncludeCommands: []string{"*"},
|
||||
ExcludeCommands: []string{"set", "mset"},
|
||||
IncludeChannels: []string{"*"},
|
||||
ExcludeChannels: []string{"channel[12]"},
|
||||
},
|
||||
}
|
||||
for _, user := range users {
|
||||
if _, err := mockServer.ACLSetUser(user); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
auth []resp.Value
|
||||
cmd []resp.Value
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "1. Return error when the connection is not authenticated",
|
||||
auth: []resp.Value{},
|
||||
cmd: []resp.Value{resp.StringValue("SET"), resp.StringValue("key"), resp.StringValue("value")},
|
||||
wantErr: "user must be authenticated",
|
||||
},
|
||||
{
|
||||
name: "2. Return error when the command category is not in the included categories section",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("HSET"),
|
||||
resp.StringValue("hash"),
|
||||
resp.StringValue("field1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s @%s]",
|
||||
constants.FastCategory, constants.HashCategory),
|
||||
},
|
||||
{
|
||||
name: "3. Return error when the command category is in the excluded categories section",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("HSET"),
|
||||
resp.StringValue("hash"),
|
||||
resp.StringValue("field1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("unauthorized access to the following categories: [@%s]",
|
||||
constants.HashCategory),
|
||||
},
|
||||
{
|
||||
name: "4. Return error when the command is not in the included command category",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("MSET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: "not authorised to run MSET command",
|
||||
},
|
||||
{
|
||||
name: "5. Return error when the command is in the excluded command category",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SET"),
|
||||
resp.StringValue("key1"),
|
||||
resp.StringValue("value1"),
|
||||
},
|
||||
wantErr: "not authorised to run SET command",
|
||||
},
|
||||
{
|
||||
name: "6. Return error when subscribing to channel that's not in included channels",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SUBSCRIBE"),
|
||||
resp.StringValue("channel3"),
|
||||
},
|
||||
wantErr: "not authorised to access channel &channel3",
|
||||
},
|
||||
{
|
||||
name: "7. Return error when publishing to channel that's in excluded channels",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_excluded"),
|
||||
resp.StringValue("test_excluded_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("SUBSCRIBE"),
|
||||
resp.StringValue("channel2"),
|
||||
},
|
||||
wantErr: "not authorised to access channel &channel2",
|
||||
},
|
||||
{
|
||||
name: "8. Return error when the user has nokeys flag",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_nokeys"),
|
||||
resp.StringValue("test_nokeys_password"),
|
||||
},
|
||||
cmd: []resp.Value{resp.StringValue("GET"), resp.StringValue("key1")},
|
||||
},
|
||||
{
|
||||
name: "9. Return error when trying to read from keys that are not in read keys list",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("LRANGE"),
|
||||
resp.StringValue("key3"),
|
||||
resp.StringValue("0"),
|
||||
resp.StringValue("-1"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"),
|
||||
},
|
||||
{
|
||||
name: "10. Return error when trying to write to keys that are not in write keys list",
|
||||
auth: []resp.Value{
|
||||
resp.StringValue("AUTH"),
|
||||
resp.StringValue("test_included"),
|
||||
resp.StringValue("test_included_password"),
|
||||
},
|
||||
cmd: []resp.Value{
|
||||
resp.StringValue("LTRIM"),
|
||||
resp.StringValue("key3"),
|
||||
resp.StringValue("0"),
|
||||
resp.StringValue("3"),
|
||||
},
|
||||
wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Authenticate the user if the auth command is provided.
|
||||
if len(test.auth) > 0 {
|
||||
err := client.WriteArray(test.auth)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected auth response to be OK, got \"%s\"", res.String())
|
||||
}
|
||||
}
|
||||
|
||||
if err := client.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.Contains(res.Error().Error(), test.wantErr) {
|
||||
t.Errorf("expected error to contain string \"%s\", got \"%s\"",
|
||||
test.wantErr, res.Error().Error())
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_HandleCat", func(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user