diff --git a/internal/modules/acl/commands_test.go b/internal/modules/acl/commands_test.go index 1bf861e..8fbe635 100644 --- a/internal/modules/acl/commands_test.go +++ b/internal/modules/acl/commands_test.go @@ -51,6 +51,14 @@ func setUpServer(port int, requirePass bool, aclConfig string) (*echovault.EchoV // Add the initial test users to the ACL module. for _, user := range generateInitialTestUsers() { + // If the user already exists in the server, skip. + existingUsers, err := mockServer.ACLUsers() + if err != nil { + return nil, err + } + if slices.Contains(existingUsers, user.Username) { + continue + } if _, err := mockServer.ACLSetUser(user); err != nil { return nil, err } @@ -157,6 +165,8 @@ func generateSHA256Password(plain string) string { } func Test_ACL(t *testing.T) { + t.Parallel() + port, err := internal.GetFreePort() if err != nil { t.Error(err) @@ -372,7 +382,7 @@ func Test_ACL(t *testing.T) { resp.StringValue("0"), resp.StringValue("-1"), }, - wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%R", "key3"), + wantErr: fmt.Sprintf("not authorised to access the following read keys: [%s~%s]", "%R", "key3"), }, { name: "10. Return error when trying to write to keys that are not in write keys list", @@ -387,7 +397,7 @@ func Test_ACL(t *testing.T) { resp.StringValue("0"), resp.StringValue("3"), }, - wantErr: fmt.Sprintf("not authorised to access the following keys: [%s~%s]", "%W", "key3"), + wantErr: fmt.Sprintf("not authorised to access the following write keys: [%s~%s]", "%W", "key3"), }, } @@ -1854,10 +1864,15 @@ func Test_ACL(t *testing.T) { // Check if ACL LIST returns the expected list of users. resArr := res.Array() if len(resArr) != len(test.want) { - t.Errorf("expected response of lenght %d, got lenght %d", len(test.want), len(resArr)) + t.Errorf("expected response of lenght %d, got length %d", len(test.want), len(resArr)) return } + fmt.Println("USER LIST: ") + for j, user := range resArr { + fmt.Printf("%d) %+v\n", j, user) + } + var resStr []string for i := 0; i < len(resArr); i++ { resStr = strings.Split(resArr[i].String(), " ") diff --git a/internal/raft/fsm.go b/internal/raft/fsm.go index 92cf9ac..71cf101 100644 --- a/internal/raft/fsm.go +++ b/internal/raft/fsm.go @@ -162,16 +162,16 @@ func (fsm *FSM) Restore(snapshot io.ReadCloser) error { } // Set state - ctx := context.Background() - for _, data := range internal.FilterExpiredKeys(time.Now(), data.State) { - for k, v := range data { - // TODO: Set values according to database. - if err = fsm.options.SetValues(ctx, map[string]interface{}{k: v.Value}); err != nil { + for database, data := range internal.FilterExpiredKeys(time.Now(), data.State) { + ctx := context.WithValue(context.Background(), "Database", database) + for key, keyData := range data { + if err = fsm.options.SetValues(ctx, map[string]interface{}{key: keyData.Value}); err != nil { log.Fatal(err) } - fsm.options.SetExpiry(ctx, k, v.ExpireAt, false) + fsm.options.SetExpiry(ctx, key, keyData.ExpireAt, false) } } + // Set latest snapshot milliseconds. fsm.options.SetLatestSnapshotTime(data.LatestSnapshotMilliseconds)