diff --git a/echovault/echovault.go b/echovault/echovault.go index 5258389..07bce2d 100644 --- a/echovault/echovault.go +++ b/echovault/echovault.go @@ -35,6 +35,7 @@ import ( "github.com/echovault/echovault/internal/modules/list" "github.com/echovault/echovault/internal/modules/pubsub" "github.com/echovault/echovault/internal/modules/set" + "github.com/echovault/echovault/internal/modules/sorted_set" "github.com/echovault/echovault/internal/raft" "github.com/echovault/echovault/internal/snapshot" "io" @@ -141,7 +142,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { commands = append(commands, list.Commands()...) commands = append(commands, pubsub.Commands()...) commands = append(commands, set.Commands()...) - // commands = append(commands, sorted_set.Commands()...) + commands = append(commands, sorted_set.Commands()...) // commands = append(commands, str.Commands()...) return commands }(), diff --git a/internal/modules/sorted_set/commands.go b/internal/modules/sorted_set/commands.go index f53c963..0c6d60e 100644 --- a/internal/modules/sorted_set/commands.go +++ b/internal/modules/sorted_set/commands.go @@ -33,6 +33,7 @@ func handleZADD(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] var updatePolicy interface{} = nil var comparison interface{} = nil @@ -139,14 +140,9 @@ func handleZADD(params internal.HandlerFuncParams) ([]byte, error) { } } - if params.KeyExists(params.Context, key) { + if keyExists { // Key exists - _, err = params.KeyLock(params.Context, key) - if err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -163,14 +159,9 @@ func handleZADD(params internal.HandlerFuncParams) ([]byte, error) { return []byte(fmt.Sprintf(":%d\r\n", count)), nil } - // Key does not exist - if _, err = params.CreateKeyAndLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - + // Key does not exist. set := NewSortedSet(members) - if err = params.SetValue(params.Context, key, set); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{key: set}); err != nil { return nil, err } @@ -182,18 +173,15 @@ func handleZCARD(params internal.HandlerFuncParams) ([]byte, error) { if err != nil { return nil, err } - key := keys.ReadKeys[0] - if !params.KeyExists(params.Context, key) { + key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] + + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -208,6 +196,7 @@ func handleZCOUNT(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] minimum := Score(math.Inf(-1)) switch internal.AdaptType(params.Command[2]).(type) { @@ -245,16 +234,11 @@ func handleZCOUNT(params internal.HandlerFuncParams) ([]byte, error) { maximum = Score(s) } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -276,19 +260,15 @@ func handleZLEXCOUNT(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] minimum := params.Command[2] maximum := params.Command[3] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -320,6 +300,8 @@ func handleZDIFF(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } + keyExists := params.KeysExist(keys.ReadKeys) + withscoresIndex := slices.IndexFunc(params.Command, func(s string) bool { return strings.EqualFold(s, "withscores") }) @@ -327,25 +309,13 @@ func handleZDIFF(params internal.HandlerFuncParams) ([]byte, error) { return nil, errors.New(constants.WrongArgsResponse) } - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() - // Extract base set - if !params.KeyExists(params.Context, keys.ReadKeys[0]) { + if !keyExists[keys.ReadKeys[0]] { // If base set does not exist, return an empty array return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) - baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet) + + baseSortedSet, ok := params.GetValues(params.Context, []string{keys.ReadKeys[0]})[keys.ReadKeys[0]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0]) } @@ -354,15 +324,10 @@ func handleZDIFF(params internal.HandlerFuncParams) ([]byte, error) { var sets []*SortedSet for i := 1; i < len(keys.ReadKeys); i++ { - if !params.KeyExists(params.Context, keys.ReadKeys[i]) { + if !keyExists[keys.ReadKeys[i]] { continue } - locked, err := params.KeyRLock(params.Context, keys.ReadKeys[i]) - if err != nil { - return nil, err - } - locks[keys.ReadKeys[i]] = locked - set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{keys.ReadKeys[i]})[keys.ReadKeys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i]) } @@ -376,7 +341,8 @@ func handleZDIFF(params internal.HandlerFuncParams) ([]byte, error) { for _, m := range diff.GetAll() { if includeScores { - res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64)) + res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", + len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64)) } else { res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value) } @@ -393,27 +359,16 @@ func handleZDIFFSTORE(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } + keyExists := params.KeysExist(keys.ReadKeys) destination := keys.WriteKeys[0] - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() - // Extract base set - if !params.KeyExists(params.Context, keys.ReadKeys[0]) { + if !keyExists[keys.ReadKeys[0]] { // If base set does not exist, return 0 return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, keys.ReadKeys[0]); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, keys.ReadKeys[0]) - baseSortedSet, ok := params.GetValue(params.Context, keys.ReadKeys[0]).(*SortedSet) + + baseSortedSet, ok := params.GetValues(params.Context, []string{keys.ReadKeys[0]})[keys.ReadKeys[0]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[0]) } @@ -421,11 +376,8 @@ func handleZDIFFSTORE(params internal.HandlerFuncParams) ([]byte, error) { var sets []*SortedSet for i := 1; i < len(keys.ReadKeys); i++ { - if params.KeyExists(params.Context, keys.ReadKeys[i]) { - if _, err = params.KeyRLock(params.Context, keys.ReadKeys[i]); err != nil { - return nil, err - } - set, ok := params.GetValue(params.Context, keys.ReadKeys[i]).(*SortedSet) + if keyExists[keys.ReadKeys[i]] { + set, ok := params.GetValues(params.Context, []string{keys.ReadKeys[i]})[keys.ReadKeys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys.ReadKeys[i]) } @@ -434,19 +386,7 @@ func handleZDIFFSTORE(params internal.HandlerFuncParams) ([]byte, error) { } diff := baseSortedSet.Subtract(sets) - - if params.KeyExists(params.Context, destination) { - if _, err = params.KeyLock(params.Context, destination); err != nil { - return nil, err - } - } else { - if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil { - return nil, err - } - } - defer params.KeyUnlock(params.Context, destination) - - if err = params.SetValue(params.Context, destination, diff); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{destination: diff}); err != nil { return nil, err } @@ -460,6 +400,8 @@ func handleZINCRBY(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] + member := Value(params.Command[3]) var increment Score @@ -482,28 +424,21 @@ func handleZINCRBY(params internal.HandlerFuncParams) ([]byte, error) { increment = Score(s) } - if !params.KeyExists(params.Context, key) { + if !keyExists { // If the key does not exist, create a new sorted set at the key with // the member and increment as the first value - if _, err = params.CreateKeyAndLock(params.Context, key); err != nil { - return nil, err - } - if err = params.SetValue( + if err = params.SetValues( params.Context, - key, - NewSortedSet([]MemberParam{{Value: member, Score: increment}}), + map[string]interface{}{ + key: NewSortedSet([]MemberParam{{Value: member, Score: increment}}), + }, ); err != nil { return nil, err } - params.KeyUnlock(params.Context, key) return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -530,28 +465,17 @@ func handleZINTER(params internal.HandlerFuncParams) ([]byte, error) { if err != nil { return nil, err } - - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() + keyExists := params.KeysExist(keys) var setParams []SortedSetParam + values := params.GetValues(params.Context, keys) for i := 0; i < len(keys); i++ { - if !params.KeyExists(params.Context, keys[i]) { + if !keyExists[keys[i]] { // If any of the keys is non-existent, return an empty array as there's no intersect return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, keys[i]); err != nil { - return nil, err - } - locks[keys[i]] = true - set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet) + set, ok := values[keys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -586,6 +510,7 @@ func handleZINTERSTORE(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } + keyExists := params.KeysExist(k.ReadKeys) destination := k.WriteKeys[0] // Remove the destination keys from the command before parsing it @@ -598,26 +523,14 @@ func handleZINTERSTORE(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() - var setParams []SortedSetParam + values := params.GetValues(params.Context, keys) for i := 0; i < len(keys); i++ { - if !params.KeyExists(params.Context, keys[i]) { + if !keyExists[keys[i]] { return []byte(":0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, keys[i]); err != nil { - return nil, err - } - locks[keys[i]] = true - set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet) + set, ok := values[keys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -628,19 +541,9 @@ func handleZINTERSTORE(params internal.HandlerFuncParams) ([]byte, error) { } intersect := Intersect(aggregate, setParams...) - - if params.KeyExists(params.Context, destination) && intersect.Cardinality() > 0 { - if _, err = params.KeyLock(params.Context, destination); err != nil { - return nil, err - } - } else if intersect.Cardinality() > 0 { - if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil { - return nil, err - } - } - defer params.KeyUnlock(params.Context, destination) - - if err = params.SetValue(params.Context, destination, intersect); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{ + destination: intersect, + }); err != nil { return nil, err } @@ -653,6 +556,8 @@ func handleZMPOP(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } + keyExists := params.KeysExist(keys.WriteKeys) + count := 1 policy := "min" modifierIdx := -1 @@ -694,21 +599,15 @@ func handleZMPOP(params internal.HandlerFuncParams) ([]byte, error) { } for i := 0; i < len(keys.WriteKeys); i++ { - if params.KeyExists(params.Context, keys.WriteKeys[i]) { - if _, err = params.KeyLock(params.Context, keys.WriteKeys[i]); err != nil { - continue - } - v, ok := params.GetValue(params.Context, keys.WriteKeys[i]).(*SortedSet) + if keyExists[keys.WriteKeys[i]] { + v, ok := params.GetValues(params.Context, []string{keys.WriteKeys[i]})[keys.WriteKeys[i]].(*SortedSet) if !ok || v.Cardinality() == 0 { - params.KeyUnlock(params.Context, keys.WriteKeys[i]) continue } popped, err := v.Pop(count, policy) if err != nil { - params.KeyUnlock(params.Context, keys.WriteKeys[i]) return nil, err } - params.KeyUnlock(params.Context, keys.WriteKeys[i]) res := fmt.Sprintf("*%d", popped.Cardinality()) @@ -732,6 +631,7 @@ func handleZPOP(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] count := 1 policy := "min" @@ -749,16 +649,11 @@ func handleZPOP(params internal.HandlerFuncParams) ([]byte, error) { } } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at key %s is not a sorted set", key) } @@ -770,7 +665,8 @@ func handleZPOP(params internal.HandlerFuncParams) ([]byte, error) { res := fmt.Sprintf("*%d", popped.Cardinality()) for _, m := range popped.GetAll() { - res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64)) + res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", + len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64)) } res += "\r\n" @@ -785,17 +681,13 @@ func handleZMSCORE(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -827,6 +719,7 @@ func handleZRANDMEMBER(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] count := 1 if len(params.Command) >= 3 { @@ -848,16 +741,11 @@ func handleZRANDMEMBER(params internal.HandlerFuncParams) ([]byte, error) { } } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -885,6 +773,7 @@ func handleZRANK(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] member := params.Command[2] withscores := false @@ -892,16 +781,11 @@ func handleZRANK(params internal.HandlerFuncParams) ([]byte, error) { withscores = true } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -935,17 +819,13 @@ func handleZREM(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -967,15 +847,13 @@ func handleZSCORE(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("$-1\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - set, ok := params.GetValue(params.Context, key).(*SortedSet) + + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -996,6 +874,7 @@ func handleZREMRANGEBYSCORE(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] deletedCount := 0 @@ -1009,16 +888,11 @@ func handleZREMRANGEBYSCORE(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1040,6 +914,7 @@ func handleZREMRANGEBYRANK(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] start, err := strconv.Atoi(params.Command[2]) if err != nil { @@ -1051,16 +926,11 @@ func handleZREMRANGEBYRANK(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1105,19 +975,15 @@ func handleZREMRANGEBYLEX(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.WriteKeys[0] + keyExists := params.KeysExist(keys.WriteKeys)[key] minimum := params.Command[2] maximum := params.Command[3] - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte(":0\r\n"), nil } - if _, err = params.KeyLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1152,6 +1018,8 @@ func handleZRANGE(params internal.HandlerFuncParams) ([]byte, error) { } key := keys.ReadKeys[0] + keyExists := params.KeysExist(keys.ReadKeys)[key] + policy := "byscore" scoreStart := math.Inf(-1) // Lower bound if policy is "byscore" scoreStop := math.Inf(1) // Upper bound if policy is "byscore" @@ -1206,16 +1074,11 @@ func handleZRANGE(params internal.HandlerFuncParams) ([]byte, error) { } } - if !params.KeyExists(params.Context, key) { + if !keyExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, key); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, key) - - set, ok := params.GetValue(params.Context, key).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{key})[key].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1293,6 +1156,7 @@ func handleZRANGESTORE(params internal.HandlerFuncParams) ([]byte, error) { destination := keys.WriteKeys[0] source := keys.ReadKeys[0] + sourceExists := params.KeysExist(keys.ReadKeys)[source] policy := "byscore" scoreStart := math.Inf(-1) // Lower bound if policy is "byscore" scoreStop := math.Inf(1) // Upper bound if policy is "byfloat" @@ -1343,16 +1207,11 @@ func handleZRANGESTORE(params internal.HandlerFuncParams) ([]byte, error) { } } - if !params.KeyExists(params.Context, source) { + if !sourceExists { return []byte("*0\r\n"), nil } - if _, err = params.KeyRLock(params.Context, source); err != nil { - return nil, err - } - defer params.KeyRUnlock(params.Context, source) - - set, ok := params.GetValue(params.Context, source).(*SortedSet) + set, ok := params.GetValues(params.Context, []string{source})[source].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", source) } @@ -1408,19 +1267,9 @@ func handleZRANGESTORE(params internal.HandlerFuncParams) ([]byte, error) { } newSortedSet := NewSortedSet(resultMembers) - - if params.KeyExists(params.Context, destination) { - if _, err = params.KeyLock(params.Context, destination); err != nil { - return nil, err - } - } else { - if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil { - return nil, err - } - } - defer params.KeyUnlock(params.Context, destination) - - if err = params.SetValue(params.Context, destination, newSortedSet); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{ + destination: newSortedSet, + }); err != nil { return nil, err } @@ -1437,24 +1286,14 @@ func handleZUNION(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() + keyExists := params.KeysExist(keys) var setParams []SortedSetParam + values := params.GetValues(params.Context, keys) for i := 0; i < len(keys); i++ { - if params.KeyExists(params.Context, keys[i]) { - if _, err = params.KeyRLock(params.Context, keys[i]); err != nil { - return nil, err - } - locks[keys[i]] = true - set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet) + if keyExists[keys[i]] { + set, ok := values[keys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -1499,24 +1338,14 @@ func handleZUNIONSTORE(params internal.HandlerFuncParams) ([]byte, error) { return nil, err } - locks := make(map[string]bool) - defer func() { - for key, locked := range locks { - if locked { - params.KeyRUnlock(params.Context, key) - } - } - }() + keyExists := params.KeysExist(keys) var setParams []SortedSetParam + values := params.GetValues(params.Context, keys) for i := 0; i < len(keys); i++ { - if params.KeyExists(params.Context, keys[i]) { - if _, err = params.KeyRLock(params.Context, keys[i]); err != nil { - return nil, err - } - locks[keys[i]] = true - set, ok := params.GetValue(params.Context, keys[i]).(*SortedSet) + if keyExists[keys[i]] { + set, ok := values[keys[i]].(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -1528,19 +1357,9 @@ func handleZUNIONSTORE(params internal.HandlerFuncParams) ([]byte, error) { } union := Union(aggregate, setParams...) - - if params.KeyExists(params.Context, destination) { - if _, err = params.KeyLock(params.Context, destination); err != nil { - return nil, err - } - } else { - if _, err = params.CreateKeyAndLock(params.Context, destination); err != nil { - return nil, err - } - } - defer params.KeyUnlock(params.Context, destination) - - if err = params.SetValue(params.Context, destination, union); err != nil { + if err = params.SetValues(params.Context, map[string]interface{}{ + destination: union, + }); err != nil { return nil, err } diff --git a/internal/modules/sorted_set/commands_test.go b/internal/modules/sorted_set/commands_test.go index b190e12..5e97a43 100644 --- a/internal/modules/sorted_set/commands_test.go +++ b/internal/modules/sorted_set/commands_test.go @@ -15,8 +15,6 @@ package sorted_set_test import ( - "bytes" - "context" "errors" "fmt" "github.com/echovault/echovault/echovault" @@ -27,359 +25,264 @@ import ( "github.com/tidwall/resp" "math" "net" - "reflect" "slices" "strconv" "strings" + "sync" "testing" - "unsafe" ) var mockServer *echovault.EchoVault +var addr = "localhost" +var port int func init() { + port, _ = internal.GetFreePort() mockServer, _ = echovault.NewEchoVault( echovault.WithConfig(config.Config{ + BindAddr: addr, + Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) -} - -func getUnexportedField(field reflect.Value) interface{} { - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() -} - -func getHandler(commands ...string) internal.HandlerFunc { - if len(commands) == 0 { - return nil - } - getCommands := - getUnexportedField(reflect.ValueOf(mockServer).Elem().FieldByName("getCommands")).(func() []internal.Command) - for _, c := range getCommands() { - if strings.EqualFold(commands[0], c.Command) && len(commands) == 1 { - // Get command handler - return c.HandlerFunc - } - if strings.EqualFold(commands[0], c.Command) { - // Get sub-command handler - for _, sc := range c.SubCommands { - if strings.EqualFold(commands[1], sc.Command) { - return sc.HandlerFunc - } - } - } - } - return nil -} - -func getHandlerFuncParams(ctx context.Context, cmd []string, conn *net.Conn) internal.HandlerFuncParams { - return internal.HandlerFuncParams{ - Context: ctx, - Command: cmd, - Connection: conn, - KeyExists: mockServer.KeyExists, - CreateKeyAndLock: mockServer.CreateKeyAndLock, - KeyLock: mockServer.KeyLock, - KeyRLock: mockServer.KeyRLock, - KeyUnlock: mockServer.KeyUnlock, - KeyRUnlock: mockServer.KeyRUnlock, - GetValue: mockServer.GetValue, - SetValue: mockServer.SetValue, - } + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + mockServer.Start() + }() + wg.Wait() } func Test_HandleZADD(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValue *sorted_set.SortedSet key string command []string - expectedValue *sorted_set.SortedSet expectedResponse int expectedError error }{ { - name: "1. Create new sorted set and return the cardinality of the new sorted set", - preset: false, - presetValue: nil, - key: "ZaddKey1", - command: []string{"ZADD", "ZaddKey1", "5.5", "member1", "67.77", "member2", "10", "member3", "-inf", "member4", "+inf", "member5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - {Value: "member4", Score: sorted_set.Score(math.Inf(-1))}, - {Value: "member5", Score: sorted_set.Score(math.Inf(1))}, - }), + name: "1. Create new sorted set and return the cardinality of the new sorted set", + presetValue: nil, + key: "ZaddKey1", + command: []string{"ZADD", "ZaddKey1", "5.5", "member1", "67.77", "member2", "10", "member3", "-inf", "member4", "+inf", "member5"}, expectedResponse: 5, expectedError: nil, }, { - name: "2. Only add the elements that do not currently exist in the sorted set when NX flag is provided", - preset: true, + name: "2. Only add the elements that do not currently exist in the sorted set when NX flag is provided", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey2", - command: []string{"ZADD", "ZaddKey2", "NX", "5.5", "member1", "67.77", "member4", "10", "member5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - {Value: "member4", Score: sorted_set.Score(67.77)}, - {Value: "member5", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey2", + command: []string{"ZADD", "ZaddKey2", "NX", "5.5", "member1", "67.77", "member4", "10", "member5"}, expectedResponse: 2, expectedError: nil, }, { - name: "Do not add any elements when providing existing members with NX flag", - preset: true, + name: "Do not add any elements when providing existing members with NX flag", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey3", - command: []string{"ZADD", "ZaddKey3", "NX", "5.5", "member1", "67.77", "member2", "10", "member3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey3", + command: []string{"ZADD", "ZaddKey3", "NX", "5.5", "member1", "67.77", "member2", "10", "member3"}, expectedResponse: 0, expectedError: nil, }, { - name: "Successfully add elements to an existing set when XX flag is provided with existing elements", - preset: true, + name: "Successfully add elements to an existing set when XX flag is provided with existing elements", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey4", - command: []string{"ZADD", "ZaddKey4", "XX", "CH", "55", "member1", "1005", "member2", "15", "member3", "99.75", "member4"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(55)}, - {Value: "member2", Score: sorted_set.Score(1005)}, - {Value: "member3", Score: sorted_set.Score(15)}, - }), + key: "ZaddKey4", + command: []string{"ZADD", "ZaddKey4", "XX", "CH", "55", "member1", "1005", "member2", "15", "member3", "99.75", "member4"}, expectedResponse: 3, expectedError: nil, }, { - name: "5. Fail to add element when providing XX flag with elements that do not exist in the sorted set.", - preset: true, + name: "5. Fail to add element when providing XX flag with elements that do not exist in the sorted set.", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey5", - command: []string{"ZADD", "ZaddKey5", "XX", "5.5", "member4", "100.5", "member5", "15", "member6"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey5", + command: []string{"ZADD", "ZaddKey5", "XX", "5.5", "member4", "100.5", "member5", "15", "member6"}, expectedResponse: 0, expectedError: nil, }, { // 6. Only update the elements where provided score is greater than current score and GT flag is provided // Return only the new elements added by default - name: "6. Only update the elements where provided score is greater than current score and GT flag is provided", - preset: true, + name: "6. Only update the elements where provided score is greater than current score and GT flag is provided", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey6", - command: []string{"ZADD", "ZaddKey6", "XX", "CH", "GT", "7.5", "member1", "100.5", "member4", "15", "member5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(7.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey6", + command: []string{"ZADD", "ZaddKey6", "XX", "CH", "GT", "7.5", "member1", "100.5", "member4", "15", "member5"}, expectedResponse: 1, expectedError: nil, }, { // 7. Only update the elements where provided score is less than current score if LT flag is provided // Return only the new elements added by default. - name: "7. Only update the elements where provided score is less than current score if LT flag is provided", - preset: true, + name: "7. Only update the elements where provided score is less than current score if LT flag is provided", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey7", - command: []string{"ZADD", "ZaddKey7", "XX", "LT", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(3.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey7", + command: []string{"ZADD", "ZaddKey7", "XX", "LT", "3.5", "member1", "100.5", "member4", "15", "member5"}, expectedResponse: 0, expectedError: nil, }, { - name: "8. Return all the elements that were updated AND added when CH flag is provided", - preset: true, + name: "8. Return all the elements that were updated AND added when CH flag is provided", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey8", - command: []string{"ZADD", "ZaddKey8", "XX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(3.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), + key: "ZaddKey8", + command: []string{"ZADD", "ZaddKey8", "XX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, expectedResponse: 1, expectedError: nil, }, { - name: "9. Increment the member by score", - preset: true, + name: "9. Increment the member by score", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, {Value: "member3", Score: sorted_set.Score(10)}, }), - key: "ZaddKey9", - command: []string{"ZADD", "ZaddKey9", "INCR", "5.5", "member3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(15.5)}, - }), + key: "ZaddKey9", + command: []string{"ZADD", "ZaddKey9", "INCR", "5.5", "member3"}, expectedResponse: 0, expectedError: nil, }, { name: "10. Fail when GT/LT flag is provided alongside NX flag", - preset: false, presetValue: nil, key: "ZaddKey10", command: []string{"ZADD", "ZaddKey10", "NX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("GT/LT flags not allowed if NX flag is provided"), }, { name: "11. Command is too short", - preset: false, presetValue: nil, key: "ZaddKey11", command: []string{"ZADD", "ZaddKey11"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "12. Throw error when score/member entries are do not match", - preset: false, presetValue: nil, key: "ZaddKey11", command: []string{"ZADD", "ZaddKey12", "10.5", "member1", "12.5"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("score/member pairs must be float/string"), }, { name: "13. Throw error when INCR flag is passed with more than one score/member pair", - preset: false, presetValue: nil, key: "ZaddKey13", command: []string{"ZADD", "ZaddKey13", "INCR", "10.5", "member1", "12.5", "member2"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("cannot pass more than one score/member pair when INCR flag is provided"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZADD, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if res.Integer() != test.presetValue.Cardinality() { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) - } - // Fetch the sorted set from the echovault and check it against the expected result - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - sortedSet, ok := mockServer.GetValue(ctx, test.key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key) - } - if test.expectedValue == nil { - return - } - if !sortedSet.Equals(test.expectedValue) { - t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, sortedSet) - } - mockServer.KeyRUnlock(ctx, test.key) }) } } func Test_HandleZCARD(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValue interface{} key string command []string - expectedValue *sorted_set.SortedSet expectedResponse int expectedError error }{ { - name: "1. Get cardinality of valid sorted set.", - preset: true, + name: "1. Get cardinality of valid sorted set.", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, @@ -387,109 +290,125 @@ func Test_HandleZCARD(t *testing.T) { }), key: "ZcardKey1", command: []string{"ZCARD", "ZcardKey1"}, - expectedValue: nil, expectedResponse: 3, expectedError: nil, }, { name: "2. Return 0 when trying to get cardinality from non-existent key", - preset: false, presetValue: nil, key: "ZcardKey2", command: []string{"ZCARD", "ZcardKey2"}, - expectedValue: nil, expectedResponse: 0, expectedError: nil, }, { name: "3. Command is too short", - preset: false, presetValue: nil, key: "ZcardKey3", command: []string{"ZCARD"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, - { // + { name: "4. Command too long", - preset: false, presetValue: nil, key: "ZcardKey4", command: []string{"ZCARD", "ZcardKey4", "ZcardKey5"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Return error when not a sorted set", - preset: true, presetValue: "Default value", key: "ZcardKey5", command: []string{"ZCARD", "ZcardKey5"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("value at ZcardKey5 is not a sorted set"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } }) } } func Test_HandleZCOUNT(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValue interface{} key string command []string - expectedValue *sorted_set.SortedSet expectedResponse int expectedError error }{ { - name: "1. Get entire count using infinity boundaries", - preset: true, + name: "1. Get entire count using infinity boundaries", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, @@ -501,13 +420,11 @@ func Test_HandleZCOUNT(t *testing.T) { }), key: "ZcountKey1", command: []string{"ZCOUNT", "ZcountKey1", "-inf", "+inf"}, - expectedValue: nil, expectedResponse: 7, expectedError: nil, }, { - name: "2. Get count of sub-set from -inf to limit", - preset: true, + name: "2. Get count of sub-set from -inf to limit", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, @@ -519,13 +436,11 @@ func Test_HandleZCOUNT(t *testing.T) { }), key: "ZcountKey2", command: []string{"ZCOUNT", "ZcountKey2", "-inf", "90"}, - expectedValue: nil, expectedResponse: 5, expectedError: nil, }, { - name: "3. Get count of sub-set from bottom boundary to +inf limit", - preset: true, + name: "3. Get count of sub-set from bottom boundary to +inf limit", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "member1", Score: sorted_set.Score(5.5)}, {Value: "member2", Score: sorted_set.Score(67.77)}, @@ -537,119 +452,133 @@ func Test_HandleZCOUNT(t *testing.T) { }), key: "ZcountKey3", command: []string{"ZCOUNT", "ZcountKey3", "1000", "+inf"}, - expectedValue: nil, expectedResponse: 2, expectedError: nil, }, { name: "4. Return error when bottom boundary is not a valid double/float", - preset: false, presetValue: nil, key: "ZcountKey4", command: []string{"ZCOUNT", "ZcountKey4", "min", "10"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("min constraint must be a double"), }, { name: "5. Return error when top boundary is not a valid double/float", - preset: false, presetValue: nil, key: "ZcountKey5", command: []string{"ZCOUNT", "ZcountKey5", "-10", "max"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("max constraint must be a double"), }, { name: "6. Command is too short", - preset: false, presetValue: nil, key: "ZcountKey6", command: []string{"ZCOUNT"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "7. Command too long", - preset: false, presetValue: nil, key: "ZcountKey7", command: []string{"ZCOUNT", "ZcountKey4", "min", "max", "count"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "8. Throw error when value at the key is not a sorted set", - preset: true, presetValue: "Default value", key: "ZcountKey8", command: []string{"ZCOUNT", "ZcountKey8", "1", "10"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("value at ZcountKey8 is not a sorted set"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } }) } } func Test_HandleZLEXCOUNT(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValue interface{} key string command []string - expectedValue *sorted_set.SortedSet expectedResponse int expectedError error }{ { - name: "1. Get entire count using infinity boundaries", - preset: true, + name: "1. Get entire count using infinity boundaries", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "e", Score: sorted_set.Score(1)}, {Value: "f", Score: sorted_set.Score(1)}, @@ -661,13 +590,11 @@ func Test_HandleZLEXCOUNT(t *testing.T) { }), key: "ZlexCountKey1", command: []string{"ZLEXCOUNT", "ZlexCountKey1", "f", "j"}, - expectedValue: nil, expectedResponse: 5, expectedError: nil, }, { - name: "2. Return 0 when the members do not have the same score", - preset: true, + name: "2. Return 0 when the members do not have the same score", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: sorted_set.Score(5.5)}, {Value: "b", Score: sorted_set.Score(67.77)}, @@ -679,107 +606,124 @@ func Test_HandleZLEXCOUNT(t *testing.T) { }), key: "ZlexCountKey2", command: []string{"ZLEXCOUNT", "ZlexCountKey2", "a", "b"}, - expectedValue: nil, expectedResponse: 0, expectedError: nil, }, { name: "3. Return 0 when the key does not exist", - preset: false, presetValue: nil, key: "ZlexCountKey3", command: []string{"ZLEXCOUNT", "ZlexCountKey3", "a", "z"}, - expectedValue: nil, expectedResponse: 0, expectedError: nil, }, { name: "4. Return error when the value at the key is not a sorted set", - preset: true, presetValue: "Default value", key: "ZlexCountKey4", command: []string{"ZLEXCOUNT", "ZlexCountKey4", "a", "z"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New("value at ZlexCountKey4 is not a sorted set"), }, { name: "5. Command is too short", - preset: false, presetValue: nil, key: "ZlexCountKey5", command: []string{"ZLEXCOUNT"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "6. Command too long", - preset: false, presetValue: nil, key: "ZlexCountKey6", command: []string{"ZLEXCOUNT", "ZlexCountKey6", "min", "max", "count"}, - expectedValue: nil, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZLEXCOUNT, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewReader(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } }) } } func Test_HandleZDIFF(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedResponse [][]string expectedError error }{ { - name: "1. Get the difference between 2 sorted sets without scores.", - preset: true, + name: "1. Get the difference between 2 sorted sets without scores.", presetValues: map[string]interface{}{ "ZdiffKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, @@ -801,16 +745,15 @@ func Test_HandleZDIFF(t *testing.T) { expectedError: nil, }, { - name: "2. Get the difference between 2 sorted sets with scores.", - preset: true, + name: "2. Get the difference between 2 sorted sets with scores.", presetValues: map[string]interface{}{ - "ZdiffKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, }), - "ZdiffKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, @@ -819,47 +762,45 @@ func Test_HandleZDIFF(t *testing.T) { {Value: "eight", Score: 8}, }), }, - command: []string{"ZDIFF", "ZdiffKey1", "ZdiffKey2", "WITHSCORES"}, + command: []string{"ZDIFF", "ZdiffKey3", "ZdiffKey4", "WITHSCORES"}, expectedResponse: [][]string{{"one", "1"}, {"two", "2"}}, expectedError: nil, }, { - name: "3. Get the difference between 3 sets with scores.", - preset: true, + name: "3. Get the difference between 3 sets with scores.", presetValues: map[string]interface{}{ - "ZdiffKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), - "ZdiffKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, {Value: "eleven", Score: 11}, }), - "ZdiffKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "twelve", Score: 12}, }), }, - command: []string{"ZDIFF", "ZdiffKey3", "ZdiffKey4", "ZdiffKey5", "WITHSCORES"}, + command: []string{"ZDIFF", "ZdiffKey5", "ZdiffKey6", "ZdiffKey7", "WITHSCORES"}, expectedResponse: [][]string{{"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, expectedError: nil, }, { - name: "4. Return sorted set if only one key exists and is a sorted set", - preset: true, + name: "4. Return sorted set if only one key exists and is a sorted set", presetValues: map[string]interface{}{ - "ZdiffKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), }, - command: []string{"ZDIFF", "ZdiffKey6", "ZdiffKey7", "ZdiffKey8", "WITHSCORES"}, + command: []string{"ZDIFF", "ZdiffKey8", "ZdiffKey9", "ZdiffKey10", "WITHSCORES"}, expectedResponse: [][]string{ {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, {"seven", "7"}, {"eight", "8"}, @@ -867,86 +808,115 @@ func Test_HandleZDIFF(t *testing.T) { expectedError: nil, }, { - name: "5. Throw error when one of the keys is not a sorted set.", - preset: true, + name: "5. Throw error when one of the keys is not a sorted set.", presetValues: map[string]interface{}{ - "ZdiffKey9": "Default value", - "ZdiffKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey11": "Default value", + "ZdiffKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, {Value: "eleven", Score: 11}, }), - "ZdiffKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + "ZdiffKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "twelve", Score: 12}, }), }, - command: []string{"ZDIFF", "ZdiffKey9", "ZdiffKey10", "ZdiffKey11"}, + command: []string{"ZDIFF", "ZdiffKey11", "ZdiffKey12", "ZdiffKey13"}, expectedResponse: nil, - expectedError: errors.New("value at ZdiffKey9 is not a sorted set"), + expectedError: errors.New("value at ZdiffKey11 is not a sorted set"), }, { name: "6. Command too short", - preset: false, command: []string{"ZDIFF"}, expectedResponse: [][]string{}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFF, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } }) @@ -954,9 +924,14 @@ func Test_HandleZDIFF(t *testing.T) { } func Test_HandleZDIFFSTORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} destination string command []string @@ -965,8 +940,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { expectedError error }{ { - name: "1. Get the difference between 2 sorted sets.", - preset: true, + name: "1. Get the difference between 2 sorted sets.", presetValues: map[string]interface{}{ "ZdiffStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -986,8 +960,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { expectedError: nil, }, { - name: "2. Get the difference between 3 sorted sets.", - preset: true, + name: "2. Get the difference between 3 sorted sets.", presetValues: map[string]interface{}{ "ZdiffStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -1016,8 +989,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { expectedError: nil, }, { - name: "3. Return base sorted set element if base set is the only existing key provided and is a valid sorted set", - preset: true, + name: "3. Return base sorted set element if base set is the only existing key provided and is a valid sorted set", presetValues: map[string]interface{}{ "ZdiffStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -1038,8 +1010,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { expectedError: nil, }, { - name: "4. Throw error when base sorted set is not a set.", - preset: true, + name: "4. Throw error when base sorted set is not a set.", presetValues: map[string]interface{}{ "ZdiffStoreKey9": "Default value", "ZdiffStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ @@ -1060,8 +1031,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { expectedError: errors.New("value at ZdiffStoreKey9 is not a sorted set"), }, { - name: "5. Throw error when base set is non-existent.", - preset: true, + name: "5. Return 0 when base set is non-existent.", destination: "ZdiffStoreDestinationKey5", presetValues: map[string]interface{}{ "ZdiffStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ @@ -1082,77 +1052,124 @@ func Test_HandleZDIFFSTORE(t *testing.T) { }, { name: "6. Command too short", - preset: false, command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey6"}, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFFSTORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) - } - if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { - t.Error(err) + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, test.destination).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected vaule at key %s to be set, got another type", test.destination) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } - for _, elem := range set.GetAll() { - if !test.expectedValue.Contains(elem.Value) { - t.Errorf("could not find element %s in the expected values", elem.Value) - } - } - mockServer.KeyRUnlock(ctx, test.destination) } }) } } func Test_HandleZINCRBY(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValue interface{} key string command []string @@ -1161,8 +1178,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError error }{ { - name: "1. Successfully increment by int. Return the new score", - preset: true, + name: "1. Successfully increment by int. Return the new score", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1179,8 +1195,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: nil, }, { - name: "2. Successfully increment by float. Return new score", - preset: true, + name: "2. Successfully increment by float. Return new score", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1198,7 +1213,6 @@ func Test_HandleZINCRBY(t *testing.T) { }, { name: "3. Increment on non-existent sorted set will create the set with the member and increment as its score", - preset: false, presetValue: nil, key: "ZincrbyKey3", command: []string{"ZINCRBY", "ZincrbyKey3", "346.785", "one"}, @@ -1209,8 +1223,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: nil, }, { - name: "4. Increment score to +inf", - preset: true, + name: "4. Increment score to +inf", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1227,8 +1240,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: nil, }, { - name: "5. Increment score to -inf", - preset: true, + name: "5. Increment score to -inf", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1245,8 +1257,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: nil, }, { - name: "6. Incrementing score by negative increment should lower the score", - preset: true, + name: "6. Incrementing score by negative increment should lower the score", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1264,7 +1275,6 @@ func Test_HandleZINCRBY(t *testing.T) { }, { name: "7. Return error when attempting to increment on a value that is not a valid sorted set", - preset: true, presetValue: "Default value", key: "ZincrbyKey7", command: []string{"ZINCRBY", "ZincrbyKey7", "-2.5", "five"}, @@ -1273,8 +1283,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: errors.New("value at ZincrbyKey7 is not a sorted set"), }, { - name: "8. Return error when trying to increment a member that already has score -inf", - preset: true, + name: "8. Return error when trying to increment a member that already has score -inf", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, }), @@ -1287,8 +1296,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: errors.New("cannot increment -inf or +inf"), }, { - name: "9. Return error when trying to increment a member that already has score +inf", - preset: true, + name: "9. Return error when trying to increment a member that already has score +inf", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: sorted_set.Score(math.Inf(1))}, }), @@ -1301,8 +1309,7 @@ func Test_HandleZINCRBY(t *testing.T) { expectedError: errors.New("cannot increment -inf or +inf"), }, { - name: "10. Return error when increment is not a valid number", - preset: true, + name: "10. Return error when increment is not a valid number", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, }), @@ -1330,71 +1337,115 @@ func Test_HandleZINCRBY(t *testing.T) { }, } - for i, test := range tests { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINCRBY, %d", i)) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { - t.Error(err) - } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { - t.Error(err) - } - mockServer.KeyUnlock(ctx, test.key) - } - - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return - } - - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) - - if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - continue - } - if err != nil { - t.Error(err) - } - - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.String() != test.expectedResponse { - t.Errorf("expected response integer %s, got %s", test.expectedResponse, rv.String()) - } - if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - set, ok := mockServer.GetValue(ctx, test.key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected vaule at key %s to be set, got another type", test.key) - } - for _, elem := range set.GetAll() { - if !test.expectedValue.Contains(elem.Value) { - t.Errorf("could not find element %s in the expected values", elem.Value) + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) } - if test.expectedValue.Get(elem.Value).Score != elem.Score { - t.Errorf("expected score of element \"%s\" from set at key \"%s\" to be %s, got %s", - elem.Value, test.key, - strconv.FormatFloat(float64(test.expectedValue.Get(elem.Value).Score), 'f', -1, 64), - strconv.FormatFloat(float64(elem.Score), 'f', -1, 64), - ) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - mockServer.KeyRUnlock(ctx, test.key) - } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.key, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) + } + } + }) } } func Test_HandleZMPOP(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string preset bool @@ -1517,7 +1568,6 @@ func Test_HandleZMPOP(t *testing.T) { name: "6. Successfully pop elements from the first set which is non-empty", preset: true, presetValues: map[string]interface{}{ - "ZmpopKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), "ZmpopKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1538,9 +1588,8 @@ func Test_HandleZMPOP(t *testing.T) { name: "7. Skip the non-set items and pop elements from the first non-empty sorted set found", preset: true, presetValues: map[string]interface{}{ - "ZmpopKey8": "Default value", - "ZmpopKey9": 56, - "ZmpopKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), + "ZmpopKey8": "Default value", + "ZmpopKey9": 56, "ZmpopKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, @@ -1571,70 +1620,129 @@ func Test_HandleZMPOP(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMPOP, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } + + // Check if the resulting sorted set has the expected members/scores for key, expectedSortedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if !set.Equals(expectedSortedSet) { - t.Errorf("expected sorted set at key \"%s\" %+v, got %+v", key, expectedSortedSet, set) + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) + } } } }) @@ -1642,6 +1750,12 @@ func Test_HandleZMPOP(t *testing.T) { } func Test_HandleZPOP(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string preset bool @@ -1762,70 +1876,129 @@ func Test_HandleZPOP(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZPOPMIN/ZPOPMAX, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } + + // Check if the resulting sorted set has the expected members/scores for key, expectedSortedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) } - if !set.Equals(expectedSortedSet) { - t.Errorf("expected sorted set at key \"%s\" %+v, got %+v", key, expectedSortedSet, set) + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) + } } } }) @@ -1833,19 +2006,23 @@ func Test_HandleZPOP(t *testing.T) { } func Test_HandleZMSCORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string - expectedResponse []interface{} + expectedResponse []string expectedError error }{ { // 1. Return multiple scores from the sorted set. // Return nil for elements that do not exist in the sorted set. - name: "Return multiple scores from the sorted set.", - preset: true, + name: "1. Return multiple scores from the sorted set.", presetValues: map[string]interface{}{ "ZmScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, @@ -1854,79 +2031,97 @@ func Test_HandleZMSCORE(t *testing.T) { }), }, command: []string{"ZMSCORE", "ZmScoreKey1", "one", "none", "two", "one", "three", "four", "none", "five"}, - expectedResponse: []interface{}{"1.1", nil, "245", "1.1", "3", "4.055", nil, "5"}, + expectedResponse: []string{"1.1", "", "245", "1.1", "3", "4.055", "", "5"}, expectedError: nil, }, { name: "2. If key does not exist, return empty array", - preset: false, presetValues: nil, command: []string{"ZMSCORE", "ZmScoreKey2", "one", "two", "three", "four"}, - expectedResponse: []interface{}{}, + expectedResponse: []string{}, expectedError: nil, }, { name: "3. Throw error when trying to find scores from elements that are not sorted sets", - preset: true, presetValues: map[string]interface{}{"ZmScoreKey3": "Default value"}, command: []string{"ZMSCORE", "ZmScoreKey3", "one", "two", "three"}, expectedError: errors.New("value at ZmScoreKey3 is not a sorted set"), }, { name: "9. Command too short", - preset: false, command: []string{"ZMSCORE"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMSCORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for i := 0; i < len(rv.Array()); i++ { - if rv.Array()[i].IsNull() { - if test.expectedResponse[i] != nil { - t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedResponse[i], rv.Array()[i]) - } - continue - } - if rv.Array()[i].String() != test.expectedResponse[i] { - t.Errorf("expected \"%s\" at index %d, got %s", test.expectedResponse[i], i, rv.Array()[i].String()) + + for i := 0; i < len(res.Array()); i++ { + if test.expectedResponse[i] != res.Array()[i].String() { + t.Errorf("expected element at index %d to be \"%s\", got %s", + i, test.expectedResponse[i], res.Array()[i].String()) } } }) @@ -1934,17 +2129,21 @@ func Test_HandleZMSCORE(t *testing.T) { } func Test_HandleZSCORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string - expectedResponse interface{} + expectedResponse string expectedError error }{ { - name: "1. Return score from a sorted set.", - preset: true, + name: "1. Return score from a sorted set.", presetValues: map[string]interface{}{ "ZscoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, @@ -1958,15 +2157,13 @@ func Test_HandleZSCORE(t *testing.T) { }, { name: "2. If key does not exist, return nil value", - preset: false, presetValues: nil, command: []string{"ZSCORE", "ZscoreKey2", "one"}, - expectedResponse: nil, + expectedResponse: "", expectedError: nil, }, { - name: "3. If key exists and is a sorted set, but the member does not exist, return nil", - preset: true, + name: "3. If key exists and is a sorted set, but the member does not exist, return nil", presetValues: map[string]interface{}{ "ZscoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, @@ -1975,85 +2172,103 @@ func Test_HandleZSCORE(t *testing.T) { }), }, command: []string{"ZSCORE", "ZscoreKey3", "non-existent"}, - expectedResponse: nil, + expectedResponse: "", expectedError: nil, }, { name: "4. Throw error when trying to find scores from elements that are not sorted sets", - preset: true, presetValues: map[string]interface{}{"ZscoreKey4": "Default value"}, command: []string{"ZSCORE", "ZscoreKey4", "one"}, expectedError: errors.New("value at ZscoreKey4 is not a sorted set"), }, { name: "5. Command too short", - preset: false, command: []string{"ZSCORE"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "6. Command too long", - preset: false, command: []string{"ZSCORE", "ZscoreKey5", "one", "two"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZSCORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if test.expectedResponse == nil { - if !rv.IsNull() { - t.Errorf("expected nil response, got %+v", rv) - } - return - } - if rv.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got %s", test.expectedResponse, rv.String()) + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) } }) } } func Test_HandleZRANDMEMBER(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool key string presetValue interface{} command []string @@ -2065,9 +2280,8 @@ func Test_HandleZRANDMEMBER(t *testing.T) { { // 1. Return multiple random elements without removing them. // Count is positive, do not allow repeated elements - name: "1. Return multiple random elements without removing them.", - preset: true, - key: "ZrandMemberKey1", + name: "1. Return multiple random elements without removing them.", + key: "ZrandMemberKey1", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, @@ -2084,9 +2298,8 @@ func Test_HandleZRANDMEMBER(t *testing.T) { { // 2. Return multiple random elements and their scores without removing them. // Count is negative, so allow repeated numbers. - name: "2. Return multiple random elements and their scores without removing them.", - preset: true, - key: "ZrandMemberKey2", + name: "2. Return multiple random elements and their scores without removing them.", + key: "ZrandMemberKey2", presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, @@ -2102,7 +2315,6 @@ func Test_HandleZRANDMEMBER(t *testing.T) { }, { name: "2. Return error when the source key is not a sorted set.", - preset: true, key: "ZrandMemberKey3", presetValue: "Default value", command: []string{"ZRANDMEMBER", "ZrandMemberKey3"}, @@ -2111,119 +2323,129 @@ func Test_HandleZRANDMEMBER(t *testing.T) { }, { name: "5. Command too short", - preset: false, command: []string{"ZRANDMEMBER"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "6. Command too long", - preset: false, command: []string{"ZRANDMEMBER", "source5", "source6", "member1", "member2"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "7. Throw error when count is not an integer", - preset: false, command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "count"}, expectedError: errors.New("count must be an integer"), }, { name: "8. Throw error when the fourth argument is not WITHSCORES", - preset: false, command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "8", "ANOTHER"}, expectedError: errors.New("last option must be WITHSCORES"), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANDMEMBER, %d", i)) + if test.presetValue != nil { + var command []resp.Value + var expected string - if test.preset { - if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, test.key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } return } - if err != nil { - t.Error(err) - } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - // 1. Check if the response array members are all included in test.expectedResponse. - for _, element := range rv.Array() { + + // Check that each of the returned elements is in the expected response. + for _, item := range res.Array() { + value := sorted_set.Value(item.Array()[0].String()) if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false - } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false - } - } - return true + return expected[0] == string(value) }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) + t.Errorf("unexected element \"%s\" in response", value) } - } - // 2. Fetch the set and check if its cardinality is what we expect. - if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { - t.Error(err) - } - set, ok := mockServer.GetValue(ctx, test.key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) - } - if set.Cardinality() != test.expectedValue { - t.Errorf("expected cardinality of final set to be %d, got %d", test.expectedValue, set.Cardinality()) - } - // 3. Check if all the returned elements we received are still in the set. - for _, element := range rv.Array() { - if !set.Contains(sorted_set.Value(element.Array()[0].String())) { - t.Errorf("expected element \"%s\" to be in set but it was not found", element.String()) - } - } - // 4. If allowRepeat is false, check that all the elements make a valid set - if !test.allowRepeat { - var elems []sorted_set.MemberParam - for _, e := range rv.Array() { - if len(e.Array()) == 1 { - elems = append(elems, sorted_set.MemberParam{ - Value: sorted_set.Value(e.Array()[0].String()), - Score: 1, - }) + for _, expected := range test.expectedResponse { + if len(item.Array()) != len(expected) { + t.Errorf("expected response for element \"%s\" to have length %d, got %d", + value, len(expected), len(item.Array())) + } + if expected[0] != string(value) { continue } - elems = append(elems, sorted_set.MemberParam{ - Value: sorted_set.Value(e.Array()[0].String()), - Score: sorted_set.Score(e.Array()[1].Float()), - }) + if len(expected) == 2 { + score := item.Array()[1].String() + if expected[1] != score { + t.Errorf("expected score for memebr \"%s\" to be %s, got %s", value, expected[1], score) + } + } } - s := sorted_set.NewSortedSet(elems) - if s.Cardinality() != len(elems) { - t.Errorf("expected non-repeating elements for random elements at key \"%s\"", test.key) + } + + // Check that allowRepeat determines whether elements are repeated or not. + if !test.allowRepeat { + ss := sorted_set.NewSortedSet([]sorted_set.MemberParam{}) + for _, item := range res.Array() { + member := sorted_set.Value(item.Array()[0].String()) + score := func() sorted_set.Score { + if len(item.Array()) == 2 { + return sorted_set.Score(item.Array()[1].Float()) + } + return sorted_set.Score(0) + }() + _, err = ss.AddOrUpdate( + []sorted_set.MemberParam{{member, score}}, + nil, nil, nil, nil) + if err != nil { + t.Error(err) + } + } + if len(res.Array()) != ss.Cardinality() { + t.Error("unexpected repeated elements in response") } } }) @@ -2231,17 +2453,21 @@ func Test_HandleZRANDMEMBER(t *testing.T) { } func Test_HandleZRANK(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedResponse []string expectedError error }{ { - name: "1. Return element's rank from a sorted set.", - preset: true, + name: "1. Return element's rank from a sorted set.", presetValues: map[string]interface{}{ "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2254,8 +2480,7 @@ func Test_HandleZRANK(t *testing.T) { expectedError: nil, }, { - name: "2. Return element's rank from a sorted set with its score.", - preset: true, + name: "2. Return element's rank from a sorted set with its score.", presetValues: map[string]interface{}{ "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100.1}, {Value: "two", Score: 245}, @@ -2269,15 +2494,13 @@ func Test_HandleZRANK(t *testing.T) { }, { name: "3. If key does not exist, return nil value", - preset: false, presetValues: nil, command: []string{"ZRANK", "ZrankKey3", "one"}, expectedResponse: nil, expectedError: nil, }, { - name: "4. If key exists and is a sorted set, but the member does not exist, return nil", - preset: true, + name: "4. If key exists and is a sorted set, but the member does not exist, return nil", presetValues: map[string]interface{}{ "ZrankKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, @@ -2291,75 +2514,90 @@ func Test_HandleZRANK(t *testing.T) { }, { name: "5. Throw error when trying to find scores from elements that are not sorted sets", - preset: true, presetValues: map[string]interface{}{"ZrankKey5": "Default value"}, command: []string{"ZRANK", "ZrankKey5", "one"}, expectedError: errors.New("value at ZrankKey5 is not a sorted set"), }, { name: "5. Command too short", - preset: false, command: []string{"ZRANK"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "6. Command too long", - preset: false, command: []string{"ZRANK", "ZrankKey5", "one", "WITHSCORES", "two"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANK, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if test.expectedResponse == nil { - if !rv.IsNull() { - t.Errorf("expected nil response, got %+v", rv) - } - return - } - if len(rv.Array()) != len(test.expectedResponse) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) - } - for i := 0; i < len(test.expectedResponse); i++ { - if rv.Array()[i].String() != test.expectedResponse[i] { - t.Errorf("expected element at index %d to be %s, got %s", i, test.expectedResponse[i], rv.Array()[i].String()) + + for i := 0; i < len(res.Array()); i++ { + if test.expectedResponse[i] != res.Array()[i].String() { + t.Errorf("expected element at index %d to be \"%s\", got %s", + i, test.expectedResponse[i], res.Array()[i].String()) } } }) @@ -2367,9 +2605,14 @@ func Test_HandleZRANK(t *testing.T) { } func Test_HandleZREM(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedValues map[string]*sorted_set.SortedSet @@ -2379,8 +2622,7 @@ func Test_HandleZREM(t *testing.T) { { // Successfully remove multiple elements from sorted set, skipping non-existent members. // Return deleted count. - name: "1. Successfully remove multiple elements from sorted set, skipping non-existent members.", - preset: true, + name: "1. Successfully remove multiple elements from sorted set, skipping non-existent members.", presetValues: map[string]interface{}{ "ZremKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2402,7 +2644,6 @@ func Test_HandleZREM(t *testing.T) { }, { name: "2. If key does not exist, return 0", - preset: false, presetValues: nil, command: []string{"ZREM", "ZremKey2", "member"}, expectedValues: nil, @@ -2410,8 +2651,7 @@ func Test_HandleZREM(t *testing.T) { expectedError: nil, }, { - name: "3. Return error key is not a sorted set", - preset: true, + name: "3. Return error key is not a sorted set", presetValues: map[string]interface{}{ "ZremKey3": "Default value", }, @@ -2420,65 +2660,111 @@ func Test_HandleZREM(t *testing.T) { }, { name: "9. Command too short", - preset: false, command: []string{"ZREM"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREM, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) - } - // Check if the expected sorted set is the same at the current one - if test.expectedValues != nil { - for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { - t.Error(err) + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) - } - if !set.Equals(expectedSet) { - t.Errorf("exptected sorted set %+v, got %+v", expectedSet, set) + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } @@ -2487,9 +2773,14 @@ func Test_HandleZREM(t *testing.T) { } func Test_HandleZREMRANGEBYSCORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedValues map[string]*sorted_set.SortedSet @@ -2497,8 +2788,7 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { expectedError error }{ { - name: "1. Successfully remove multiple elements with scores inside the provided range", - preset: true, + name: "1. Successfully remove multiple elements with scores inside the provided range", presetValues: map[string]interface{}{ "ZremRangeByScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2520,7 +2810,6 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { }, { name: "2. If key does not exist, return 0", - preset: false, presetValues: nil, command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey2", "2", "4"}, expectedValues: nil, @@ -2528,8 +2817,7 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { expectedError: nil, }, { - name: "3. Return error key is not a sorted set", - preset: true, + name: "3. Return error key is not a sorted set", presetValues: map[string]interface{}{ "ZremRangeByScoreKey3": "Default value", }, @@ -2538,71 +2826,116 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { }, { name: "4. Command too short", - preset: false, command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey4", "3"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Command too long", - preset: false, command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey5", "4", "5", "8"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYSCORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) - } - // Check if the expected values are the same - if test.expectedValues != nil { - for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { - t.Error(err) + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) - } - if !set.Equals(expectedSet) { - t.Errorf("exptected sorted set %+v, got %+v", expectedSet, set) + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } @@ -2611,9 +2944,14 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { } func Test_HandleZREMRANGEBYRANK(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedValues map[string]*sorted_set.SortedSet @@ -2621,8 +2959,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { expectedError error }{ { - name: "1. Successfully remove multiple elements within range", - preset: true, + name: "1. Successfully remove multiple elements within range", presetValues: map[string]interface{}{ "ZremRangeByRankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2643,8 +2980,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { expectedError: nil, }, { - name: "2. Establish boundaries from the end of the set when negative boundaries are provided", - preset: true, + name: "2. Establish boundaries from the end of the set when negative boundaries are provided", presetValues: map[string]interface{}{ "ZremRangeByRankKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2667,7 +3003,6 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { }, { name: "2. If key does not exist, return 0", - preset: false, presetValues: nil, command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey3", "2", "4"}, expectedValues: nil, @@ -2675,8 +3010,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { expectedError: nil, }, { - name: "3. Return error key is not a sorted set", - preset: true, + name: "3. Return error key is not a sorted set", presetValues: map[string]interface{}{ "ZremRangeByRankKey3": "Default value", }, @@ -2685,13 +3019,11 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { }, { name: "4. Command too short", - preset: false, command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey4", "3"}, expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "5. Return error when start index is out of bounds", - preset: true, + name: "5. Return error when start index is out of bounds", presetValues: map[string]interface{}{ "ZremRangeByRankKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2707,8 +3039,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { expectedError: errors.New("indices out of bounds"), }, { - name: "6. Return error when end index is out of bounds", - preset: true, + name: "6. Return error when end index is out of bounds", presetValues: map[string]interface{}{ "ZremRangeByRankKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2725,65 +3056,111 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { }, { name: "7. Command too long", - preset: false, command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey7", "4", "5", "8"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYRANK, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) - } - // Check if the expected values are the same - if test.expectedValues != nil { - for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { - t.Error(err) + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) - } - if !set.Equals(expectedSet) { - t.Errorf("exptected sorted set %+v, got %+v", expectedSet, set) + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } @@ -2792,9 +3169,14 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { } func Test_HandleZREMRANGEBYLEX(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedValues map[string]*sorted_set.SortedSet @@ -2802,8 +3184,7 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { expectedError error }{ { - name: "1. Successfully remove multiple elements with scores inside the provided range", - preset: true, + name: "1. Successfully remove multiple elements with scores inside the provided range", presetValues: map[string]interface{}{ "ZremRangeByLexKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 1}, @@ -2825,8 +3206,7 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { expectedError: nil, }, { - name: "2. Return 0 if the members do not have the same score", - preset: true, + name: "2. Return 0 if the members do not have the same score", presetValues: map[string]interface{}{ "ZremRangeByLexKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 2}, @@ -2851,7 +3231,6 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { }, { name: "3. If key does not exist, return 0", - preset: false, presetValues: nil, command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey3", "2", "4"}, expectedValues: nil, @@ -2859,8 +3238,7 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { expectedError: nil, }, { - name: "3. Return error key is not a sorted set", - preset: true, + name: "3. Return error key is not a sorted set", presetValues: map[string]interface{}{ "ZremRangeByLexKey3": "Default value", }, @@ -2869,71 +3247,116 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { }, { name: "4. Command too short", - preset: false, command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey4", "a"}, expectedError: errors.New(constants.WrongArgsResponse), }, { name: "5. Command too long", - preset: false, command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey5", "a", "b", "c"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYLEX, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, rv.Integer()) - } - // Check if the expected values are the same - if test.expectedValues != nil { - for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(ctx, key); err != nil { - t.Error(err) + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, key).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) - } - if !set.Equals(expectedSet) { - t.Errorf("exptected sorted set %+v, got %+v", expectedSet, set) + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } @@ -2942,17 +3365,21 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { } func Test_HandleZRANGE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedResponse [][]string expectedError error }{ { - name: "1. Get elements withing score range without score.", - preset: true, + name: "1. Get elements withing score range without score.", presetValues: map[string]interface{}{ "ZrangeKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2966,8 +3393,7 @@ func Test_HandleZRANGE(t *testing.T) { expectedError: nil, }, { - name: "2. Get elements within score range with score.", - preset: true, + name: "2. Get elements within score range with score.", presetValues: map[string]interface{}{ "ZrangeKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -2985,8 +3411,7 @@ func Test_HandleZRANGE(t *testing.T) { { // 3. Get elements within score range with offset and limit. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "3. Get elements within score range with offset and limit.", - preset: true, + name: "3. Get elements within score range with offset and limit.", presetValues: map[string]interface{}{ "ZrangeKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3003,8 +3428,7 @@ func Test_HandleZRANGE(t *testing.T) { // 4. Get elements within score range with offset and limit + reverse the results. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). // REV reverses the original set before getting the range. - name: "4. Get elements within score range with offset and limit + reverse the results.", - preset: true, + name: "4. Get elements within score range with offset and limit + reverse the results.", presetValues: map[string]interface{}{ "ZrangeKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3018,8 +3442,7 @@ func Test_HandleZRANGE(t *testing.T) { expectedError: nil, }, { - name: "5. Get elements within lex range without score.", - preset: true, + name: "5. Get elements within lex range without score.", presetValues: map[string]interface{}{ "ZrangeKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "e", Score: 1}, @@ -3033,8 +3456,7 @@ func Test_HandleZRANGE(t *testing.T) { expectedError: nil, }, { - name: "6. Get elements within lex range with score.", - preset: true, + name: "6. Get elements within lex range with score.", presetValues: map[string]interface{}{ "ZrangeKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "e", Score: 1}, @@ -3052,8 +3474,7 @@ func Test_HandleZRANGE(t *testing.T) { { // 7. Get elements within lex range with offset and limit. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "7. Get elements within lex range with offset and limit.", - preset: true, + name: "7. Get elements within lex range with offset and limit.", presetValues: map[string]interface{}{ "ZrangeKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 1}, @@ -3070,8 +3491,7 @@ func Test_HandleZRANGE(t *testing.T) { // 8. Get elements within lex range with offset and limit + reverse the results. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). // REV reverses the original set before getting the range. - name: "8. Get elements within lex range with offset and limit + reverse the results.", - preset: true, + name: "8. Get elements within lex range with offset and limit + reverse the results.", presetValues: map[string]interface{}{ "ZrangeKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 1}, @@ -3085,8 +3505,7 @@ func Test_HandleZRANGE(t *testing.T) { expectedError: nil, }, { - name: "9. Return an empty slice when we use BYLEX while elements have different scores", - preset: true, + name: "9. Return an empty slice when we use BYLEX while elements have different scores", presetValues: map[string]interface{}{ "ZrangeKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 5}, @@ -3101,7 +3520,6 @@ func Test_HandleZRANGE(t *testing.T) { }, { name: "10. Throw error when limit does not provide both offset and limit", - preset: false, presetValues: nil, command: []string{"ZRANGE", "ZrangeKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, expectedResponse: [][]string{}, @@ -3109,7 +3527,6 @@ func Test_HandleZRANGE(t *testing.T) { }, { name: "11. Throw error when offset is not a valid integer", - preset: false, presetValues: nil, command: []string{"ZRANGE", "ZrangeKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, expectedResponse: [][]string{}, @@ -3117,7 +3534,6 @@ func Test_HandleZRANGE(t *testing.T) { }, { name: "12. Throw error when limit is not a valid integer", - preset: false, presetValues: nil, command: []string{"ZRANGE", "ZrangeKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, expectedResponse: [][]string{}, @@ -3125,15 +3541,13 @@ func Test_HandleZRANGE(t *testing.T) { }, { name: "13. Throw error when offset is negative", - preset: false, presetValues: nil, command: []string{"ZRANGE", "ZrangeKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, expectedResponse: [][]string{}, expectedError: errors.New("limit offset must be >= 0"), }, { - name: "14. Throw error when the key does not hold a sorted set", - preset: true, + name: "14. Throw error when the key does not hold a sorted set", presetValues: map[string]interface{}{ "ZrangeKey14": "Default value", }, @@ -3143,15 +3557,13 @@ func Test_HandleZRANGE(t *testing.T) { }, { name: "15. Command too short", - preset: false, presetValues: nil, command: []string{"ZRANGE", "ZrangeKey15", "1"}, expectedResponse: [][]string{}, expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "16 Command too long", - preset: false, + name: "16. Command too long", presetValues: nil, command: []string{"ZRANGE", "ZrangeKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, expectedResponse: [][]string{}, @@ -3159,61 +3571,89 @@ func Test_HandleZRANGE(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - if len(rv.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(rv.Array())) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } }) @@ -3221,9 +3661,14 @@ func Test_HandleZRANGE(t *testing.T) { } func Test_HandleZRANGESTORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} destination string command []string @@ -3232,8 +3677,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { expectedError error }{ { - name: "1. Get elements withing score range without score.", - preset: true, + name: "1. Get elements withing score range without score.", presetValues: map[string]interface{}{ "ZrangeStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3252,8 +3696,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { expectedError: nil, }, { - name: "2. Get elements within score range with score.", - preset: true, + name: "2. Get elements within score range with score.", presetValues: map[string]interface{}{ "ZrangeStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3274,8 +3717,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { { // 3. Get elements within score range with offset and limit. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "3. Get elements within score range with offset and limit.", - preset: true, + name: "3. Get elements within score range with offset and limit.", presetValues: map[string]interface{}{ "ZrangeStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3296,8 +3738,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { // 4. Get elements within score range with offset and limit + reverse the results. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). // REV reverses the original set before getting the range. - name: "4. Get elements within score range with offset and limit + reverse the results.", - preset: true, + name: "4. Get elements within score range with offset and limit + reverse the results.", presetValues: map[string]interface{}{ "ZrangeStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3315,8 +3756,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { expectedError: nil, }, { - name: "5. Get elements within lex range without score.", - preset: true, + name: "5. Get elements within lex range without score.", presetValues: map[string]interface{}{ "ZrangeStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "e", Score: 1}, @@ -3335,8 +3775,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { expectedError: nil, }, { - name: "6. Get elements within lex range with score.", - preset: true, + name: "6. Get elements within lex range with score.", presetValues: map[string]interface{}{ "ZrangeStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "e", Score: 1}, @@ -3357,8 +3796,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { { // 7. Get elements within lex range with offset and limit. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "7. Get elements within lex range with offset and limit.", - preset: true, + name: "7. Get elements within lex range with offset and limit.", presetValues: map[string]interface{}{ "ZrangeStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 1}, @@ -3379,8 +3817,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { // 8. Get elements within lex range with offset and limit + reverse the results. // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). // REV reverses the original set before getting the range. - name: "8. Get elements within lex range with offset and limit + reverse the results.", - preset: true, + name: "8. Get elements within lex range with offset and limit + reverse the results.", presetValues: map[string]interface{}{ "ZrangeStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 1}, @@ -3398,8 +3835,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { expectedError: nil, }, { - name: "9. Return an empty slice when we use BYLEX while elements have different scores", - preset: true, + name: "9. Return an empty slice when we use BYLEX while elements have different scores", presetValues: map[string]interface{}{ "ZrangeStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "a", Score: 1}, {Value: "b", Score: 5}, @@ -3416,7 +3852,6 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "10. Throw error when limit does not provide both offset and limit", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey10", "ZrangeStoreKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, expectedResponse: 0, @@ -3424,7 +3859,6 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "11. Throw error when offset is not a valid integer", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey11", "ZrangeStoreKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, expectedResponse: 0, @@ -3432,7 +3866,6 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "12. Throw error when limit is not a valid integer", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey12", "ZrangeStoreKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, expectedResponse: 0, @@ -3440,15 +3873,13 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "13. Throw error when offset is negative", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey13", "ZrangeStoreKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, expectedResponse: 0, expectedError: errors.New("limit offset must be >= 0"), }, { - name: "14. Throw error when the key does not hold a sorted set", - preset: true, + name: "14. Throw error when the key does not hold a sorted set", presetValues: map[string]interface{}{ "ZrangeStoreKey14": "Default value", }, @@ -3458,7 +3889,6 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "15. Command too short", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreKey15", "1"}, expectedResponse: 0, @@ -3466,7 +3896,6 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, { name: "16 Command too long", - preset: false, presetValues: nil, command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey16", "ZrangeStoreKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, expectedResponse: 0, @@ -3474,76 +3903,125 @@ func Test_HandleZRANGESTORE(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGESTORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) - } - if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { - t.Error(err) + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, test.destination).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected vaule at key %s to be set, got another type", test.destination) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } - if !set.Equals(test.expectedValue) { - t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, set) - } - mockServer.KeyRUnlock(ctx, test.destination) } }) } } func Test_HandleZINTER(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedResponse [][]string expectedError error }{ { - name: "1. Get the intersection between 2 sorted sets.", - preset: true, + name: "1. Get the intersection between 2 sorted sets.", presetValues: map[string]interface{}{ "ZinterKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3563,8 +4041,7 @@ func Test_HandleZINTER(t *testing.T) { { // 2. Get the intersection between 3 sorted sets with scores. // By default, the SUM aggregate will be used. - name: "2. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "2. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3590,8 +4067,7 @@ func Test_HandleZINTER(t *testing.T) { { // 3. Get the intersection between 3 sorted sets with scores. // Use MIN aggregate. - name: "3. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "3. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3617,8 +4093,7 @@ func Test_HandleZINTER(t *testing.T) { { // 4. Get the intersection between 3 sorted sets with scores. // Use MAX aggregate. - name: "4. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "4. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3644,8 +4119,7 @@ func Test_HandleZINTER(t *testing.T) { { // 5. Get the intersection between 3 sorted sets with scores. // Use SUM aggregate with weights modifier. - name: "5. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "5. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3671,8 +4145,7 @@ func Test_HandleZINTER(t *testing.T) { { // 6. Get the intersection between 3 sorted sets with scores. // Use MAX aggregate with added weights. - name: "6. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "6. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3698,8 +4171,7 @@ func Test_HandleZINTER(t *testing.T) { { // 7. Get the intersection between 3 sorted sets with scores. // Use MIN aggregate with added weights. - name: "7. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "7. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3723,8 +4195,7 @@ func Test_HandleZINTER(t *testing.T) { expectedError: nil, }, { - name: "8. Throw an error if there are more weights than keys", - preset: true, + name: "8. Throw an error if there are more weights than keys", presetValues: map[string]interface{}{ "ZinterKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3739,8 +4210,7 @@ func Test_HandleZINTER(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "9. Throw an error if there are fewer weights than keys", - preset: true, + name: "9. Throw an error if there are fewer weights than keys", presetValues: map[string]interface{}{ "ZinterKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3758,8 +4228,7 @@ func Test_HandleZINTER(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "10. Throw an error if there are no keys provided", - preset: true, + name: "10. Throw an error if there are no keys provided", presetValues: map[string]interface{}{ "ZinterKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), "ZinterKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), @@ -3770,8 +4239,7 @@ func Test_HandleZINTER(t *testing.T) { expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "11. Throw an error if any of the provided keys are not sorted sets", - preset: true, + name: "11. Throw an error if any of the provided keys are not sorted sets", presetValues: map[string]interface{}{ "ZinterKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3787,8 +4255,7 @@ func Test_HandleZINTER(t *testing.T) { expectedError: errors.New("value at ZinterKey30 is not a sorted set"), }, { - name: "12. If any of the keys does not exist, return an empty array.", - preset: true, + name: "12. If any of the keys does not exist, return an empty array.", presetValues: map[string]interface{}{ "ZinterKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3807,65 +4274,95 @@ func Test_HandleZINTER(t *testing.T) { }, { name: "13. Command too short", - preset: false, command: []string{"ZINTER"}, expectedResponse: [][]string{}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTER, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } }) @@ -3873,9 +4370,14 @@ func Test_HandleZINTER(t *testing.T) { } func Test_HandleZINTERSTORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} destination string command []string @@ -3884,8 +4386,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError error }{ { - name: "1. Get the intersection between 2 sorted sets.", - preset: true, + name: "1. Get the intersection between 2 sorted sets.", presetValues: map[string]interface{}{ "ZinterStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3901,8 +4402,8 @@ func Test_HandleZINTERSTORE(t *testing.T) { destination: "ZinterStoreDestinationKey1", command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey1", "ZinterStoreKey1", "ZinterStoreKey2"}, expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, + {Value: "three", Score: 6}, {Value: "four", Score: 8}, + {Value: "five", Score: 10}, }), expectedResponse: 3, expectedError: nil, @@ -3910,8 +4411,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 2. Get the intersection between 3 sorted sets with scores. // By default, the SUM aggregate will be used. - name: "2. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "2. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -3931,9 +4431,11 @@ func Test_HandleZINTERSTORE(t *testing.T) { }), }, destination: "ZinterStoreDestinationKey2", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey2", "ZinterStoreKey3", "ZinterStoreKey4", "ZinterStoreKey5", "WITHSCORES"}, + command: []string{ + "ZINTERSTORE", "ZinterStoreDestinationKey2", "ZinterStoreKey3", "ZinterStoreKey4", "ZinterStoreKey5", "WITHSCORES", + }, expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 24}, + {Value: "one", Score: 3}, {Value: "eight", Score: 24}, }), expectedResponse: 2, expectedError: nil, @@ -3941,8 +4443,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 3. Get the intersection between 3 sorted sets with scores. // Use MIN aggregate. - name: "3. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "3. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -3972,8 +4473,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 4. Get the intersection between 3 sorted sets with scores. // Use MAX aggregate. - name: "4. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "4. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4003,8 +4503,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 5. Get the intersection between 3 sorted sets with scores. // Use SUM aggregate with weights modifier. - name: "5. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "5. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4026,7 +4525,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { destination: "ZinterStoreDestinationKey5", command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey5", "ZinterStoreKey12", "ZinterStoreKey13", "ZinterStoreKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 2808}, + {Value: "one", Score: 3105}, {Value: "eight", Score: 2808}, }), expectedResponse: 2, expectedError: nil, @@ -4034,8 +4533,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 6. Get the intersection between 3 sorted sets with scores. // Use MAX aggregate with added weights. - name: "6. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "6. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4065,8 +4563,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { { // 7. Get the intersection between 3 sorted sets with scores. // Use MIN aggregate with added weights. - name: "7. Get the intersection between 3 sorted sets with scores.", - preset: true, + name: "7. Get the intersection between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZinterStoreKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4094,8 +4591,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError: nil, }, { - name: "8. Throw an error if there are more weights than keys", - preset: true, + name: "8. Throw an error if there are more weights than keys", presetValues: map[string]interface{}{ "ZinterStoreKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4110,8 +4606,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "9. Throw an error if there are fewer weights than keys", - preset: true, + name: "9. Throw an error if there are fewer weights than keys", presetValues: map[string]interface{}{ "ZinterStoreKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4129,8 +4624,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "10. Throw an error if there are no keys provided", - preset: true, + name: "10. Throw an error if there are no keys provided", presetValues: map[string]interface{}{ "ZinterStoreKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), "ZinterStoreKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), @@ -4141,8 +4635,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "11. Throw an error if any of the provided keys are not sorted sets", - preset: true, + name: "11. Throw an error if any of the provided keys are not sorted sets", presetValues: map[string]interface{}{ "ZinterStoreKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4158,8 +4651,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { expectedError: errors.New("value at ZinterStoreKey30 is not a sorted set"), }, { - name: "12. If any of the keys does not exist, return an empty array.", - preset: true, + name: "12. If any of the keys does not exist, return an empty array.", presetValues: map[string]interface{}{ "ZinterStoreKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4178,85 +4670,131 @@ func Test_HandleZINTERSTORE(t *testing.T) { }, { name: "13. Command too short", - preset: false, command: []string{"ZINTERSTORE"}, expectedResponse: 0, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTERSTORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) - } - if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { - t.Error(err) + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, test.destination).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected vaule at key %s to be set, got another type", test.destination) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } - for _, elem := range set.GetAll() { - if !test.expectedValue.Contains(elem.Value) { - t.Errorf("could not find element %s in the expected values", elem.Value) - } - } - mockServer.KeyRUnlock(ctx, test.destination) } }) } } func Test_HandleZUNION(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string - preset bool presetValues map[string]interface{} command []string expectedResponse [][]string expectedError error }{ { - name: "1. Get the union between 2 sorted sets.", - preset: true, + name: "1. Get the union between 2 sorted sets.", presetValues: map[string]interface{}{ "ZunionKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4276,8 +4814,7 @@ func Test_HandleZUNION(t *testing.T) { { // 2. Get the union between 3 sorted sets with scores. // By default, the SUM aggregate will be used. - name: "2. Get the union between 3 sorted sets with scores.", - preset: true, + name: "2. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4307,8 +4844,7 @@ func Test_HandleZUNION(t *testing.T) { { // 3. Get the union between 3 sorted sets with scores. // Use MIN aggregate. - name: "3. Get the union between 3 sorted sets with scores.", - preset: true, + name: "3. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4338,8 +4874,7 @@ func Test_HandleZUNION(t *testing.T) { { // 4. Get the union between 3 sorted sets with scores. // Use MAX aggregate. - name: "4. Get the union between 3 sorted sets with scores.", - preset: true, + name: "4. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4369,8 +4904,7 @@ func Test_HandleZUNION(t *testing.T) { { // 5. Get the union between 3 sorted sets with scores. // Use SUM aggregate with weights modifier. - name: "5. Get the union between 3 sorted sets with scores.", - preset: true, + name: "5. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4400,8 +4934,7 @@ func Test_HandleZUNION(t *testing.T) { { // 6. Get the union between 3 sorted sets with scores. // Use MAX aggregate with added weights. - name: "6. Get the union between 3 sorted sets with scores.", - preset: true, + name: "6. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4431,8 +4964,7 @@ func Test_HandleZUNION(t *testing.T) { { // 7. Get the union between 3 sorted sets with scores. // Use MIN aggregate with added weights. - name: "7. Get the union between 3 sorted sets with scores.", - preset: true, + name: "7. Get the union between 3 sorted sets with scores.", presetValues: map[string]interface{}{ "ZunionKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 100}, {Value: "two", Score: 2}, @@ -4459,8 +4991,7 @@ func Test_HandleZUNION(t *testing.T) { expectedError: nil, }, { - name: "8. Throw an error if there are more weights than keys", - preset: true, + name: "8. Throw an error if there are more weights than keys", presetValues: map[string]interface{}{ "ZunionKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4475,8 +5006,7 @@ func Test_HandleZUNION(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "9. Throw an error if there are fewer weights than keys", - preset: true, + name: "9. Throw an error if there are fewer weights than keys", presetValues: map[string]interface{}{ "ZunionKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4494,8 +5024,7 @@ func Test_HandleZUNION(t *testing.T) { expectedError: errors.New("number of weights should match number of keys"), }, { - name: "10. Throw an error if there are no keys provided", - preset: true, + name: "10. Throw an error if there are no keys provided", presetValues: map[string]interface{}{ "ZunionKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), "ZunionKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), @@ -4506,8 +5035,7 @@ func Test_HandleZUNION(t *testing.T) { expectedError: errors.New(constants.WrongArgsResponse), }, { - name: "11. Throw an error if any of the provided keys are not sorted sets", - preset: true, + name: "11. Throw an error if any of the provided keys are not sorted sets", presetValues: map[string]interface{}{ "ZunionKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4523,8 +5051,7 @@ func Test_HandleZUNION(t *testing.T) { expectedError: errors.New("value at ZunionKey30 is not a sorted set"), }, { - name: "12. If any of the keys does not exist, skip it.", - preset: true, + name: "12. If any of the keys does not exist, skip it.", presetValues: map[string]interface{}{ "ZunionKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, @@ -4546,64 +5073,94 @@ func Test_HandleZUNION(t *testing.T) { }, { name: "13. Command too short", - preset: false, command: []string{"ZUNION"}, expectedError: errors.New(constants.WrongArgsResponse), }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNION, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } + } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } - if err != nil { - t.Error(err) + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) - } - for _, element := range rv.Array() { - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - // The current sub-slice is a different length, return false because they're not equal - if len(element.Array()) != len(expected) { - return false + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() } - for i := 0; i < len(expected); i++ { - if element.Array()[i].String() != expected[i] { - return false + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) } } - return true - }) { - t.Errorf("expected response %+v, got %+v", test.expectedResponse, rv.Array()) } } }) @@ -4611,6 +5168,12 @@ func Test_HandleZUNION(t *testing.T) { } func Test_HandleZUNIONSTORE(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + t.Error() + } + client := resp.NewConn(conn) + tests := []struct { name string preset bool @@ -4640,8 +5203,8 @@ func Test_HandleZUNIONSTORE(t *testing.T) { command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey1", "ZunionStoreKey1", "ZunionStoreKey2"}, expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "three", Score: 6}, {Value: "four", Score: 8}, + {Value: "five", Score: 10}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), expectedResponse: 8, @@ -4947,7 +5510,7 @@ func Test_HandleZUNIONSTORE(t *testing.T) { command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey12", "non-existent", "ZunionStoreKey32", "ZunionStoreKey33"}, expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 12}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 24}, {Value: "thirty-six", Score: 36}, }), expectedResponse: 9, @@ -4962,61 +5525,104 @@ func Test_HandleZUNIONSTORE(t *testing.T) { }, } - for i, test := range tests { + for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNIONSTORE, %d", i)) - - if test.preset { + if test.presetValues != nil { + var command []resp.Value + var expected string for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - if err := mockServer.SetValue(ctx, key, value); err != nil { + res, _, err := client.ReadValue() + if err != nil { t.Error(err) } - mockServer.KeyUnlock(ctx, key) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } } - handler := getHandler(test.command[0]) - if handler == nil { - t.Errorf("no handler found for command %s", test.command[0]) - return + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - res, err := handler(getHandlerFuncParams(ctx, test.command, nil)) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } if test.expectedError != nil { - if err.Error() != test.expectedError.Error() { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } return } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - rd := resp.NewReader(bytes.NewBuffer(res)) - rv, _, err := rd.ReadValue() - if err != nil { - t.Error(err) + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) } - if rv.Integer() != test.expectedResponse { - t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) - } - if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { - t.Error(err) + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - set, ok := mockServer.GetValue(ctx, test.destination).(*sorted_set.SortedSet) - if !ok { - t.Errorf("expected vaule at key %s to be set, got another type", test.destination) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } - for _, elem := range set.GetAll() { - if !test.expectedValue.Contains(elem.Value) { - t.Errorf("could not find element %s in the expected values", elem.Value) - } - } - mockServer.KeyRUnlock(ctx, test.destination) } }) } diff --git a/internal/modules/sorted_set/key_funcs.go b/internal/modules/sorted_set/key_funcs.go index 8c9a5cb..05ffff6 100644 --- a/internal/modules/sorted_set/key_funcs.go +++ b/internal/modules/sorted_set/key_funcs.go @@ -135,13 +135,11 @@ func zinterstoreKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) } endIdx := slices.IndexFunc(cmd[1:], func(s string) bool { - if strings.EqualFold(s, "WEIGHTS") || + return strings.EqualFold(s, "WEIGHTS") || strings.EqualFold(s, "AGGREGATE") || - strings.EqualFold(s, "WITHSCORES") { - return true - } - return false + strings.EqualFold(s, "WITHSCORES") }) + if endIdx == -1 { return internal.KeyExtractionFuncResult{ Channels: make([]string, 0), @@ -149,13 +147,15 @@ func zinterstoreKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) WriteKeys: cmd[1:2], }, nil } + if endIdx >= 3 { return internal.KeyExtractionFuncResult{ Channels: make([]string, 0), - ReadKeys: cmd[2:endIdx], + ReadKeys: cmd[2 : endIdx+1], WriteKeys: cmd[1:2], }, nil } + return internal.KeyExtractionFuncResult{}, errors.New(constants.WrongArgsResponse) } @@ -377,7 +377,7 @@ func zunionstoreKeyFunc(cmd []string) (internal.KeyExtractionFuncResult, error) if endIdx >= 1 { return internal.KeyExtractionFuncResult{ Channels: make([]string, 0), - ReadKeys: cmd[2:endIdx], + ReadKeys: cmd[2 : endIdx+1], WriteKeys: cmd[1:2], }, nil }