Updated logic for loading acl config. If the config file does not exist, it will be created.

This commit is contained in:
Kelvin Clement Mwinuka
2024-06-01 22:16:16 +08:00
parent 4d56ee9083
commit bdfaf5446a
3 changed files with 236 additions and 36 deletions

View File

@@ -49,6 +49,45 @@ type ACL struct {
GlobPatterns map[string]glob.Glob
}
func loadUsersFromConfigFile(users []*User, filePath string) {
if filePath != "" {
// Create the director if it does not exist.
if err := os.MkdirAll(path.Dir(filePath), os.ModePerm); err != nil {
log.Printf("mkdir ACL config: %v\n", err)
return
}
// Open the config file. Create it if it does not exist.
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, os.ModePerm)
if err != nil {
log.Printf("open ACL config: %v\n", err)
return
}
defer func() {
if err := f.Close(); err != nil {
log.Printf("close ACL config: %v\n", err)
}
}()
ext := path.Ext(f.Name())
if strings.ToLower(ext) == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil {
log.Printf("load ACL config: %v\n", err)
return
}
}
if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) {
if err := yaml.NewDecoder(f).Decode(&users); err != nil {
log.Printf("load ACL config: %v\n", err)
return
}
}
}
}
func NewACL(config config.Config) *ACL {
var users []*User
@@ -65,32 +104,7 @@ func NewACL(config config.Config) *ACL {
}
// 2. Read and parse the ACL config file
if config.AclConfig != "" {
// Override acl configurations from file
if f, err := os.Open(config.AclConfig); err != nil {
panic(err)
} else {
defer func() {
if err := f.Close(); err != nil {
log.Printf("acl config file close: %v\n", err)
}
}()
ext := path.Ext(f.Name())
if ext == ".json" {
if err := json.NewDecoder(f).Decode(&users); err != nil {
log.Printf("load ACL config: %v\n", err)
}
}
if slices.Contains([]string{".yaml", ".yml"}, ext) {
if err := yaml.NewDecoder(f).Decode(&users); err != nil {
log.Printf("load ACL config: %v\n", err)
}
}
}
}
loadUsersFromConfigFile(users, config.AclConfig)
// 3. If default user was not loaded from file, add the created one
defaultLoaded := false

View File

@@ -442,7 +442,7 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
acl.RLockUsers()
acl.RUnlockUsers()
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
if err != nil {
return nil, err
}
@@ -455,32 +455,29 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
ext := path.Ext(f.Name())
if ext == ".json" {
if strings.ToLower(ext) == ".json" {
// Write to JSON config file
out, err := json.Marshal(acl.Users)
if err != nil {
return nil, err
}
_, err = f.Write(out)
if err != nil {
if _, err = f.Write(out); err != nil {
return nil, err
}
}
if ext == ".yaml" || ext == ".yml" {
if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) {
// Write to yaml file
out, err := yaml.Marshal(acl.Users)
if err != nil {
return nil, err
}
_, err = f.Write(out)
if err != nil {
if _, err = f.Write(out); err != nil {
return nil, err
}
}
err = f.Sync()
if err != nil {
if err = f.Sync(); err != nil {
return nil, err
}

View File

@@ -22,6 +22,8 @@ import (
"github.com/echovault/echovault/internal/config"
"github.com/echovault/echovault/internal/constants"
"github.com/tidwall/resp"
"os"
"path"
"slices"
"strings"
"testing"
@@ -174,6 +176,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleAuth", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -281,6 +284,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleCat", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -387,6 +391,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleUsers", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -450,6 +455,7 @@ func Test_ACL(t *testing.T) {
})
t.Run("Test_HandleSetUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1044,6 +1050,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleGetUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1211,6 +1218,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleDelUser", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1319,6 +1327,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleWhoAmI", func(t *testing.T) {
t.Parallel()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
@@ -1391,6 +1400,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
})
t.Run("Test_HandleList", func(t *testing.T) {
t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
@@ -1543,9 +1553,188 @@ When nopass is provided, ignore any passwords that may have been provided in the
t.Errorf("could not find the following user in expected slice: %+v", resStr)
return
}
clear(resStr)
}
})
}
})
t.Run("Test_HandleSave", func(t *testing.T) {
t.Parallel()
baseDir := path.Join(".", "testdata", "save")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
tests := []struct {
name string
path string
want []string // Response from ACL List command.
}{
{
name: "1. Save ACL config to .json file",
path: path.Join(baseDir, "json_test.json"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
{
name: "2. Save ACL config to .yaml file",
path: path.Join(baseDir, "yaml_test.yaml"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
{
name: "3. Save ACL config to .yml file",
path: path.Join(baseDir, "yml_test.yml"),
want: []string{
"default on +@all +all %RW~* +&*",
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
generateSHA256Password("password3"), "%RW"),
"no_password_user on nopass +@all +all %RW~* +&*",
"disabled_user off >password5 +@all +all %RW~* +&*",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// Get free port.
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
// Create new server instance
mockServer, err := setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err := internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected OK response, got \"%s\"", res.String())
mockServer.ShutDown()
return
}
// Close client connection
if err = conn.Close(); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
// Shutdown the mock server
mockServer.ShutDown()
// Restart server and create new client connection
port, err = internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err = setUpServer(port, false, test.path)
if err != nil {
t.Error(err)
return
}
go func() {
mockServer.Start()
}()
conn, err = internal.GetConnection("localhost", port)
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
client = resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
res, _, err = client.ReadValue()
if err != nil {
t.Error(err)
mockServer.ShutDown()
return
}
// Check if ACL LIST returns the expected list of users.
resArr := res.Array()
if len(resArr) != len(test.want) {
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
mockServer.ShutDown()
return
}
var resStr []string
for i := 0; i < len(resArr); i++ {
resStr = strings.Split(resArr[i].String(), " ")
if !slices.ContainsFunc(test.want, func(s string) bool {
expectedUserSlice := strings.Split(s, " ")
return compareSlices(resStr, expectedUserSlice) == nil
}) {
t.Errorf("could not find the following user in expected slice: %+v", resStr)
mockServer.ShutDown()
return
}
}
mockServer.ShutDown()
})
}
})
t.Run("Test_HandleLoad", func(t *testing.T) {
t.Parallel()
baseDir := path.Join(".", "testdata", "load")
t.Cleanup(func() {
_ = os.RemoveAll(baseDir)
})
})
}