diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index ed3d030..611d457 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -526,8 +526,8 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * var setParams []SortedSetParam for i := 0; i < len(keys); i++ { - // If key does not exist, return an empty array if !server.KeyExists(keys[i]) { + // If any of the keys is non-existent, return an empty array as there's no intersect return []byte("*0\r\n\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[i]); err != nil { @@ -565,12 +565,14 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * } func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) { - if len(cmd) < 3 { - return nil, errors.New(utils.WRONG_ARGS_RESPONSE) + keys, err := zinterstoreKeyFunc(cmd) + if err != nil { + return nil, err } - destination := cmd[1] + destination := keys[0] + // Remove the destination keys from the command before parsing it cmd = slices.DeleteFunc(cmd, func(s string) bool { return s == destination }) @@ -589,39 +591,34 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c } }() - var sets []*SortedSet + var setParams []SortedSetParam - for _, key := range keys { - _, err := server.KeyRLock(ctx, key) - if err != nil { + for i := 0; i < len(keys); i++ { + if !server.KeyExists(keys[i]) { + return []byte(":0\r\n\r\n"), nil + } + if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } - locks[key] = true - set, ok := server.GetValue(key).(*SortedSet) + locks[keys[i]] = true + set, ok := server.GetValue(keys[i]).(*SortedSet) if !ok { - return nil, fmt.Errorf("value at %s is not a sorted set", key) + return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } - sets = append(sets, set) + setParams = append(setParams, SortedSetParam{ + set: set, + weight: weights[i], + }) } - var intersect *SortedSet + intersect := Intersect(aggregate, setParams...) - if len(sets) > 1 { - if intersect, err = sets[0].Intersect(sets[1:], weights, aggregate); err != nil { + if server.KeyExists(destination) && intersect.Cardinality() > 0 { + if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } - } else if len(sets) == 1 { - intersect = sets[0] - } else { - return nil, errors.New("not enough sets to form an intersect") - } - - if server.KeyExists(destination) { - if _, err := server.KeyLock(ctx, destination); err != nil { - return nil, err - } - } else { - if _, err := server.CreateKeyAndLock(ctx, destination); err != nil { + } else if intersect.Cardinality() > 0 { + if _, err = server.CreateKeyAndLock(ctx, destination); err != nil { return nil, err } } diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 55c8548..b5341c5 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -1186,7 +1186,7 @@ func Test_HandleZINTER(t *testing.T) { expectedResponse: nil, expectedError: errors.New("value at key30 is not a sorted set"), }, - { // 5. If any of the keys does not exist, return an empty array. + { // 12. If any of the keys does not exist, return an empty array. preset: true, presetValues: map[string]interface{}{ "key32": NewSortedSet([]MemberParam{ @@ -1245,7 +1245,353 @@ func Test_HandleZINTER(t *testing.T) { } } -func Test_HandleZINTERSTORE(t *testing.T) {} +func Test_HandleZINTERSTORE(t *testing.T) { + mockServer := server.NewServer(server.Opts{}) + + tests := []struct { + preset bool + presetValues map[string]interface{} + destination string + command []string + expectedValue *SortedSet + expectedResponse int + expectedError error + }{ + { // 1. Get the intersection between 2 sorted sets. + preset: true, + presetValues: map[string]interface{}{ + "key1": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, + }), + "key2": NewSortedSet([]MemberParam{ + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + }, + destination: "destination1", + command: []string{"ZINTERSTORE", "destination1", "key1", "key2"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, + }), + expectedResponse: 3, + expectedError: nil, + }, + { + // 2. Get the intersection between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + preset: true, + presetValues: map[string]interface{}{ + "key3": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key4": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 8}, + }), + "key5": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination2", + command: []string{"ZINTERSTORE", "destination2", "key3", "key4", "key5", "WITHSCORES"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "eight", score: 24}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 3. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate. + preset: true, + presetValues: map[string]interface{}{ + "key6": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key7": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key8": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination3", + command: []string{"ZINTERSTORE", "destination3", "key6", "key7", "key8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "eight", score: 8}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 4. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate. + preset: true, + presetValues: map[string]interface{}{ + "key9": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key10": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key11": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination4", + command: []string{"ZINTERSTORE", "destination4", "key9", "key10", "key11", "WITHSCORES", "AGGREGATE", "MAX"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 5. Get the intersection between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + preset: true, + presetValues: map[string]interface{}{ + "key12": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key13": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key14": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination5", + command: []string{"ZINTERSTORE", "destination5", "key12", "key13", "key14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "eight", score: 2808}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 6. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + preset: true, + presetValues: map[string]interface{}{ + "key15": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key16": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key17": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination6", + command: []string{"ZINTERSTORE", "destination6", "key15", "key16", "key17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 3000}, {value: "eight", score: 2400}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 7. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + preset: true, + presetValues: map[string]interface{}{ + "key18": NewSortedSet([]MemberParam{ + {value: "one", score: 100}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key19": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, {value: "eight", score: 80}, + }), + "key20": NewSortedSet([]MemberParam{ + {value: "one", score: 1000}, {value: "eight", score: 800}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + destination: "destination7", + command: []string{"ZINTERSTORE", "destination7", "key18", "key19", "key20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, + expectedValue: NewSortedSet([]MemberParam{ + {value: "one", score: 5}, {value: "eight", score: 8}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { // 8. Throw an error if there are more weights than keys + preset: true, + presetValues: map[string]interface{}{ + "key21": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key22": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTERSTORE", "destination8", "key21", "key22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), + }, + { // 9. Throw an error if there are fewer weights than keys + preset: true, + presetValues: map[string]interface{}{ + "key23": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key24": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + }), + "key25": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTERSTORE", "destination9", "key23", "key24", "key25", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), + }, + { // 10. Throw an error if there are no keys provided + preset: true, + presetValues: map[string]interface{}{ + "key26": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + "key27": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + "key28": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTERSTORE", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New(utils.WRONG_ARGS_RESPONSE), + }, + { // 11. Throw an error if any of the provided keys are not sorted sets + preset: true, + presetValues: map[string]interface{}{ + "key29": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "three", score: 3}, {value: "four", score: 4}, + {value: "five", score: 5}, {value: "six", score: 6}, + {value: "seven", score: 7}, {value: "eight", score: 8}, + }), + "key30": "Default value", + "key31": NewSortedSet([]MemberParam{{value: "one", score: 1}}), + }, + command: []string{"ZINTERSTORE", "key29", "key30", "key31"}, + expectedResponse: 0, + expectedError: errors.New("value at key30 is not a sorted set"), + }, + { // 12. If any of the keys does not exist, return an empty array. + preset: true, + presetValues: map[string]interface{}{ + "key32": NewSortedSet([]MemberParam{ + {value: "one", score: 1}, {value: "two", score: 2}, + {value: "thirty-six", score: 36}, {value: "twelve", score: 12}, + {value: "eleven", score: 11}, + }), + "key33": NewSortedSet([]MemberParam{ + {value: "seven", score: 7}, {value: "eight", score: 8}, + {value: "nine", score: 9}, {value: "ten", score: 10}, + {value: "twelve", score: 12}, + }), + }, + command: []string{"ZINTERSTORE", "destination12", "non-existent", "key32", "key33"}, + expectedResponse: 0, + expectedError: nil, + }, + { // 13. Command too short + preset: false, + command: []string{"ZINTERSTORE"}, + expectedResponse: 0, + expectedError: errors.New(utils.WRONG_ARGS_RESPONSE), + }, + } + + for _, test := range tests { + if test.preset { + for key, value := range test.presetValues { + if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + t.Error(err) + } + mockServer.SetValue(context.Background(), key, value) + mockServer.KeyUnlock(key) + } + } + res, err := handleZINTERSTORE(context.Background(), test.command, mockServer, nil) + if test.expectedError != nil { + if err.Error() != test.expectedError.Error() { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + continue + } + if err != nil { + t.Error(err) + } + rd := resp.NewReader(bytes.NewBuffer(res)) + rv, _, err := rd.ReadValue() + if err != nil { + t.Error(err) + } + if rv.Integer() != test.expectedResponse { + t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) + } + if test.expectedValue != nil { + if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + t.Error(err) + } + set, ok := mockServer.GetValue(test.destination).(*SortedSet) + if !ok { + t.Errorf("expected vaule at key %s to be set, got another type", test.destination) + } + for _, elem := range set.GetAll() { + if !test.expectedValue.Contains(elem.value) { + t.Errorf("could not find element %s in the expected values", elem.value) + } + } + mockServer.KeyRUnlock(test.destination) + } + } + +} func Test_HandleZMPOP(t *testing.T) {} diff --git a/src/modules/sorted_set/sorted_set.go b/src/modules/sorted_set/sorted_set.go index 3b170c6..e9d726d 100644 --- a/src/modules/sorted_set/sorted_set.go +++ b/src/modules/sorted_set/sorted_set.go @@ -302,6 +302,8 @@ type SortedSetParam struct { func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet { switch len(setParams) { + case 0: + return NewSortedSet([]MemberParam{}) case 1: var params []MemberParam for _, member := range setParams[0].set.GetAll() { @@ -370,47 +372,3 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet { return NewSortedSet(params) } } - -func (set *SortedSet) Intersect(others []*SortedSet, weights []int, aggregate string) (*SortedSet, error) { - res := NewSortedSet([]MemberParam{}) - // Find intersect between this set and the first set in others - var score Score - for _, m := range set.GetAll() { - if others[0].Contains(m.value) { - switch strings.ToLower(aggregate) { - case "sum": - score = m.score*Score(weights[0]) + (others[0].Get(m.value).score * Score(weights[1])) - case "min": - score = compareScores(m.score*Score(weights[0]), others[0].Get(m.value).score*Score(weights[1]), "lt") - case "max": - score = compareScores(m.score*Score(weights[0]), others[0].Get(m.value).score*Score(weights[1]), "gt") - } - if _, err := res.AddOrUpdate([]MemberParam{ - {value: m.value, score: score}, - }, nil, nil, nil, nil); err != nil { - return nil, err - } - } - } - // Calculate intersect with the remaining sets in others - for setIdx, sortedSet := range others[1:] { - for _, m := range sortedSet.GetAll() { - if res.Contains(m.value) { - switch strings.ToLower(aggregate) { - case "sum": - score = res.Get(m.value).score + (m.score * Score(weights[setIdx+1])) - case "min": - score = compareScores(res.Get(m.value).score, m.score*Score(weights[setIdx+1]), "lt") - case "max": - score = compareScores(res.Get(m.value).score, m.score*Score(weights[setIdx+1]), "gt") - } - if _, err := res.AddOrUpdate([]MemberParam{ - {value: m.value, score: score}, - }, nil, nil, nil, nil); err != nil { - return nil, err - } - } - } - } - return res, nil -}