diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 7df80f0..ed3d030 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -504,8 +504,9 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn } func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { - if len(cmd) < 2 { - return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + keys, err := zinterKeyFunc(cmd) + if err != nil { + return nil, err } keys, weights, aggregate, withscores, err := extractKeysWeightsAggregateWithScores(cmd) @@ -522,53 +523,44 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * } }() - var sets []*SortedSet + var setParams []SortedSetParam - for _, key := range keys { - if server.KeyExists(key) { - _, err := server.KeyRLock(ctx, key) - if err != nil { - return nil, err - } - locks[key] = true - set, ok := server.GetValue(key).(*SortedSet) - if !ok { - return nil, fmt.Errorf("value at %s is not a sorted set", key) - } - sets = append(sets, set) + for i := 0; i < len(keys); i++ { + // If key does not exist, return an empty array + if !server.KeyExists(keys[i]) { + return []byte("*0\r\n\r\n"), nil } - } - - var intersect *SortedSet - - if len(sets) > 1 { - if intersect, err = sets[0].Intersect(sets[1:], weights, aggregate); err != nil { + if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } - } else if len(sets) == 1 { - intersect = sets[0] - } else { - return nil, errors.New("not enough sets to form an intersect") + locks[keys[i]] = true + set, ok := server.GetValue(keys[i]).(*SortedSet) + if !ok { + return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) + } + setParams = append(setParams, SortedSetParam{ + set: set, + weight: weights[i], + }) } + intersect := Intersect(aggregate, setParams...) + res := fmt.Sprintf("*%d", intersect.Cardinality()) if intersect.Cardinality() > 0 { - for i, m := range intersect.GetAll() { + for _, m := range intersect.GetAll() { if withscores { - s := fmt.Sprintf("%s %f", m.value, m.score) + s := fmt.Sprintf("%s %s", m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64)) res += fmt.Sprintf("\r\n$%d\r\n%s", len(s), s) } else { - res += fmt.Sprintf("\r\n%s", m.value) - } - if i == intersect.Cardinality()-1 { - res += "\r\n\r\n" + res += fmt.Sprintf("\r\n$%d\r\n%s", len(m.value), m.value) } } - } else { - res += "\r\n\r\n" } + res += "\r\n\r\n" + return []byte(res), nil } diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 28f9c03..55c8548 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -942,7 +942,308 @@ func Test_HandleZDIFFSTORE(t *testing.T) { func Test_HandleZINCRBY(t *testing.T) {} -func Test_HandleZINTER(t *testing.T) {} +func Test_HandleZINTER(t *testing.T) { + mockServer := server.NewServer(server.Opts{}) + + tests := []struct { + preset bool + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { // 1. Get the intersection between 2 sorted sets. + preset: true, + presetValues: map[string]interface{}{ + "key1": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, + }), + "key2": NewSortedSet([]MemberParam{ + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + }, + command: []string{"ZINTER", "key1", "key2"}, + expectedResponse: []string{"three", "four", "five"}, + expectedError: nil, + }, + { + // 2. Get the intersection between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + preset: true, + presetValues: map[string]interface{}{ + "key3": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key4": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 8}, + }), + "key5": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key3", "key4", "key5", "WITHSCORES"}, + expectedResponse: []string{"one 3", "eight 24"}, + expectedError: nil, + }, + { + // 3. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate. + preset: true, + presetValues: map[string]interface{}{ + "key6": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key7": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key8": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key6", "key7", "key8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedResponse: []string{"one 1", "eight 8"}, + expectedError: nil, + }, + { + // 4. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate. + preset: true, + presetValues: map[string]interface{}{ + "key9": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key10": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key11": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key9", "key10", "key11", "WITHSCORES", "AGGREGATE", "MAX"}, + expectedResponse: []string{"one 1000", "eight 800"}, + expectedError: nil, + }, + { + // 5. Get the intersection between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + preset: true, + presetValues: map[string]interface{}{ + "key12": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key13": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key14": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key12", "key13", "key14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, + expectedResponse: []string{"one 3105", "eight 2808"}, + expectedError: nil, + }, + { + // 6. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + preset: true, + presetValues: map[string]interface{}{ + "key15": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key16": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key17": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key15", "key16", "key17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, + expectedResponse: []string{"one 3000", "eight 2400"}, + expectedError: nil, + }, + { + // 7. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + preset: true, + presetValues: map[string]interface{}{ + "key18": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key19": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key20": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "key18", "key19", "key20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, + expectedResponse: []string{"one 5", "eight 8"}, + expectedError: nil, + }, + { // 8. Throw an error if there are more weights than keys + preset: true, + presetValues: map[string]interface{}{ + "key21": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key22": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTER", "key21", "key22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { // 9. Throw an error if there are fewer weights than keys + preset: true, + presetValues: map[string]interface{}{ + "key23": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key24": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + }), + "key25": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTER", "key23", "key24", "key25", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { // 10. Throw an error if there are no keys provided + preset: true, + presetValues: map[string]interface{}{ + "key26": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + "key27": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + "key28": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTER", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New(utils.WRONG_ARGS_RESPONSE), + }, + { // 11. Throw an error if any of the provided keys are not sorted sets + preset: true, + presetValues: map[string]interface{}{ + "key29": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key30": "Default value", + "key31": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTER", "key29", "key30", "key31"}, + expectedResponse: nil, + expectedError: errors.New("value at key30 is not a sorted set"), + }, + { // 5. If any of the keys does not exist, return an empty array. + preset: true, + presetValues: map[string]interface{}{ + "key32": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, + }), + "key33": NewSortedSet([]MemberParam{ + {value: "seven", score: 7}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTER", "non-existent", "key32", "key33"}, + expectedResponse: []string{}, + expectedError: nil, + }, + { // 13. Command too short + preset: false, + command: []string{"ZINTER"}, + expectedResponse: []string{}, + expectedError: errors.New(utils.WRONG_ARGS_RESPONSE), + }, + } + + for _, test := range tests { + if test.preset { + for key, value := range test.presetValues { + if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + t.Error(err) + } + mockServer.SetValue(context.Background(), key, value) + mockServer.KeyUnlock(key) + } + } + res, err := handleZINTER(context.Background(), test.command, mockServer, nil) + if test.expectedError != nil { + if err.Error() != test.expectedError.Error() { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + continue + } + if err != nil { + t.Error(err) + } + rd := resp.NewReader(bytes.NewBuffer(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + for _, responseElement := range rv.Array() { + if !slices.Contains(test.expectedResponse, responseElement.String()) { + t.Errorf("could not find response element \"%s\" from expected response array", responseElement.String()) + } + } + } +} func Test_HandleZINTERSTORE(t *testing.T) {} diff --git a/src/modules/sorted_set/key_funcs.go b/src/modules/sorted_set/key_funcs.go index 81a77b0..4c8fcfd 100644 --- a/src/modules/sorted_set/key_funcs.go +++ b/src/modules/sorted_set/key_funcs.go @@ -74,7 +74,7 @@ func zinterKeyFunc(cmd []string) ([]string, error) { return cmd[1:], nil } if endIdx >= 1 { - return cmd[1 : endIdx+1], nil + return cmd[1:endIdx], nil } return nil, errors.New(utils.WRONG_ARGS_RESPONSE) } diff --git a/src/modules/sorted_set/sorted_set.go b/src/modules/sorted_set/sorted_set.go index 5cf8a7f..3b170c6 100644 --- a/src/modules/sorted_set/sorted_set.go +++ b/src/modules/sorted_set/sorted_set.go @@ -294,6 +294,83 @@ func (set *SortedSet) Union(others []*SortedSet, weights []int, aggregate string return res, nil } +// SortedSetParam is a composite object used for Intersect and Union function +type SortedSetParam struct { + set *SortedSet + weight int +} + +func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet { + switch len(setParams) { + case 1: + var params []MemberParam + for _, member := range setParams[0].set.GetAll() { + params = append(params, MemberParam{ + value: member.value, + score: member.score * Score(setParams[0].weight), + }) + } + return NewSortedSet(params) + case 2: + var params []MemberParam + // Traverse the params in the left sorted set + for _, member := range setParams[0].set.GetAll() { + // Check if the member exists in the right sorted set + if !setParams[1].set.Contains(member.value) { + continue + } + // If the member exists, get both elements and apply the weight + param := MemberParam{ + value: member.value, + score: func(left, right Score) Score { + // Choose which param to add to params depending on the aggregate + switch aggregate { + case "sum": + return left + right + case "min": + return compareScores(left, right, "lt") + default: + // Aggregate is "max" + return compareScores(left, right, "gt") + } + }( + member.score*Score(setParams[0].weight), + setParams[1].set.Get(member.value).score*Score(setParams[1].weight), + ), + } + params = append(params, param) + } + return NewSortedSet(params) + default: + // Divide the sets into 2 and return the intersection + left := Intersect(aggregate, setParams[0:len(setParams)/2]...) + right := Intersect(aggregate, setParams[len(setParams)/2:]...) + + var params []MemberParam + for _, member := range left.GetAll() { + if !right.Contains(member.value) { + continue + } + params = append(params, MemberParam{ + value: member.value, + score: func(left, right Score) Score { + switch aggregate { + case "sum": + return left + right + case "min": + return compareScores(left, right, "lt") + default: + // Aggregate is "max" + return compareScores(left, right, "gt") + } + }(member.score, right.Get(member.value).score), + }) + } + + return NewSortedSet(params) + } +} + func (set *SortedSet) Intersect(others []*SortedSet, weights []int, aggregate string) (*SortedSet, error) { res := NewSortedSet([]MemberParam{}) // Find intersect between this set and the first set in others diff --git a/src/modules/sorted_set/utils.go b/src/modules/sorted_set/utils.go index c8ec3db..96e020d 100644 --- a/src/modules/sorted_set/utils.go +++ b/src/modules/sorted_set/utils.go @@ -9,16 +9,13 @@ import ( ) func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, string, bool, error) { - firstModifierIndex := -1 - var weights []int weightsIndex := slices.IndexFunc(cmd, func(s string) bool { return strings.EqualFold(s, "weights") }) if weightsIndex != -1 { - firstModifierIndex = weightsIndex for i := weightsIndex + 1; i < len(cmd); i++ { - if slices.Contains([]string{"aggregate", "withscores"}, cmd[i]) { + if slices.Contains([]string{"aggregate", "withscores"}, strings.ToLower(cmd[i])) { break } w, err := strconv.Atoi(cmd[i]) @@ -34,14 +31,6 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin return strings.EqualFold(s, "aggregate") }) if aggregateIndex != -1 { - if firstModifierIndex != -1 && (aggregateIndex != -1 && aggregateIndex < firstModifierIndex) { - firstModifierIndex = aggregateIndex - } else if firstModifierIndex == -1 { - firstModifierIndex = aggregateIndex - } - if aggregateIndex >= len(cmd)-1 { - return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX") - } if !slices.Contains([]string{"sum", "min", "max"}, strings.ToLower(cmd[aggregateIndex+1])) { return []string{}, []int{}, "", false, errors.New("aggregate must be SUM, MIN, or MAX") } @@ -53,25 +42,31 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin return strings.EqualFold(s, "withscores") }) if withscoresIndex != -1 { - if firstModifierIndex != -1 && (withscoresIndex != -1 && withscoresIndex < firstModifierIndex) { - firstModifierIndex = withscoresIndex - } else if firstModifierIndex == -1 { - firstModifierIndex = withscoresIndex - } withscores = true } + // Get the first modifier index as this will be the upper boundary when extracting the keys + firstModifierIndex := -1 + for _, modifierIndex := range []int{weightsIndex, aggregateIndex, withscoresIndex} { + if modifierIndex == -1 { + continue + } + if firstModifierIndex == -1 { + firstModifierIndex = modifierIndex + continue + } + if modifierIndex < firstModifierIndex { + firstModifierIndex = modifierIndex + } + } + var keys []string if firstModifierIndex == -1 { keys = cmd[1:] - } else if firstModifierIndex != -1 && firstModifierIndex < 2 { - return []string{}, []int{}, "", false, errors.New("must provide at least 1 key") } else { keys = cmd[1:firstModifierIndex] } - if len(keys) < 1 { - return []string{}, []int{}, "", false, errors.New("must provide at least 1 key") - } + if weightsIndex != -1 && (len(keys) != len(weights)) { return []string{}, []int{}, "", false, errors.New("number of weights should match number of keys") } else if weightsIndex == -1 {