mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-09 18:00:23 +08:00
Initialise default IncludedCategories to allCategories("*").
Implemented test for ACL USERS command.
This commit is contained in:
@@ -202,7 +202,7 @@ func handleCat(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("category not found")
|
return nil, fmt.Errorf("category %s not found", strings.ToUpper(cmd[2]))
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
func handleUsers(_ context.Context, _ []string, server utils.Server, _ *net.Conn) ([]byte, error) {
|
||||||
|
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/echovault/echovault/src/utils"
|
"github.com/echovault/echovault/src/utils"
|
||||||
"github.com/tidwall/resp"
|
"github.com/tidwall/resp"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,6 +22,14 @@ func init() {
|
|||||||
bindAddr = "localhost"
|
bindAddr = "localhost"
|
||||||
port = 7490
|
port = 7490
|
||||||
|
|
||||||
|
mockServer = setUpServer(bindAddr, port)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
mockServer.Start(context.Background())
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUpServer(bindAddr string, port uint16) *server.Server {
|
||||||
config := utils.Config{
|
config := utils.Config{
|
||||||
BindAddr: bindAddr,
|
BindAddr: bindAddr,
|
||||||
Port: port,
|
Port: port,
|
||||||
@@ -33,15 +42,11 @@ func init() {
|
|||||||
acl = NewACL(config)
|
acl = NewACL(config)
|
||||||
acl.Users = append(acl.Users, generateInitialTestUsers()...)
|
acl.Users = append(acl.Users, generateInitialTestUsers()...)
|
||||||
|
|
||||||
mockServer = server.NewServer(server.Opts{
|
return server.NewServer(server.Opts{
|
||||||
Config: config,
|
Config: config,
|
||||||
ACL: acl,
|
ACL: acl,
|
||||||
Commands: Commands(),
|
Commands: Commands(),
|
||||||
})
|
})
|
||||||
|
|
||||||
go func() {
|
|
||||||
mockServer.Start(context.Background())
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateInitialTestUsers() []*User {
|
func generateInitialTestUsers() []*User {
|
||||||
@@ -80,6 +85,9 @@ func Test_HandleAuth(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
r := resp.NewConn(conn)
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -175,11 +183,174 @@ func Test_HandleAuth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_HandleCat(t *testing.T) {
|
func Test_HandleCat(t *testing.T) {
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
r := resp.NewConn(conn)
|
||||||
|
|
||||||
|
// Authenticate connection
|
||||||
|
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rv, _, err := r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if rv.String() != "OK" {
|
||||||
|
t.Error("could not authenticate user")
|
||||||
|
}
|
||||||
|
|
||||||
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
// Since only ACL commands are loaded in this test suite, this test will only test against the
|
||||||
// list of categories and commands available in the ACL module.
|
// list of categories and commands available in the ACL module.
|
||||||
|
tests := []struct {
|
||||||
|
cmd []resp.Value
|
||||||
|
wantRes []string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{ // 1. Return list of categories
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT")},
|
||||||
|
wantRes: []string{
|
||||||
|
utils.ConnectionCategory,
|
||||||
|
utils.SlowCategory,
|
||||||
|
utils.FastCategory,
|
||||||
|
utils.AdminCategory,
|
||||||
|
utils.DangerousCategory,
|
||||||
|
},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 2. Return list of commands in connection category
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.ConnectionCategory)},
|
||||||
|
wantRes: []string{"auth"},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 3. Return list of commands in slow category
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.SlowCategory)},
|
||||||
|
wantRes: []string{"auth", "acl|cat", "acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 4. Return list of commands in fast category
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.FastCategory)},
|
||||||
|
wantRes: []string{"acl|whoami"},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 5. Return list of commands in admin category
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.AdminCategory)},
|
||||||
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 6. Return list of commands in dangerous category
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue(utils.DangerousCategory)},
|
||||||
|
wantRes: []string{"acl|users", "acl|setuser", "acl|getuser", "acl|deluser", "acl|list", "acl|load", "acl|save"},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{ // 7. Return error when category does not exist
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("non-existent")},
|
||||||
|
wantRes: nil,
|
||||||
|
wantErr: "Error category NON-EXISTENT not found",
|
||||||
|
},
|
||||||
|
{ // 8. Command too long
|
||||||
|
cmd: []resp.Value{resp.StringValue("ACL"), resp.StringValue("CAT"), resp.StringValue("category1"), resp.StringValue("category2")},
|
||||||
|
wantRes: nil,
|
||||||
|
wantErr: fmt.Sprintf("Error %s", utils.WrongArgsResponse),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
if err = r.WriteArray(test.cmd); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rv, _, err = r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if test.wantErr != "" {
|
||||||
|
if rv.Error().Error() != test.wantErr {
|
||||||
|
t.Errorf("expected error response \"%s\", got \"%s\"", test.wantErr, rv.Error().Error())
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resArr := rv.Array()
|
||||||
|
// Check if all the elements in the expected array are in the response array
|
||||||
|
for _, expected := range test.wantRes {
|
||||||
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
||||||
|
return value.String() == expected
|
||||||
|
}) {
|
||||||
|
t.Errorf("could not find expected command \"%s\" in the response array for category", expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check if all the elements in the response array are in the expected array
|
||||||
|
for _, value := range resArr {
|
||||||
|
if !slices.ContainsFunc(test.wantRes, func(expected string) bool {
|
||||||
|
return value.String() == expected
|
||||||
|
}) {
|
||||||
|
t.Errorf("could not find response command \"%s\" in the expected array", value.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_HandleUsers(t *testing.T) {}
|
func Test_HandleUsers(t *testing.T) {
|
||||||
|
var port uint16 = 7491
|
||||||
|
mockServer := setUpServer(bindAddr, port)
|
||||||
|
go func() {
|
||||||
|
mockServer.Start(context.Background())
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", bindAddr, port))
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
r := resp.NewConn(conn)
|
||||||
|
if err = r.WriteArray([]resp.Value{resp.StringValue("AUTH"), resp.StringValue("password1")}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
rv, _, err := r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
if rv.String() != "OK" {
|
||||||
|
t.Errorf("expected OK response, got \"%s\"", rv.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
users := []string{"default", "with_password_user", "no_password_user", "disabled_user"}
|
||||||
|
|
||||||
|
if err = r.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("USERS")}); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rv, _, err = r.ReadValue()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resArr := rv.Array()
|
||||||
|
|
||||||
|
// Check if all the expected users are in the response array
|
||||||
|
for _, user := range users {
|
||||||
|
if !slices.ContainsFunc(resArr, func(value resp.Value) bool {
|
||||||
|
return value.String() == user
|
||||||
|
}) {
|
||||||
|
t.Errorf("could not find expected user \"%s\" in response array", user)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if all the users in the response array are in the expected users
|
||||||
|
for _, value := range resArr {
|
||||||
|
if !slices.ContainsFunc(users, func(user string) bool {
|
||||||
|
return value.String() == user
|
||||||
|
}) {
|
||||||
|
t.Errorf("could not find response user \"%s\" in expected users array", value.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_HandleSetUser(t *testing.T) {}
|
func Test_HandleSetUser(t *testing.T) {}
|
||||||
|
|
||||||
|
@@ -38,6 +38,9 @@ type User struct {
|
|||||||
|
|
||||||
func (user *User) Normalise() {
|
func (user *User) Normalise() {
|
||||||
user.IncludedCategories = RemoveDuplicateEntries(user.IncludedCategories, "allCategories")
|
user.IncludedCategories = RemoveDuplicateEntries(user.IncludedCategories, "allCategories")
|
||||||
|
if len(user.IncludedCategories) == 0 {
|
||||||
|
user.IncludedCategories = []string{"*"}
|
||||||
|
}
|
||||||
user.ExcludedCategories = RemoveDuplicateEntries(user.ExcludedCategories, "allCategories")
|
user.ExcludedCategories = RemoveDuplicateEntries(user.ExcludedCategories, "allCategories")
|
||||||
if slices.Contains(user.ExcludedCategories, "*") {
|
if slices.Contains(user.ExcludedCategories, "*") {
|
||||||
user.IncludedCategories = []string{}
|
user.IncludedCategories = []string{}
|
||||||
|
Reference in New Issue
Block a user