diff --git a/internal/modules/acl/acl.go b/internal/modules/acl/acl.go index 752a14d..a3f87fd 100644 --- a/internal/modules/acl/acl.go +++ b/internal/modules/acl/acl.go @@ -49,6 +49,45 @@ type ACL struct { GlobPatterns map[string]glob.Glob } +func loadUsersFromConfigFile(users []*User, filePath string) { + if filePath != "" { + // Create the director if it does not exist. + if err := os.MkdirAll(path.Dir(filePath), os.ModePerm); err != nil { + log.Printf("mkdir ACL config: %v\n", err) + return + } + // Open the config file. Create it if it does not exist. + f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, os.ModePerm) + if err != nil { + log.Printf("open ACL config: %v\n", err) + return + } + + defer func() { + if err := f.Close(); err != nil { + log.Printf("close ACL config: %v\n", err) + } + }() + + ext := path.Ext(f.Name()) + + if strings.ToLower(ext) == ".json" { + if err := json.NewDecoder(f).Decode(&users); err != nil { + log.Printf("load ACL config: %v\n", err) + return + } + } + + if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) { + if err := yaml.NewDecoder(f).Decode(&users); err != nil { + log.Printf("load ACL config: %v\n", err) + return + } + } + + } +} + func NewACL(config config.Config) *ACL { var users []*User @@ -65,32 +104,7 @@ func NewACL(config config.Config) *ACL { } // 2. Read and parse the ACL config file - if config.AclConfig != "" { - // Override acl configurations from file - if f, err := os.Open(config.AclConfig); err != nil { - panic(err) - } else { - defer func() { - if err := f.Close(); err != nil { - log.Printf("acl config file close: %v\n", err) - } - }() - - ext := path.Ext(f.Name()) - - if ext == ".json" { - if err := json.NewDecoder(f).Decode(&users); err != nil { - log.Printf("load ACL config: %v\n", err) - } - } - - if slices.Contains([]string{".yaml", ".yml"}, ext) { - if err := yaml.NewDecoder(f).Decode(&users); err != nil { - log.Printf("load ACL config: %v\n", err) - } - } - } - } + loadUsersFromConfigFile(users, config.AclConfig) // 3. If default user was not loaded from file, add the created one defaultLoaded := false diff --git a/internal/modules/acl/commands.go b/internal/modules/acl/commands.go index 432b596..839d94a 100644 --- a/internal/modules/acl/commands.go +++ b/internal/modules/acl/commands.go @@ -442,7 +442,7 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) { acl.RLockUsers() acl.RUnlockUsers() - f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE, os.ModeAppend) + f, err := os.OpenFile(acl.Config.AclConfig, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { return nil, err } @@ -455,32 +455,29 @@ func handleSave(params internal.HandlerFuncParams) ([]byte, error) { ext := path.Ext(f.Name()) - if ext == ".json" { + if strings.ToLower(ext) == ".json" { // Write to JSON config file out, err := json.Marshal(acl.Users) if err != nil { return nil, err } - _, err = f.Write(out) - if err != nil { + if _, err = f.Write(out); err != nil { return nil, err } } - if ext == ".yaml" || ext == ".yml" { + if slices.Contains([]string{".yaml", ".yml"}, strings.ToLower(ext)) { // Write to yaml file out, err := yaml.Marshal(acl.Users) if err != nil { return nil, err } - _, err = f.Write(out) - if err != nil { + if _, err = f.Write(out); err != nil { return nil, err } } - err = f.Sync() - if err != nil { + if err = f.Sync(); err != nil { return nil, err } diff --git a/internal/modules/acl/commands_test.go b/internal/modules/acl/commands_test.go index 7e43efd..0b4c78e 100644 --- a/internal/modules/acl/commands_test.go +++ b/internal/modules/acl/commands_test.go @@ -22,6 +22,8 @@ import ( "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" "github.com/tidwall/resp" + "os" + "path" "slices" "strings" "testing" @@ -174,6 +176,7 @@ func Test_ACL(t *testing.T) { }) t.Run("Test_HandleAuth", func(t *testing.T) { + t.Parallel() conn, err := internal.GetConnection("localhost", port) if err != nil { t.Error(err) @@ -281,6 +284,7 @@ func Test_ACL(t *testing.T) { }) t.Run("Test_HandleCat", func(t *testing.T) { + t.Parallel() conn, err := internal.GetConnection("localhost", port) if err != nil { t.Error(err) @@ -387,6 +391,7 @@ func Test_ACL(t *testing.T) { }) t.Run("Test_HandleUsers", func(t *testing.T) { + t.Parallel() port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -450,6 +455,7 @@ func Test_ACL(t *testing.T) { }) t.Run("Test_HandleSetUser", func(t *testing.T) { + t.Parallel() port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -1044,6 +1050,7 @@ When nopass is provided, ignore any passwords that may have been provided in the }) t.Run("Test_HandleGetUser", func(t *testing.T) { + t.Parallel() port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -1211,6 +1218,7 @@ When nopass is provided, ignore any passwords that may have been provided in the }) t.Run("Test_HandleDelUser", func(t *testing.T) { + t.Parallel() port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -1319,6 +1327,7 @@ When nopass is provided, ignore any passwords that may have been provided in the }) t.Run("Test_HandleWhoAmI", func(t *testing.T) { + t.Parallel() conn, err := internal.GetConnection("localhost", port) if err != nil { t.Error(err) @@ -1391,6 +1400,7 @@ When nopass is provided, ignore any passwords that may have been provided in the }) t.Run("Test_HandleList", func(t *testing.T) { + t.Parallel() port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -1543,9 +1553,188 @@ When nopass is provided, ignore any passwords that may have been provided in the t.Errorf("could not find the following user in expected slice: %+v", resStr) return } - clear(resStr) } }) } }) + + t.Run("Test_HandleSave", func(t *testing.T) { + t.Parallel() + + baseDir := path.Join(".", "testdata", "save") + + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + + tests := []struct { + name string + path string + want []string // Response from ACL List command. + }{ + { + name: "1. Save ACL config to .json file", + path: path.Join(baseDir, "json_test.json"), + want: []string{ + "default on +@all +all %RW~* +&*", + fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*", + generateSHA256Password("password3"), "%RW"), + "no_password_user on nopass +@all +all %RW~* +&*", + "disabled_user off >password5 +@all +all %RW~* +&*", + }, + }, + { + name: "2. Save ACL config to .yaml file", + path: path.Join(baseDir, "yaml_test.yaml"), + want: []string{ + "default on +@all +all %RW~* +&*", + fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*", + generateSHA256Password("password3"), "%RW"), + "no_password_user on nopass +@all +all %RW~* +&*", + "disabled_user off >password5 +@all +all %RW~* +&*", + }, + }, + { + name: "3. Save ACL config to .yml file", + path: path.Join(baseDir, "yml_test.yml"), + want: []string{ + "default on +@all +all %RW~* +&*", + fmt.Sprintf("with_password_user on >password2 #%s +@all +all %s~* +&*", + generateSHA256Password("password3"), "%RW"), + "no_password_user on nopass +@all +all %RW~* +&*", + "disabled_user off >password5 +@all +all %RW~* +&*", + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Get free port. + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + + // Create new server instance + mockServer, err := setUpServer(port, false, test.path) + if err != nil { + t.Error(err) + return + } + + go func() { + mockServer.Start() + }() + + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + client := resp.NewConn(conn) + + if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("SAVE")}); err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected OK response, got \"%s\"", res.String()) + mockServer.ShutDown() + return + } + + // Close client connection + if err = conn.Close(); err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + + // Shutdown the mock server + mockServer.ShutDown() + + // Restart server and create new client connection + port, err = internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + mockServer, err = setUpServer(port, false, test.path) + if err != nil { + t.Error(err) + return + } + go func() { + mockServer.Start() + }() + + conn, err = internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + client = resp.NewConn(conn) + + if err = client.WriteArray([]resp.Value{resp.StringValue("ACL"), resp.StringValue("LIST")}); err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + mockServer.ShutDown() + return + } + + // Check if ACL LIST returns the expected list of users. + resArr := res.Array() + if len(resArr) != len(test.want) { + t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr)) + mockServer.ShutDown() + return + } + + var resStr []string + for i := 0; i < len(resArr); i++ { + resStr = strings.Split(resArr[i].String(), " ") + if !slices.ContainsFunc(test.want, func(s string) bool { + expectedUserSlice := strings.Split(s, " ") + return compareSlices(resStr, expectedUserSlice) == nil + }) { + t.Errorf("could not find the following user in expected slice: %+v", resStr) + mockServer.ShutDown() + return + } + } + + mockServer.ShutDown() + }) + } + }) + + t.Run("Test_HandleLoad", func(t *testing.T) { + t.Parallel() + + baseDir := path.Join(".", "testdata", "load") + + t.Cleanup(func() { + _ = os.RemoveAll(baseDir) + }) + }) }