diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 1b63236..d4ec877 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -300,14 +300,11 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con } func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { - if len(cmd) < 3 { - return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + keys, err := zdiffKeyFunc(cmd) + if err != nil { + return nil, err } - keys := utils.Filter(cmd[1:], func(s string) bool { - return !strings.EqualFold(s, "withscores") - }) - withscoresIndex := slices.IndexFunc(cmd, func(s string) bool { return strings.EqualFold(s, "withscores") }) @@ -324,34 +321,36 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n } }() + // Extract base set + if _, err = server.KeyRLock(ctx, keys[0]); err != nil { + return nil, err + } + defer server.KeyRUnlock(keys[0]) + baseSortedSet, ok := server.GetValue(keys[0]).(*SortedSet) + if !ok { + return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) + } + + // Extract the remaining sets var sets []*SortedSet - for _, key := range keys { - if !server.KeyExists(key) { + for i := 1; i < len(keys); i++ { + if !server.KeyExists(keys[i]) { continue } - locked, err := server.KeyRLock(ctx, key) + locked, err := server.KeyRLock(ctx, keys[i]) if err != nil { return nil, err } - locks[key] = locked - set, ok := server.GetValue(key).(*SortedSet) + locks[keys[i]] = locked + set, ok := server.GetValue(keys[i]).(*SortedSet) if !ok { - return nil, fmt.Errorf("value at error %s is not a sorted set", key) + return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } sets = append(sets, set) } - var diff *SortedSet - - switch len(sets) { - case 0: - return []byte("*0\r\n\r\n"), nil - case 1: - diff = sets[0] - default: - diff = sets[0].Subtract(sets[1:]) - } + var diff = baseSortedSet.Subtract(sets) res := fmt.Sprintf("*%d", diff.Cardinality()) includeScores := withscoresIndex != -1 && withscoresIndex >= 2 @@ -359,7 +358,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n var str string for i, m := range diff.GetAll() { if includeScores { - str = fmt.Sprintf("%s %f", m.value, m.score) + str = fmt.Sprintf("%s %s", m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64)) res += fmt.Sprintf("\r\n$%d\r\n%s", len(str), str) } else { str = string(m.value) diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 3a6cebc..f128fd0 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -8,6 +8,7 @@ import ( "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" "math" + "slices" "testing" ) @@ -610,7 +611,157 @@ func Test_HandleZLEXCOUNT(t *testing.T) { } } -func Test_HandleZDIFF(t *testing.T) {} +func Test_HandleZDIFF(t *testing.T) { + mockServer := server.NewServer(server.Opts{}) + + tests := []struct { + preset bool + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { // 1. Get the difference between 2 sorted sets without scores. + preset: true, + presetValues: map[string]interface{}{ + "key1": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, + {value: "two", score: 2}, + {value: "three", score: 3}, + {value: "four", score: 4}, + }), + "key2": NewSortedSet([]MemberParam{ + {value: "three", score: 3}, + {value: "four", score: 4}, + {value: "five", score: 5}, + {value: "six", score: 6}, + {value: "seven", score: 7}, + {value: "eight", score: 8}, + }), + }, + command: []string{"ZDIFF", "key1", "key2"}, + expectedResponse: []string{"one", "two"}, + expectedError: nil, + }, + { // 2. Get the difference between 2 sorted sets with scores. + preset: true, + presetValues: map[string]interface{}{ + "key1": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, + {value: "two", score: 2}, + {value: "three", score: 3}, + {value: "four", score: 4}, + }), + "key2": NewSortedSet([]MemberParam{ + {value: "three", score: 3}, + {value: "four", score: 4}, + {value: "five", score: 5}, + {value: "six", score: 6}, + {value: "seven", score: 7}, + {value: "eight", score: 8}, + }), + }, + command: []string{"ZDIFF", "key1", "key2", "WITHSCORES"}, + expectedResponse: []string{"one 1", "two 2"}, + expectedError: nil, + }, + { // 3. Get the difference between 3 sets with scores. + preset: true, + presetValues: map[string]interface{}{ + "key3": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key4": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, + }), + "key5": NewSortedSet([]MemberParam{ + {value: "seven", score: 7}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZDIFF", "key3", "key4", "key5", "WITHSCORES"}, + expectedResponse: []string{"three 3", "four 4", "five 5", "six 6"}, + expectedError: nil, + }, + { // 3. Return sorted set if only one key exists and is a sorted set + preset: true, + presetValues: map[string]interface{}{ + "key6": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + }, + command: []string{"ZDIFF", "key6", "key7", "key8", "WITHSCORES"}, + expectedResponse: []string{"one 1", "two 2", "three 3", "four 4", "five 5", "six 6", "seven 7", "eight 8"}, + expectedError: nil, + }, + { // 4. Throw error when one of the keys is not a sorted set. + preset: true, + presetValues: map[string]interface{}{ + "key9": "Default value", + "key10": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, + }), + "key11": NewSortedSet([]MemberParam{ + {value: "seven", score: 7}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZDIFF", "key9", "key10", "key11"}, + expectedResponse: nil, + expectedError: errors.New("value at key9 is not a sorted set"), + }, + { // 6. Command too short + preset: false, + command: []string{"ZDIFF"}, + expectedResponse: []string{}, + expectedError: errors.New(utils.WRONG_ARGS_RESPONSE), + }, + } + + for _, test := range tests { + if test.preset { + for key, value := range test.presetValues { + if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + t.Error(err) + } + mockServer.SetValue(context.Background(), key, value) + mockServer.KeyUnlock(key) + } + } + res, err := handleZDIFF(context.Background(), test.command, mockServer, nil) + if test.expectedError != nil { + if err.Error() != test.expectedError.Error() { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + continue + } + if err != nil { + t.Error(err) + } + rd := resp.NewReader(bytes.NewBuffer(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + for _, responseElement := range rv.Array() { + if !slices.Contains(test.expectedResponse, responseElement.String()) { + t.Errorf("could not find response element \"%s\" from expected response array", responseElement.String()) + } + } + } +} func Test_HandleZDIFFSTORE(t *testing.T) {} diff --git a/src/modules/sorted_set/key_funcs.go b/src/modules/sorted_set/key_funcs.go index 1991739..81a77b0 100644 --- a/src/modules/sorted_set/key_funcs.go +++ b/src/modules/sorted_set/key_funcs.go @@ -32,10 +32,16 @@ func zdiffKeyFunc(cmd []string) ([]string, error) { if len(cmd) < 2 { return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } - keys := utils.Filter(cmd[1:], func(elem string) bool { - return !strings.EqualFold(elem, "WITHSCORES") + + withscoresIndex := slices.IndexFunc(cmd, func(s string) bool { + return strings.EqualFold(s, "withscores") }) - return keys, nil + + if withscoresIndex == -1 { + return cmd[1:], nil + } + + return cmd[1:withscoresIndex], nil } func zdiffstoreKeyFunc(cmd []string) ([]string, error) { diff --git a/src/modules/sorted_set/sorted_set.go b/src/modules/sorted_set/sorted_set.go index 0369493..5cf8a7f 100644 --- a/src/modules/sorted_set/sorted_set.go +++ b/src/modules/sorted_set/sorted_set.go @@ -201,6 +201,42 @@ func (set *SortedSet) Remove(v Value) bool { return false } +func (set *SortedSet) Pop(count int, policy string) (*SortedSet, error) { + popped := NewSortedSet([]MemberParam{}) + if !slices.Contains([]string{"min", "max"}, strings.ToLower(policy)) { + return nil, errors.New("policy must be MIN or MAX") + } + if count < 0 { + return nil, errors.New("count must be a positive integer") + } + if count == 0 { + return popped, nil + } + + members := set.GetAll() + + slices.SortFunc(members, func(a, b MemberParam) int { + if strings.EqualFold(policy, "min") { + return cmp.Compare(a.score, b.score) + } + return cmp.Compare(b.score, a.score) + }) + + for i := 0; i < count; i++ { + if i < len(members) { + set.Remove(members[i].value) + _, err := popped.AddOrUpdate([]MemberParam{members[i]}, nil, nil, nil, nil) + if err != nil { + fmt.Println(err.Error()) + // TODO: Add all the removed elements back if we encounter an error + return nil, err + } + } + } + + return popped, nil +} + func (set *SortedSet) Subtract(others []*SortedSet) *SortedSet { res := NewSortedSet(set.GetAll()) for _, ss := range others { @@ -301,39 +337,3 @@ func (set *SortedSet) Intersect(others []*SortedSet, weights []int, aggregate st } return res, nil } - -func (set *SortedSet) Pop(count int, policy string) (*SortedSet, error) { - popped := NewSortedSet([]MemberParam{}) - if !slices.Contains([]string{"min", "max"}, strings.ToLower(policy)) { - return nil, errors.New("policy must be MIN or MAX") - } - if count < 0 { - return nil, errors.New("count must be a positive integer") - } - if count == 0 { - return popped, nil - } - - members := set.GetAll() - - slices.SortFunc(members, func(a, b MemberParam) int { - if strings.EqualFold(policy, "min") { - return cmp.Compare(a.score, b.score) - } - return cmp.Compare(b.score, a.score) - }) - - for i := 0; i < count; i++ { - if i < len(members) { - set.Remove(members[i].value) - _, err := popped.AddOrUpdate([]MemberParam{members[i]}, nil, nil, nil, nil) - if err != nil { - fmt.Println(err.Error()) - // TODO: Add all the removed elements back if we encounter an error - return nil, err - } - } - } - - return popped, nil -}