mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 09:50:09 +08:00
Initialise default IncludedCategories to allCategories("*").
Implemented test for ACL USERS command.
This commit is contained in:
@@ -202,7 +202,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("category not found")
|
||||
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
|
||||
}
|
||||
|
||||
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/echovault/echovault/src/utils"
|
||||
"github.com/tidwall/resp"
|
||||
"net"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -21,6 +22,14 @@ func init() {
|
||||
bindAddr = "localhost"
|
||||
port = 7490
|
||||
|
||||
mockServer = setUpServer(bindAddr, port)
|
||||
|
||||
go func() {
|
||||
mockServer.Start(context.Background())
|
||||
}()
|
||||
}
|
||||
|
||||
func setUpServer(bindAddr string, port uint16) *server.Server {
|
||||
config := utils.Config{
|
||||
BindAddr: bindAddr,
|
||||
Port: port,
|
||||
@@ -33,15 +42,11 @@ func init() {
|
||||
acl = NewACL(config)
|
||||
acl.Users = append(acl.Users, generateInitialTestUsers()...)
|
||||
|
||||
mockServer = server.NewServer(server.Opts{
|
||||
return server.NewServer(server.Opts{
|
||||
Config: config,
|
||||
ACL: acl,
|
||||
Commands: Commands(),
|
||||
})
|
||||
|
||||
go func() {
|
||||
mockServer.Start(context.Background())
|
||||
}()
|
||||
}
|
||||
|
||||
func generateInitialTestUsers() []*User {
|
||||
@@ -80,6 +85,9 @@ func Test_HandleAuth(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
r := resp.NewConn(conn)
|
||||
|
||||
tests := []struct {
|
||||
@@ -175,11 +183,174 @@ func Test_HandleAuth(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_HandleCat(t *testing.T) {
|
||||
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
||||
// list of categories and commands available in the ACL module.
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
r := resp.NewConn(conn)
|
||||
|
||||
// Authenticate connection
|
||||
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rv.String() != "OK" {
|
||||
t.Error("could not authenticate user")
|
||||
}
|
||||
|
||||
func Test_HandleUsers(t *testing.T) {}
|
||||
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
||||
// list of categories and commands available in the ACL module.
|
||||
tests := []struct {
|
||||
cmd []resp.Value
|
||||
wantRes []string
|
||||
wantErr string
|
||||
}{
|
||||
{ // 1. Return list of categories
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")},
|
||||
wantRes: []string{
|
||||
utils.ConnectionCategory,
|
||||
utils.SlowCategory,
|
||||
utils.FastCategory,
|
||||
utils.AdminCategory,
|
||||
utils.DangerousCategory,
|
||||
},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 2. Return list of commands in connection category
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.ConnectionCategory)},
|
||||
wantRes: []string{"auth"},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 3. Return list of commands in slow category
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.SlowCategory)},
|
||||
wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 4. Return list of commands in fast category
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.FastCategory)},
|
||||
wantRes: []string{"acl|whoami"},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 5. Return list of commands in admin category
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.AdminCategory)},
|
||||
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 6. Return list of commands in dangerous category
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.DangerousCategory)},
|
||||
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||
wantErr: "",
|
||||
},
|
||||
{ // 7. Return error when category does not exist
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")},
|
||||
wantRes: nil,
|
||||
wantErr: "Error category NON-EXISTENT not found",
|
||||
},
|
||||
{ // 8. Command too long
|
||||
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")},
|
||||
wantRes: nil,
|
||||
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if err = r.WriteArray(test.cmd); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err = r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if test.wantErr != "" {
|
||||
if rv.Error().Error() != test.wantErr {
|
||||
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||
}
|
||||
continue
|
||||
}
|
||||
resArr := rv.Array()
|
||||
// Check if all the elements in the expected array are in the response array
|
||||
for _, expected := range test.wantRes {
|
||||
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
||||
return value.String() == expected
|
||||
}) {
|
||||
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
||||
}
|
||||
}
|
||||
// Check if all the elements in the response array are in the expected array
|
||||
for _, value := range resArr {
|
||||
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
|
||||
return value.String() == expected
|
||||
}) {
|
||||
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUsers(t *testing.T) {
|
||||
var port uint16 = 7491
|
||||
mockServer := setUpServer(bindAddr, port)
|
||||
go func() {
|
||||
mockServer.Start(context.Background())
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
r := resp.NewConn(conn)
|
||||
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
rv, _, err := r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if rv.String() != "OK" {
|
||||
t.Errorf("expected OK response, got \"%s\"", rv.String())
|
||||
}
|
||||
|
||||
users := []string{"default", "with_password_user", "no_password_user", "disabled_user"}
|
||||
|
||||
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
rv, _, err = r.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
resArr := rv.Array()
|
||||
|
||||
// Check if all the expected users are in the response array
|
||||
for _, user := range users {
|
||||
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
||||
return value.String() == user
|
||||
}) {
|
||||
t.Errorf("could not find expected user \"%s\" in response array", user)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all the users in the response array are in the expected users
|
||||
for _, value := range resArr {
|
||||
if !slices.ContainsFunc(users, func(user string) bool {
|
||||
return value.String() == user
|
||||
}) {
|
||||
t.Errorf("could not find response user \"%s\" in expected users array", value.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleSetUser(t *testing.T) {}
|
||||
|
||||
|
@@ -38,6 +38,9 @@ type User struct {
|
||||
|
||||
func (user *User) Normalise() {
|
||||
user.IncludedCategories = RemoveDuplicateEntries(user.IncludedCategories, "allCategories")
|
||||
if len(user.IncludedCategories) == 0 {
|
||||
user.IncludedCategories = []string{"*"}
|
||||
}
|
||||
user.ExcludedCategories = RemoveDuplicateEntries(user.ExcludedCategories, "allCategories")
|
||||
if slices.Contains(user.ExcludedCategories, "*") {
|
||||
user.IncludedCategories = []string{}
|
||||
|
Reference in New Issue
Block a user