mirror of
https://github.com/EchoVault/SugarDB.git
synced 2025-10-05 16:06:57 +08:00
Updated logic for loading acl config. If the config file does not exist, it will be created.
This commit is contained in:
@@ -49,6 +49,45 @@ type ACL struct {
|
||||
GlobPatterns map[string]glob.Glob
|
||||
}
|
||||
|
||||
func loadUsersFromConfigFile(users []*User, filePath string) {
|
||||
if filePath != "" {
|
||||
// Create the director if it does not exist.
|
||||
if err := os.MkdirAll(path.Dir(filePath), os.ModePerm); err != nil {
|
||||
log.Printf("mkdir ACL config: %v\n", err)
|
||||
return
|
||||
}
|
||||
// Open the config file. Create it if it does not exist.
|
||||
f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, os.ModePerm)
|
||||
if err != nil {
|
||||
log.Printf("open ACL config: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
log.Printf("close ACL config: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ext := path.Ext(f.Name())
|
||||
|
||||
if strings.ToLower(ext) == ".json" {
|
||||
if err := json.NewDecoder(f).Decode(&users); err != nil {
|
||||
log.Printf("load ACL config: %v\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) {
|
||||
if err := yaml.NewDecoder(f).Decode(&users); err != nil {
|
||||
log.Printf("load ACL config: %v\n", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func NewACL(config config.Config) *ACL {
|
||||
var users []*User
|
||||
|
||||
@@ -65,32 +104,7 @@ func NewACL(config config.Config) *ACL {
|
||||
}
|
||||
|
||||
// 2. Read and parse the ACL config file
|
||||
if config.AclConfig != "" {
|
||||
// Override acl configurations from file
|
||||
if f, err := os.Open(config.AclConfig); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defer func() {
|
||||
if err := f.Close(); err != nil {
|
||||
log.Printf("acl config file close: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ext := path.Ext(f.Name())
|
||||
|
||||
if ext == ".json" {
|
||||
if err := json.NewDecoder(f).Decode(&users); err != nil {
|
||||
log.Printf("load ACL config: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
if slices.Contains([]string{".yaml", ".yml"}, ext) {
|
||||
if err := yaml.NewDecoder(f).Decode(&users); err != nil {
|
||||
log.Printf("load ACL config: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
loadUsersFromConfigFile(users, config.AclConfig)
|
||||
|
||||
// 3. If default user was not loaded from file, add the created one
|
||||
defaultLoaded := false
|
||||
|
@@ -442,7 +442,7 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
acl.RLockUsers()
|
||||
acl.RUnlockUsers()
|
||||
|
||||
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend)
|
||||
f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -455,32 +455,29 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) {
|
||||
|
||||
ext := path.Ext(f.Name())
|
||||
|
||||
if ext == ".json" {
|
||||
if strings.ToLower(ext) == ".json" {
|
||||
// Write to JSON config file
|
||||
out, err := json.Marshal(acl.Users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = f.Write(out)
|
||||
if err != nil {
|
||||
if _, err = f.Write(out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if ext == ".yaml" || ext == ".yml" {
|
||||
if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) {
|
||||
// Write to yaml file
|
||||
out, err := yaml.Marshal(acl.Users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = f.Write(out)
|
||||
if err != nil {
|
||||
if _, err = f.Write(out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = f.Sync()
|
||||
if err != nil {
|
||||
if err = f.Sync(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/echovault/echovault/internal/config"
|
||||
"github.com/echovault/echovault/internal/constants"
|
||||
"github.com/tidwall/resp"
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -174,6 +176,7 @@ func Test_ACL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Test_HandleAuth", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -281,6 +284,7 @@ func Test_ACL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Test_HandleCat", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -387,6 +391,7 @@ func Test_ACL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Test_HandleUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -450,6 +455,7 @@ func Test_ACL(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Test_HandleSetUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -1044,6 +1050,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
})
|
||||
|
||||
t.Run("Test_HandleGetUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -1211,6 +1218,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
})
|
||||
|
||||
t.Run("Test_HandleDelUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -1319,6 +1327,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
})
|
||||
|
||||
t.Run("Test_HandleWhoAmI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -1391,6 +1400,7 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
})
|
||||
|
||||
t.Run("Test_HandleList", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -1543,9 +1553,188 @@ When nopass is provided, ignore any passwords that may have been provided in the
|
||||
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
||||
return
|
||||
}
|
||||
clear(resStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_HandleSave", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseDir := path.Join(".", "testdata", "save")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
want []string // Response from ACL List command.
|
||||
}{
|
||||
{
|
||||
name: "1. Save ACL config to .json file",
|
||||
path: path.Join(baseDir, "json_test.json"),
|
||||
want: []string{
|
||||
"default on +@all +all %RW~* +&*",
|
||||
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
||||
generateSHA256Password("password3"), "%RW"),
|
||||
"no_password_user on nopass +@all +all %RW~* +&*",
|
||||
"disabled_user off >password5 +@all +all %RW~* +&*",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "2. Save ACL config to .yaml file",
|
||||
path: path.Join(baseDir, "yaml_test.yaml"),
|
||||
want: []string{
|
||||
"default on +@all +all %RW~* +&*",
|
||||
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
||||
generateSHA256Password("password3"), "%RW"),
|
||||
"no_password_user on nopass +@all +all %RW~* +&*",
|
||||
"disabled_user off >password5 +@all +all %RW~* +&*",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "3. Save ACL config to .yml file",
|
||||
path: path.Join(baseDir, "yml_test.yml"),
|
||||
want: []string{
|
||||
"default on +@all +all %RW~* +&*",
|
||||
fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*",
|
||||
generateSHA256Password("password3"), "%RW"),
|
||||
"no_password_user on nopass +@all +all %RW~* +&*",
|
||||
"disabled_user off >password5 +@all +all %RW~* +&*",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Get free port.
|
||||
port, err := internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create new server instance
|
||||
mockServer, err := setUpServer(port, false, test.path)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
|
||||
conn, err := internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
client := resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err := client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.EqualFold(res.String(), "ok") {
|
||||
t.Errorf("expected OK response, got \"%s\"", res.String())
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Close client connection
|
||||
if err = conn.Close(); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Shutdown the mock server
|
||||
mockServer.ShutDown()
|
||||
|
||||
// Restart server and create new client connection
|
||||
port, err = internal.GetFreePort()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
mockServer, err = setUpServer(port, false, test.path)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
mockServer.Start()
|
||||
}()
|
||||
|
||||
conn, err = internal.GetConnection("localhost", port)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
client = resp.NewConn(conn)
|
||||
|
||||
if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
res, _, err = client.ReadValue()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if ACL LIST returns the expected list of users.
|
||||
resArr := res.Array()
|
||||
if len(resArr) != len(test.want) {
|
||||
t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr))
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
|
||||
var resStr []string
|
||||
for i := 0; i < len(resArr); i++ {
|
||||
resStr = strings.Split(resArr[i].String(), " ")
|
||||
if !slices.ContainsFunc(test.want, func(s string) bool {
|
||||
expectedUserSlice := strings.Split(s, " ")
|
||||
return compareSlices(resStr, expectedUserSlice) == nil
|
||||
}) {
|
||||
t.Errorf("could not find the following user in expected slice: %+v", resStr)
|
||||
mockServer.ShutDown()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mockServer.ShutDown()
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Test_HandleLoad", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
baseDir := path.Join(".", "testdata", "load")
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = os.RemoveAll(baseDir)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user