ZINTERSTORE command handler now uses new Intersect divide & conquer function instead of the old Intersect receiver function.

Intersect receiver function on SortedSet reference has been deleted as it's no longer in use.
Added test for ZINTERSTORE command handler.
This commit is contained in:
Kelvin Clement Mwinuka
2024-02-20 23:25:22 +08:00
parent a3bb3e9b34
commit 0e657baa2e
3 changed files with 374 additions and 73 deletions

View File

@@ -526,8 +526,8 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
var setParams []SortedSetParam
for i := 0; i < len(keys); i++ {
// If key does not exist, return an empty array
if !server.KeyExists(keys[i]) {
// If any of the keys is non-existent, return an empty array as there's no intersect
return []byte("*0\r\n\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
@@ -565,12 +565,14 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
}
func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
if len(cmd) < 3 {
return nil, errors.New(utils.WRONG_ARGS_RESPONSE)
keys, err := zinterstoreKeyFunc(cmd)
if err != nil {
return nil, err
}
destination := cmd[1]
destination := keys[0]
// Remove the destination keys from the command before parsing it
cmd = slices.DeleteFunc(cmd, func(s string) bool {
return s == destination
})
@@ -589,39 +591,34 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
}
}()
var sets []*SortedSet
var setParams []SortedSetParam
for _, key := range keys {
_, err := server.KeyRLock(ctx, key)
if err != nil {
for i := 0; i < len(keys); i++ {
if !server.KeyExists(keys[i]) {
return []byte(":0\r\n\r\n"), nil
}
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err
}
locks[key] = true
set, ok := server.GetValue(key).(*SortedSet)
locks[keys[i]] = true
set, ok := server.GetValue(keys[i]).(*SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
sets = append(sets, set)
setParams = append(setParams, SortedSetParam{
set: set,
weight: weights[i],
})
}
var intersect *SortedSet
intersect := Intersect(aggregate, setParams...)
if len(sets) > 1 {
if intersect, err = sets[0].Intersect(sets[1:], weights, aggregate); err != nil {
if server.KeyExists(destination) && intersect.Cardinality() > 0 {
if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err
}
} else if len(sets) == 1 {
intersect = sets[0]
} else {
return nil, errors.New("not enough sets to form an intersect")
}
if server.KeyExists(destination) {
if _, err := server.KeyLock(ctx, destination); err != nil {
return nil, err
}
} else {
if _, err := server.CreateKeyAndLock(ctx, destination); err != nil {
} else if intersect.Cardinality() > 0 {
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
return nil, err
}
}

View File

@@ -1186,7 +1186,7 @@ func Test_HandleZINTER(t *testing.T) {
expectedResponse: nil,
expectedError: errors.New("value at key30 is not a sorted set"),
},
{ // 5. If any of the keys does not exist, return an empty array.
{ // 12. If any of the keys does not exist, return an empty array.
preset: true,
presetValues: map[string]interface{}{
"key32": NewSortedSet([]MemberParam{
@@ -1245,7 +1245,353 @@ func Test_HandleZINTER(t *testing.T) {
}
}
func Test_HandleZINTERSTORE(t *testing.T) {}
func Test_HandleZINTERSTORE(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
tests := []struct {
preset bool
presetValues map[string]interface{}
destination string
command []string
expectedValue *SortedSet
expectedResponse int
expectedError error
}{
{ // 1. Get the intersection between 2 sorted sets.
preset: true,
presetValues: map[string]interface{}{
"key1": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5},
}),
"key2": NewSortedSet([]MemberParam{
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
},
destination: "destination1",
command: []string{"ZINTERSTORE", "destination1", "key1", "key2"},
expectedValue: NewSortedSet([]MemberParam{
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5},
}),
expectedResponse: 3,
expectedError: nil,
},
{
// 2. Get the intersection between 3 sorted sets with scores.
// By default, the SUM aggregate will be used.
preset: true,
presetValues: map[string]interface{}{
"key3": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key4": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 8},
}),
"key5": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "eight", score: 8},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination2",
command: []string{"ZINTERSTORE", "destination2", "key3", "key4", "key5", "WITHSCORES"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "eight", score: 24},
}),
expectedResponse: 2,
expectedError: nil,
},
{
// 3. Get the intersection between 3 sorted sets with scores.
// Use MIN aggregate.
preset: true,
presetValues: map[string]interface{}{
"key6": NewSortedSet([]MemberParam{
{value: "one", score: 100}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key7": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 80},
}),
"key8": NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination3",
command: []string{"ZINTERSTORE", "destination3", "key6", "key7", "key8", "WITHSCORES", "AGGREGATE", "MIN"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "eight", score: 8},
}),
expectedResponse: 2,
expectedError: nil,
},
{
// 4. Get the intersection between 3 sorted sets with scores.
// Use MAX aggregate.
preset: true,
presetValues: map[string]interface{}{
"key9": NewSortedSet([]MemberParam{
{value: "one", score: 100}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key10": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 80},
}),
"key11": NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination4",
command: []string{"ZINTERSTORE", "destination4", "key9", "key10", "key11", "WITHSCORES", "AGGREGATE", "MAX"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
}),
expectedResponse: 2,
expectedError: nil,
},
{
// 5. Get the intersection between 3 sorted sets with scores.
// Use SUM aggregate with weights modifier.
preset: true,
presetValues: map[string]interface{}{
"key12": NewSortedSet([]MemberParam{
{value: "one", score: 100}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key13": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 80},
}),
"key14": NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination5",
command: []string{"ZINTERSTORE", "destination5", "key12", "key13", "key14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "eight", score: 2808},
}),
expectedResponse: 2,
expectedError: nil,
},
{
// 6. Get the intersection between 3 sorted sets with scores.
// Use MAX aggregate with added weights.
preset: true,
presetValues: map[string]interface{}{
"key15": NewSortedSet([]MemberParam{
{value: "one", score: 100}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key16": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 80},
}),
"key17": NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination6",
command: []string{"ZINTERSTORE", "destination6", "key15", "key16", "key17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 3000}, {value: "eight", score: 2400},
}),
expectedResponse: 2,
expectedError: nil,
},
{
// 7. Get the intersection between 3 sorted sets with scores.
// Use MIN aggregate with added weights.
preset: true,
presetValues: map[string]interface{}{
"key18": NewSortedSet([]MemberParam{
{value: "one", score: 100}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key19": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11}, {value: "eight", score: 80},
}),
"key20": NewSortedSet([]MemberParam{
{value: "one", score: 1000}, {value: "eight", score: 800},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
destination: "destination7",
command: []string{"ZINTERSTORE", "destination7", "key18", "key19", "key20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"},
expectedValue: NewSortedSet([]MemberParam{
{value: "one", score: 5}, {value: "eight", score: 8},
}),
expectedResponse: 2,
expectedError: nil,
},
{ // 8. Throw an error if there are more weights than keys
preset: true,
presetValues: map[string]interface{}{
"key21": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key22": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
},
command: []string{"ZINTERSTORE", "destination8", "key21", "key22", "WEIGHTS", "1", "2", "3"},
expectedResponse: 0,
expectedError: errors.New("number of weights should match number of keys"),
},
{ // 9. Throw an error if there are fewer weights than keys
preset: true,
presetValues: map[string]interface{}{
"key23": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key24": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
}),
"key25": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
},
command: []string{"ZINTERSTORE", "destination9", "key23", "key24", "key25", "WEIGHTS", "5", "4"},
expectedResponse: 0,
expectedError: errors.New("number of weights should match number of keys"),
},
{ // 10. Throw an error if there are no keys provided
preset: true,
presetValues: map[string]interface{}{
"key26": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
"key27": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
"key28": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
},
command: []string{"ZINTERSTORE", "WEIGHTS", "5", "4"},
expectedResponse: 0,
expectedError: errors.New(utils.WRONG_ARGS_RESPONSE),
},
{ // 11. Throw an error if any of the provided keys are not sorted sets
preset: true,
presetValues: map[string]interface{}{
"key29": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "three", score: 3}, {value: "four", score: 4},
{value: "five", score: 5}, {value: "six", score: 6},
{value: "seven", score: 7}, {value: "eight", score: 8},
}),
"key30": "Default value",
"key31": NewSortedSet([]MemberParam{{value: "one", score: 1}}),
},
command: []string{"ZINTERSTORE", "key29", "key30", "key31"},
expectedResponse: 0,
expectedError: errors.New("value at key30 is not a sorted set"),
},
{ // 12. If any of the keys does not exist, return an empty array.
preset: true,
presetValues: map[string]interface{}{
"key32": NewSortedSet([]MemberParam{
{value: "one", score: 1}, {value: "two", score: 2},
{value: "thirty-six", score: 36}, {value: "twelve", score: 12},
{value: "eleven", score: 11},
}),
"key33": NewSortedSet([]MemberParam{
{value: "seven", score: 7}, {value: "eight", score: 8},
{value: "nine", score: 9}, {value: "ten", score: 10},
{value: "twelve", score: 12},
}),
},
command: []string{"ZINTERSTORE", "destination12", "non-existent", "key32", "key33"},
expectedResponse: 0,
expectedError: nil,
},
{ // 13. Command too short
preset: false,
command: []string{"ZINTERSTORE"},
expectedResponse: 0,
expectedError: errors.New(utils.WRONG_ARGS_RESPONSE),
},
}
for _, test := range tests {
if test.preset {
for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil {
t.Error(err)
}
mockServer.SetValue(context.Background(), key, value)
mockServer.KeyUnlock(key)
}
}
res, err := handleZINTERSTORE(context.Background(), test.command, mockServer, nil)
if test.expectedError != nil {
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
}
continue
}
if err != nil {
t.Error(err)
}
rd := resp.NewReader(bytes.NewBuffer(res))
rv, _, err := rd.ReadValue()
if err != nil {
t.Error(err)
}
if rv.Integer() != test.expectedResponse {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
}
if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil {
t.Error(err)
}
set, ok := mockServer.GetValue(test.destination).(*SortedSet)
if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
}
for _, elem := range set.GetAll() {
if !test.expectedValue.Contains(elem.value) {
t.Errorf("could not find element %s in the expected values", elem.value)
}
}
mockServer.KeyRUnlock(test.destination)
}
}
}
func Test_HandleZMPOP(t *testing.T) {}

View File

@@ -302,6 +302,8 @@ type SortedSetParam struct {
func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
switch len(setParams) {
case 0:
return NewSortedSet([]MemberParam{})
case 1:
var params []MemberParam
for _, member := range setParams[0].set.GetAll() {
@@ -370,47 +372,3 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
return NewSortedSet(params)
}
}
func (set *SortedSet) Intersect(others []*SortedSet, weights []int, aggregate string) (*SortedSet, error) {
res := NewSortedSet([]MemberParam{})
// Find intersect between this set and the first set in others
var score Score
for _, m := range set.GetAll() {
if others[0].Contains(m.value) {
switch strings.ToLower(aggregate) {
case "sum":
score = m.score*Score(weights[0]) + (others[0].Get(m.value).score * Score(weights[1]))
case "min":
score = compareScores(m.score*Score(weights[0]), others[0].Get(m.value).score*Score(weights[1]), "lt")
case "max":
score = compareScores(m.score*Score(weights[0]), others[0].Get(m.value).score*Score(weights[1]), "gt")
}
if _, err := res.AddOrUpdate([]MemberParam{
{value: m.value, score: score},
}, nil, nil, nil, nil); err != nil {
return nil, err
}
}
}
// Calculate intersect with the remaining sets in others
for setIdx, sortedSet := range others[1:] {
for _, m := range sortedSet.GetAll() {
if res.Contains(m.value) {
switch strings.ToLower(aggregate) {
case "sum":
score = res.Get(m.value).score + (m.score * Score(weights[setIdx+1]))
case "min":
score = compareScores(res.Get(m.value).score, m.score*Score(weights[setIdx+1]), "lt")
case "max":
score = compareScores(res.Get(m.value).score, m.score*Score(weights[setIdx+1]), "gt")
}
if _, err := res.AddOrUpdate([]MemberParam{
{value: m.value, score: score},
}, nil, nil, nil, nil); err != nil {
return nil, err
}
}
}
}
return res, nil
}