Updated logic for loading acl config. If the config file does not exist, it will be created.

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-01 22:16:16 +08:00
parent 4d56ee9083
commit bdfaf5446a
3 changed files with 236 additions and 36 deletions

View File

@@ -22,6 +22,8 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/constants"
"github.com/tidwall/resp"
"os"
"path"
"slices"
"strings"
"testing"
@@ -174,6 +176,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -281,6 +284,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleCat", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -387,6 +391,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleUsers", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -450,6 +455,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleSetUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1044,6 +1050,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleGetUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1211,6 +1218,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleDelUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1319,6 +1327,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleWhoAmI", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -1391,6 +1400,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleList", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1543,9 +1553,188 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Errorf("could not find the following user in expected slice: %+v", resStr)
return
}
clear(resStr)
}
})
}
})
t.Run("Test_HandleSave", func(t *testing.T) {
t.Parallel()
baseDir := path.Join(".", "testdata", "save")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
tests := []struct {
name string
path string
want []string // Response from ACL List command.
}{
{
name: "1. Save ACL config to .json file",
path: path.Join(baseDir, "json_test.json"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
{
name: "2. Save ACL config to .yaml file",
path: path.Join(baseDir, "yaml_test.yaml"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
{
name: "3. Save ACL config to .yml file",
path: path.Join(baseDir, "yml_test.yml"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// Get free port.
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
// Create new server instance
mockServer, err := setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK response, got \"%s\"", res.String())
mockServer.ShutDown()
return
}
// Close client connection
if err = conn.Close(); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
// Shutdown the mock server
mockServer.ShutDown()
// Restart server and create new client connection
port, err = internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err = setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err = internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
client = resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
// Check if ACL LIST returns the expected list of users.
resArr := res.Array()
if len(resArr) != len(test.want) {
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
mockServer.ShutDown()
return
}
var resStr []string
for i := 0; i < len(resArr); i++ {
resStr = strings.Split(resArr[i].String(), " ")
if !slices.ContainsFunc(test.want, func(s string) bool {
expectedUserSlice := strings.Split(s, " ")
return compareSlices(resStr, expectedUserSlice) == nil
}) {
t.Errorf("could not find the following user in expected slice: %+v", resStr)
mockServer.ShutDown()
return
}
}
mockServer.ShutDown()
})
}
})
t.Run("Test_HandleLoad", func(t *testing.T) {
t.Parallel()
baseDir := path.Join(".", "testdata", "load")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
})
}