Implemented tests for sorted set API

This commit is contained in:
Kelvin Mwinuka
2024-03-31 05:55:26 +08:00
parent 8af093741f
commit f061af6de6
12 changed files with 146548 additions and 2785 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -28,17 +28,17 @@ type Value string
type Score float64
// MemberObject is the shape of the object as it's stored in the map that represents the set
// MemberObject is the shape of the object as it's stored in the map that represents the Set
type MemberObject struct {
value Value
score Score
exists bool
Value Value
Score Score
Exists bool
}
// MemberParam is the shape of the object passed as a parameter to NewSortedSet and the Add method
type MemberParam struct {
value Value
score Score
Value Value
Score Score
}
type SortedSet struct {
@@ -50,17 +50,17 @@ func NewSortedSet(members []MemberParam) *SortedSet {
members: make(map[Value]MemberObject),
}
for _, m := range members {
s.members[m.value] = MemberObject{
value: m.value,
score: m.score,
exists: true,
s.members[m.Value] = MemberObject{
Value: m.Value,
Score: m.Score,
Exists: true,
}
}
return s
}
func (set *SortedSet) Contains(m Value) bool {
return set.members[m].exists
return set.members[m].Exists
}
func (set *SortedSet) Get(v Value) MemberObject {
@@ -89,11 +89,11 @@ func (set *SortedSet) GetRandom(count int) []MemberParam {
for i := 0; i < internal.AbsInt(count); {
n = rand.Intn(len(members))
if !slices.ContainsFunc(res, func(m MemberParam) bool {
return m.value == members[n].value
return m.Value == members[n].Value
}) {
res = append(res, members[n])
slices.DeleteFunc(members, func(m MemberParam) bool {
return m.value == members[n].value
return m.Value == members[n].Value
})
i++
}
@@ -107,8 +107,8 @@ func (set *SortedSet) GetAll() []MemberParam {
var res []MemberParam
for k, v := range set.members {
res = append(res, MemberParam{
value: k,
score: v.score,
Value: k,
Score: v.Score,
})
}
return res
@@ -141,31 +141,31 @@ func (set *SortedSet) AddOrUpdate(
return 0, errors.New("cannot use GT or LT when update policy is NX")
}
if strings.EqualFold(inc, "incr") && len(members) != 1 {
return 0, errors.New("INCR can only be used with one member/score pair")
return 0, errors.New("INCR can only be used with one member/Score pair")
}
count := 0
if strings.EqualFold(inc, "incr") {
for _, m := range members {
if !set.Contains(m.value) {
// If the member is not contained, add it with the increment as its score
set.members[m.value] = MemberObject{
value: m.value,
score: m.score,
exists: true,
if !set.Contains(m.Value) {
// If the member is not contained, add it with the increment as its Score
set.members[m.Value] = MemberObject{
Value: m.Value,
Score: m.Score,
Exists: true,
}
// Always add count because this is the addition of a new element
count += 1
return count, err
}
if slices.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.value].score) {
if slices.Contains([]Score{Score(math.Inf(-1)), Score(math.Inf(1))}, set.members[m.Value].Score) {
return count, errors.New("cannot increment -inf or +inf")
}
set.members[m.value] = MemberObject{
value: m.value,
score: set.members[m.value].score + m.score,
exists: true,
set.members[m.Value] = MemberObject{
Value: m.Value,
Score: set.members[m.Value].Score + m.Score,
Exists: true,
}
if strings.EqualFold(ch, "ch") {
count += 1
@@ -177,11 +177,11 @@ func (set *SortedSet) AddOrUpdate(
for _, m := range members {
if strings.EqualFold(policy, "xx") {
// Only update existing elements, do not add new elements
if set.Contains(m.value) {
set.members[m.value] = MemberObject{
value: m.value,
score: compareScores(set.members[m.value].score, m.score, comp),
exists: true,
if set.Contains(m.Value) {
set.members[m.Value] = MemberObject{
Value: m.Value,
Score: compareScores(set.members[m.Value].Score, m.Score, comp),
Exists: true,
}
if strings.EqualFold(ch, "ch") {
count += 1
@@ -191,24 +191,24 @@ func (set *SortedSet) AddOrUpdate(
}
if strings.EqualFold(policy, "nx") {
// Only add new elements, do not update existing elements
if !set.Contains(m.value) {
set.members[m.value] = MemberObject{
value: m.value,
score: m.score,
exists: true,
if !set.Contains(m.Value) {
set.members[m.Value] = MemberObject{
Value: m.Value,
Score: m.Score,
Exists: true,
}
count += 1
}
continue
}
// Policy not specified, just set the elements and scores
if set.members[m.value].score != m.score || !set.members[m.value].exists {
// Policy not specified, just Set the elements and scores
if set.members[m.Value].Score != m.Score || !set.members[m.Value].Exists {
count += 1
}
set.members[m.value] = MemberObject{
value: m.value,
score: compareScores(set.members[m.value].score, m.score, comp),
exists: true,
set.members[m.Value] = MemberObject{
Value: m.Value,
Score: compareScores(set.members[m.Value].Score, m.Score, comp),
Exists: true,
}
}
return count, nil
@@ -238,16 +238,16 @@ func (set *SortedSet) Pop(count int, policy string) (*SortedSet, error) {
slices.SortFunc(members, func(a, b MemberParam) int {
if strings.EqualFold(policy, "min") {
return cmp.Compare(a.score, b.score)
return cmp.Compare(a.Score, b.Score)
}
return cmp.Compare(b.score, a.score)
return cmp.Compare(b.Score, a.Score)
})
for i := 0; i < count; i++ {
if i >= len(members) {
break
}
set.Remove(members[i].value)
set.Remove(members[i].Value)
_, err := popped.AddOrUpdate([]MemberParam{members[i]}, nil, nil, nil, nil)
if err != nil {
return nil, err
@@ -261,8 +261,8 @@ func (set *SortedSet) Subtract(others []*SortedSet) *SortedSet {
res := NewSortedSet(set.GetAll())
for _, ss := range others {
for _, m := range ss.GetAll() {
if res.Contains(m.value) {
res.Remove(m.value)
if res.Contains(m.Value) {
res.Remove(m.Value)
}
}
}
@@ -271,8 +271,8 @@ func (set *SortedSet) Subtract(others []*SortedSet) *SortedSet {
// SortedSetParam is a composite object used for Intersect and Union function
type SortedSetParam struct {
set *SortedSet
weight int
Set *SortedSet
Weight int
}
func (set *SortedSet) Equals(other *SortedSet) bool {
@@ -283,10 +283,10 @@ func (set *SortedSet) Equals(other *SortedSet) bool {
return true
}
for _, member := range set.members {
if !other.Contains(member.value) {
if !other.Contains(member.Value) {
return false
}
if member.score != other.Get(member.value).score {
if member.Score != other.Get(member.Value).Score {
return false
}
}
@@ -300,29 +300,29 @@ func Union(aggregate string, setParams ...SortedSetParam) *SortedSet {
return NewSortedSet([]MemberParam{})
case 1:
var params []MemberParam
for _, member := range setParams[0].set.GetAll() {
for _, member := range setParams[0].Set.GetAll() {
params = append(params, MemberParam{
value: member.value,
score: member.score * Score(setParams[0].weight),
Value: member.Value,
Score: member.Score * Score(setParams[0].Weight),
})
}
return NewSortedSet(params)
case 2:
var params []MemberParam
// Traverse the params in the left sorted set
for _, member := range setParams[0].set.GetAll() {
// If the member does not exist in the other sorted set, add it to params along with the appropriate weight
if !setParams[1].set.Contains(member.value) {
// Traverse the params in the left sorted Set
for _, member := range setParams[0].Set.GetAll() {
// If the member does not exist in the other sorted Set, add it to params along with the appropriate Weight
if !setParams[1].Set.Contains(member.Value) {
params = append(params, MemberParam{
value: member.value,
score: member.score * Score(setParams[0].weight),
Value: member.Value,
Score: member.Score * Score(setParams[0].Weight),
})
continue
}
// If the member exists, get both elements and apply the weight
// If the member Exists, get both elements and apply the Weight
param := MemberParam{
value: member.value,
score: func(left, right Score) Score {
Value: member.Value,
Score: func(left, right Score) Score {
// Choose which param to add to params depending on the aggregate
switch aggregate {
case "sum":
@@ -334,21 +334,21 @@ func Union(aggregate string, setParams ...SortedSetParam) *SortedSet {
return compareScores(left, right, "gt")
}
}(
member.score*Score(setParams[0].weight),
setParams[1].set.Get(member.value).score*Score(setParams[1].weight),
member.Score*Score(setParams[0].Weight),
setParams[1].Set.Get(member.Value).Score*Score(setParams[1].Weight),
),
}
params = append(params, param)
}
// Traverse the params on the right sorted set and add all the elements that are not
// Traverse the params on the right sorted Set and add all the elements that are not
// already contained in params with their respective weights applied.
for _, member := range setParams[1].set.GetAll() {
for _, member := range setParams[1].Set.GetAll() {
if !slices.ContainsFunc(params, func(param MemberParam) bool {
return param.value == member.value
return param.Value == member.Value
}) {
params = append(params, MemberParam{
value: member.value,
score: member.score * Score(setParams[1].weight),
Value: member.Value,
Score: member.Score * Score(setParams[1].Weight),
})
}
}
@@ -359,16 +359,16 @@ func Union(aggregate string, setParams ...SortedSetParam) *SortedSet {
right := Union(aggregate, setParams[len(setParams)/2:]...)
var params []MemberParam
// Traverse left sub-set and add the union elements to params
// Traverse left sub-Set and add the union elements to params
for _, member := range left.GetAll() {
if !right.Contains(member.value) {
// If the right set does not contain the current element, just add it to params
if !right.Contains(member.Value) {
// If the right Set does not contain the current element, just add it to params
params = append(params, member)
continue
}
params = append(params, MemberParam{
value: member.value,
score: func(left, right Score) Score {
Value: member.Value,
Score: func(left, right Score) Score {
switch aggregate {
case "sum":
return left + right
@@ -378,13 +378,13 @@ func Union(aggregate string, setParams ...SortedSetParam) *SortedSet {
// Aggregate is "max"
return compareScores(left, right, "gt")
}
}(member.score, right.Get(member.value).score),
}(member.Score, right.Get(member.Value).Score),
})
}
// Traverse the right sub-set and add any remaining elements to params
// Traverse the right sub-Set and add any remaining elements to params
for _, member := range right.GetAll() {
if !slices.ContainsFunc(params, func(param MemberParam) bool {
return param.value == member.value
return param.Value == member.Value
}) {
params = append(params, member)
}
@@ -400,25 +400,25 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
return NewSortedSet([]MemberParam{})
case 1:
var params []MemberParam
for _, member := range setParams[0].set.GetAll() {
for _, member := range setParams[0].Set.GetAll() {
params = append(params, MemberParam{
value: member.value,
score: member.score * Score(setParams[0].weight),
Value: member.Value,
Score: member.Score * Score(setParams[0].Weight),
})
}
return NewSortedSet(params)
case 2:
var params []MemberParam
// Traverse the params in the left sorted set
for _, member := range setParams[0].set.GetAll() {
// Check if the member exists in the right sorted set
if !setParams[1].set.Contains(member.value) {
// Traverse the params in the left sorted Set
for _, member := range setParams[0].Set.GetAll() {
// Check if the member Exists in the right sorted Set
if !setParams[1].Set.Contains(member.Value) {
continue
}
// If the member exists, get both elements and apply the weight
// If the member Exists, get both elements and apply the Weight
param := MemberParam{
value: member.value,
score: func(left, right Score) Score {
Value: member.Value,
Score: func(left, right Score) Score {
// Choose which param to add to params depending on the aggregate
switch aggregate {
case "sum":
@@ -430,8 +430,8 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
return compareScores(left, right, "gt")
}
}(
member.score*Score(setParams[0].weight),
setParams[1].set.Get(member.value).score*Score(setParams[1].weight),
member.Score*Score(setParams[0].Weight),
setParams[1].Set.Get(member.Value).Score*Score(setParams[1].Weight),
),
}
params = append(params, param)
@@ -444,12 +444,12 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
var params []MemberParam
for _, member := range left.GetAll() {
if !right.Contains(member.value) {
if !right.Contains(member.Value) {
continue
}
params = append(params, MemberParam{
value: member.value,
score: func(left, right Score) Score {
Value: member.Value,
Score: func(left, right Score) Score {
switch aggregate {
case "sum":
return left + right
@@ -459,7 +459,7 @@ func Intersect(aggregate string, setParams ...SortedSetParam) *SortedSet {
// Aggregate is "max"
return compareScores(left, right, "gt")
}
}(member.score, right.Get(member.value).score),
}(member.Score, right.Get(member.Value).Score),
})
}

View File

@@ -0,0 +1,98 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sorted_set
import (
"errors"
"slices"
"strings"
)
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of Value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of Value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of Value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of Value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}

View File

@@ -17,6 +17,7 @@ package internal
import (
"bufio"
"bytes"
"cmp"
"errors"
"fmt"
"github.com/echovault/echovault/pkg/utils"
@@ -24,6 +25,7 @@ import (
"log"
"math/big"
"net"
"reflect"
"runtime"
"slices"
"strconv"
@@ -220,6 +222,35 @@ func FilterExpiredKeys(state map[string]KeyData) map[string]KeyData {
return state
}
// CompareLex returns -1 when s2 is lexicographically greater than s1,
// 0 if they're equal and 1 if s2 is lexicographically less than s1.
func CompareLex(s1 string, s2 string) int {
if s1 == s2 {
return 0
}
if strings.Contains(s1, s2) {
return 1
}
if strings.Contains(s2, s1) {
return -1
}
limit := len(s1)
if len(s2) < limit {
limit = len(s2)
}
var c int
for i := 0; i < limit; i++ {
c = cmp.Compare(s1[i], s2[i])
if c != 0 {
break
}
}
return c
}
func EncodeCommand(cmd []string) []byte {
res := fmt.Sprintf("*%d\r\n", len(cmd))
for _, token := range cmd {
@@ -356,3 +387,21 @@ func ParseBooleanArrayResponse(b []byte) ([]bool, error) {
}
return arr, nil
}
func CompareNestedStringArrays(got [][]string, want [][]string) bool {
for _, wantItem := range want {
if !slices.ContainsFunc(got, func(gotItem []string) bool {
return reflect.DeepEqual(wantItem, gotItem)
}) {
return false
}
}
for _, gotItem := range got {
if !slices.ContainsFunc(want, func(wantItem []string) bool {
return reflect.DeepEqual(wantItem, gotItem)
}) {
return false
}
}
return true
}

File diff suppressed because it is too large Load Diff

View File

@@ -7,12 +7,6 @@ import (
"testing"
)
func presetValue(server *EchoVault, key string, value interface{}) {
_, _ = server.CreateKeyAndLock(server.context, key)
_ = server.SetValue(server.context, key, value)
server.KeyUnlock(server.context, key)
}
func TestEchoVault_LLEN(t *testing.T) {
server := NewEchoVault(WithCommands(commands.All()))

View File

@@ -40,15 +40,15 @@ type ZUNIONSTOREOptions ZINTEROptions
type ZMPOPOptions struct {
Min bool
Max bool
Count int
Count uint
}
type ZRANGEOptions struct {
ByScore bool
ByLex bool
Rev bool
Offset int
Count int
WithScores bool
ByScore bool
ByLex bool
Offset uint
Count uint
}
type ZRANGESTOREOptions ZRANGEOptions
@@ -87,8 +87,8 @@ func buildIntegerScoreMap(arr [][]string, withscores bool) (map[int]float64, err
return result, nil
}
func (server *EchoVault) ZADD(entries map[string]float64, options ZADDOptions) (int, error) {
cmd := []string{"ZADD"}
func (server *EchoVault) ZADD(key string, entries map[string]float64, options ZADDOptions) (int, error) {
cmd := []string{"ZADD", key}
switch {
case options.NX:
@@ -113,7 +113,7 @@ func (server *EchoVault) ZADD(entries map[string]float64, options ZADDOptions) (
}
for member, score := range entries {
cmd = append(cmd, []string{member, strconv.FormatFloat(score, 'f', -1, 64)}...)
cmd = append(cmd, []string{strconv.FormatFloat(score, 'f', -1, 64), member}...)
}
b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false)
@@ -191,8 +191,8 @@ func (server *EchoVault) ZINTER(keys []string, options ZINTEROptions) (map[strin
if len(options.Weights) > 0 {
cmd = append(cmd, "WEIGHTS")
for _, weight := range options.Weights {
cmd = append(cmd, strconv.FormatFloat(float64(weight), 'f', -1, 64))
for i := 0; i < len(options.Weights); i++ {
cmd = append(cmd, strconv.FormatFloat(options.Weights[i], 'f', -1, 64))
}
}
@@ -314,7 +314,7 @@ func (server *EchoVault) ZMPOP(keys []string, options ZMPOPOptions) ([][]string,
switch {
case options.Count != 0:
cmd = append(cmd, []string{"COUNT", strconv.Itoa(options.Count)}...)
cmd = append(cmd, []string{"COUNT", strconv.Itoa(int(options.Count))}...)
default:
cmd = append(cmd, []string{"COUNT", strconv.Itoa(1)}...)
}
@@ -327,7 +327,7 @@ func (server *EchoVault) ZMPOP(keys []string, options ZMPOPOptions) ([][]string,
return internal.ParseNestedStringArrayResponse(b)
}
func (server *EchoVault) ZMSCORE(key string, members ...string) ([]float64, error) {
func (server *EchoVault) ZMSCORE(key string, members ...string) ([]interface{}, error) {
cmd := []string{"ZMSCORE", key}
for _, member := range members {
cmd = append(cmd, member)
@@ -343,8 +343,12 @@ func (server *EchoVault) ZMSCORE(key string, members ...string) ([]float64, erro
return nil, err
}
scores := make([]float64, len(arr))
scores := make([]interface{}, len(arr))
for i, e := range arr {
if e == "" {
scores[i] = nil
continue
}
score, err := strconv.ParseFloat(e, 64)
if err != nil {
return nil, err
@@ -381,7 +385,7 @@ func (server *EchoVault) ZPOPMIN(key string, count int) ([][]string, error) {
return internal.ParseNestedStringArrayResponse(b)
}
func (server *EchoVault) ZRANDMEMBER(key string, count int, withscores bool) (map[string]float64, error) {
func (server *EchoVault) ZRANDMEMBER(key string, count int, withscores bool) ([][]string, error) {
cmd := []string{"ZRANDMEMBER", key}
if count != 0 {
cmd = append(cmd, strconv.Itoa(count))
@@ -395,12 +399,7 @@ func (server *EchoVault) ZRANDMEMBER(key string, count int, withscores bool) (ma
return nil, err
}
arr, err := internal.ParseNestedStringArrayResponse(b)
if err != nil {
return nil, err
}
return buildMemberScoreMap(arr, withscores)
return internal.ParseNestedStringArrayResponse(b)
}
func (server *EchoVault) ZRANK(key string, member string, withscores bool) (map[int]float64, error) {
@@ -414,9 +413,28 @@ func (server *EchoVault) ZRANK(key string, member string, withscores bool) (map[
return nil, err
}
arr, err := internal.ParseNestedStringArrayResponse(b)
arr, err := internal.ParseStringArrayResponse(b)
return buildIntegerScoreMap(arr, withscores)
if len(arr) == 0 {
return nil, nil
}
s, err := strconv.Atoi(arr[0])
if err != nil {
return nil, err
}
res := map[int]float64{s: 0}
if withscores {
f, err := strconv.ParseFloat(arr[1], 64)
if err != nil {
return nil, err
}
res[s] = f
}
return res, nil
}
func (server *EchoVault) ZREVRANK(key string, member string, withscores bool) (map[int]float64, error) {
@@ -508,11 +526,13 @@ func (server *EchoVault) ZRANGE(key, start, stop string, options ZRANGEOptions)
cmd = append(cmd, "BYSCORE")
}
if options.Rev {
cmd = append(cmd, "REV")
if options.WithScores {
cmd = append(cmd, "WITHSCORES")
}
cmd = append(cmd, []string{"LIMIT", strconv.Itoa(options.Offset), strconv.Itoa(options.Count)}...)
if options.Offset != 0 && options.Count != 0 {
cmd = append(cmd, []string{"LIMIT", strconv.Itoa(int(options.Offset)), strconv.Itoa(int(options.Count))}...)
}
b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false)
if err != nil {
@@ -524,7 +544,7 @@ func (server *EchoVault) ZRANGE(key, start, stop string, options ZRANGEOptions)
return nil, err
}
return buildMemberScoreMap(arr, true)
return buildMemberScoreMap(arr, options.WithScores)
}
func (server *EchoVault) ZRANGESTORE(destination, source, start, stop string, options ZRANGESTOREOptions) (int, error) {
@@ -539,12 +559,10 @@ func (server *EchoVault) ZRANGESTORE(destination, source, start, stop string, op
cmd = append(cmd, "BYSCORE")
}
if options.Rev {
cmd = append(cmd, "REV")
if options.Offset != 0 && options.Count != 0 {
cmd = append(cmd, []string{"LIMIT", strconv.Itoa(int(options.Offset)), strconv.Itoa(int(options.Count))}...)
}
cmd = append(cmd, []string{"LIMIT", strconv.Itoa(options.Offset), strconv.Itoa(options.Count)}...)
b, err := server.handleCommand(server.context, internal.EncodeCommand(cmd), nil, false)
if err != nil {
return 0, err

File diff suppressed because it is too large Load Diff

View File

@@ -565,3 +565,9 @@ func (server *EchoVault) evictKeysWithExpiredTTL(ctx context.Context) error {
return nil
}
func presetValue(server *EchoVault, key string, value interface{}) {
_, _ = server.CreateKeyAndLock(server.context, key)
_ = server.SetValue(server.context, key, value)
server.KeyUnlock(server.context, key)
}

View File

@@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/sorted_set"
"github.com/echovault/echovault/pkg/utils"
"math"
"net"
@@ -63,7 +64,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.EchoVault, conn
return nil, errors.New("score/member pairs must be float/string")
}
var members []MemberParam
var members []sorted_set.MemberParam
for i := 0; i < len(cmd[membersStartIndex:]); i++ {
if i%2 != 0 {
@@ -77,29 +78,29 @@ func handleZADD(ctx context.Context, cmd []string, server utils.EchoVault, conn
var s float64
if strings.ToLower(score.(string)) == "-inf" {
s = math.Inf(-1)
members = append(members, MemberParam{
value: Value(cmd[membersStartIndex:][i+1]),
score: Score(s),
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(cmd[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
})
}
if strings.ToLower(score.(string)) == "+inf" {
s = math.Inf(1)
members = append(members, MemberParam{
value: Value(cmd[membersStartIndex:][i+1]),
score: Score(s),
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(cmd[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
})
}
case float64:
s, _ := score.(float64)
members = append(members, MemberParam{
value: Value(cmd[membersStartIndex:][i+1]),
score: Score(s),
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(cmd[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
})
case int:
s, _ := score.(int)
members = append(members, MemberParam{
value: Value(cmd[membersStartIndex:][i+1]),
score: Score(s),
members = append(members, sorted_set.MemberParam{
Value: sorted_set.Value(cmd[membersStartIndex:][i+1]),
Score: sorted_set.Score(s),
})
}
}
@@ -148,7 +149,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.EchoVault, conn
return nil, err
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -158,8 +159,8 @@ func handleZADD(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
// If INCR option is provided, return the new score value
if incr != nil {
m := set.Get(members[0].value)
return []byte(fmt.Sprintf("+%f\r\n", m.score)), nil
m := set.Get(members[0].Value)
return []byte(fmt.Sprintf("+%f\r\n", m.Score)), nil
}
return []byte(fmt.Sprintf(":%d\r\n", count)), nil
@@ -171,7 +172,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
defer server.KeyUnlock(ctx, key)
set := NewSortedSet(members)
set := sorted_set.NewSortedSet(members)
if err = server.SetValue(ctx, key, set); err != nil {
return nil, err
}
@@ -195,7 +196,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -211,40 +212,40 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.EchoVault, con
key := keys[0]
minimum := Score(math.Inf(-1))
minimum := sorted_set.Score(math.Inf(-1))
switch internal.AdaptType(cmd[2]).(type) {
default:
return nil, errors.New("min constraint must be a double")
case string:
if strings.ToLower(cmd[2]) == "+inf" {
minimum = Score(math.Inf(1))
minimum = sorted_set.Score(math.Inf(1))
} else {
return nil, errors.New("min constraint must be a double")
}
case float64:
s, _ := internal.AdaptType(cmd[2]).(float64)
minimum = Score(s)
minimum = sorted_set.Score(s)
case int:
s, _ := internal.AdaptType(cmd[2]).(int)
minimum = Score(s)
minimum = sorted_set.Score(s)
}
maximum := Score(math.Inf(1))
maximum := sorted_set.Score(math.Inf(1))
switch internal.AdaptType(cmd[3]).(type) {
default:
return nil, errors.New("max constraint must be a double")
case string:
if strings.ToLower(cmd[3]) == "-inf" {
maximum = Score(math.Inf(-1))
maximum = sorted_set.Score(math.Inf(-1))
} else {
return nil, errors.New("max constraint must be a double")
}
case float64:
s, _ := internal.AdaptType(cmd[3]).(float64)
maximum = Score(s)
maximum = sorted_set.Score(s)
case int:
s, _ := internal.AdaptType(cmd[3]).(int)
maximum = Score(s)
maximum = sorted_set.Score(s)
}
if !server.KeyExists(ctx, key) {
@@ -256,14 +257,14 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.EchoVault, con
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
var members []MemberParam
var members []sorted_set.MemberParam
for _, m := range set.GetAll() {
if m.score >= minimum && m.score <= maximum {
if m.Score >= minimum && m.Score <= maximum {
members = append(members, m)
}
}
@@ -271,7 +272,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.EchoVault, con
return []byte(fmt.Sprintf(":%d\r\n", len(members))), nil
}
func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zlexcountKeyFunc(cmd)
if err != nil {
return nil, err
@@ -290,7 +291,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.EchoVault,
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -299,7 +300,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.EchoVault,
// Check if all members has the same score
for i := 0; i < len(members)-2; i++ {
if members[i].score != members[i+1].score {
if members[i].Score != members[i+1].Score {
return []byte(":0\r\n"), nil
}
}
@@ -307,8 +308,8 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.EchoVault,
count := 0
for _, m := range members {
if slices.Contains([]int{1, 0}, compareLex(string(m.value), minimum)) &&
slices.Contains([]int{-1, 0}, compareLex(string(m.value), maximum)) {
if slices.Contains([]int{1, 0}, internal.CompareLex(string(m.Value), minimum)) &&
slices.Contains([]int{-1, 0}, internal.CompareLex(string(m.Value), maximum)) {
count += 1
}
}
@@ -347,13 +348,13 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.EchoVault, conn
return nil, err
}
defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet)
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
}
// Extract the remaining sets
var sets []*SortedSet
var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ {
if !server.KeyExists(ctx, keys[i]) {
@@ -364,7 +365,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.EchoVault, conn
return nil, err
}
locks[keys[i]] = locked
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
@@ -378,9 +379,9 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.EchoVault, conn
for _, m := range diff.GetAll() {
if includeScores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
} else {
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.value), m.value)
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value)
}
}
@@ -415,19 +416,19 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.EchoVault,
return nil, err
}
defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet)
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
}
var sets []*SortedSet
var sets []*sorted_set.SortedSet
for i := 1; i < len(keys); i++ {
if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err
}
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
@@ -462,26 +463,26 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.EchoVault, co
}
key := keys[0]
member := Value(cmd[3])
var increment Score
member := sorted_set.Value(cmd[3])
var increment sorted_set.Score
switch internal.AdaptType(cmd[2]).(type) {
default:
return nil, errors.New("increment must be a double")
case string:
if strings.EqualFold("-inf", strings.ToLower(cmd[2])) {
increment = Score(math.Inf(-1))
increment = sorted_set.Score(math.Inf(-1))
} else if strings.EqualFold("+inf", strings.ToLower(cmd[2])) {
increment = Score(math.Inf(1))
increment = sorted_set.Score(math.Inf(1))
} else {
return nil, errors.New("increment must be a double")
}
case float64:
s, _ := internal.AdaptType(cmd[2]).(float64)
increment = Score(s)
increment = sorted_set.Score(s)
case int:
s, _ := internal.AdaptType(cmd[2]).(int)
increment = Score(s)
increment = sorted_set.Score(s)
}
if !server.KeyExists(ctx, key) {
@@ -490,7 +491,11 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.EchoVault, co
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err
}
if err = server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}})); err != nil {
if err = server.SetValue(
ctx,
key,
sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: member, Score: increment}}),
); err != nil {
return nil, err
}
server.KeyUnlock(ctx, key)
@@ -501,13 +506,13 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.EchoVault, co
return nil, err
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
if _, err = set.AddOrUpdate(
[]MemberParam{
{value: member, score: increment}},
[]sorted_set.MemberParam{
{Value: member, Score: increment}},
"xx",
nil,
nil,
@@ -515,7 +520,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.EchoVault, co
return nil, err
}
return []byte(fmt.Sprintf("+%s\r\n",
strconv.FormatFloat(float64(set.Get(member).score), 'f', -1, 64))), nil
strconv.FormatFloat(float64(set.Get(member).Score), 'f', -1, 64))), nil
}
func handleZINTER(ctx context.Context, cmd []string, server utils.EchoVault, conn *net.Conn) ([]byte, error) {
@@ -538,7 +543,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.EchoVault, con
}
}()
var setParams []SortedSetParam
var setParams []sorted_set.SortedSetParam
for i := 0; i < len(keys); i++ {
if !server.KeyExists(ctx, keys[i]) {
@@ -549,26 +554,26 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.EchoVault, con
return nil, err
}
locks[keys[i]] = true
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, SortedSetParam{
set: set,
weight: weights[i],
setParams = append(setParams, sorted_set.SortedSetParam{
Set: set,
Weight: weights[i],
})
}
intersect := Intersect(aggregate, setParams...)
intersect := sorted_set.Intersect(aggregate, setParams...)
res := fmt.Sprintf("*%d", intersect.Cardinality())
if intersect.Cardinality() > 0 {
for _, m := range intersect.GetAll() {
if withscores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
} else {
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.value), m.value)
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value)
}
}
}
@@ -605,7 +610,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.EchoVault
}
}()
var setParams []SortedSetParam
var setParams []sorted_set.SortedSetParam
for i := 0; i < len(keys); i++ {
if !server.KeyExists(ctx, keys[i]) {
@@ -615,17 +620,17 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.EchoVault
return nil, err
}
locks[keys[i]] = true
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, SortedSetParam{
set: set,
weight: weights[i],
setParams = append(setParams, sorted_set.SortedSetParam{
Set: set,
Weight: weights[i],
})
}
intersect := Intersect(aggregate, setParams...)
intersect := sorted_set.Intersect(aggregate, setParams...)
if server.KeyExists(ctx, destination) && intersect.Cardinality() > 0 {
if _, err = server.KeyLock(ctx, destination); err != nil {
@@ -696,7 +701,7 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.EchoVault, conn
if _, err = server.KeyLock(ctx, keys[i]); err != nil {
continue
}
v, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
v, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok || v.Cardinality() == 0 {
server.KeyUnlock(ctx, keys[i])
continue
@@ -711,7 +716,7 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.EchoVault, conn
res := fmt.Sprintf("*%d", popped.Cardinality())
for _, m := range popped.GetAll() {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
}
res += "\r\n"
@@ -754,7 +759,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at key %s is not a sorted set", key)
}
@@ -766,7 +771,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.EchoVault, conn
res := fmt.Sprintf("*%d", popped.Cardinality())
for _, m := range popped.GetAll() {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
}
res += "\r\n"
@@ -791,7 +796,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.EchoVault, co
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -800,14 +805,14 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.EchoVault, co
res := fmt.Sprintf("*%d", len(members))
var member MemberObject
var member sorted_set.MemberObject
for i := 0; i < len(members); i++ {
member = set.Get(Value(members[i]))
if !member.exists {
member = set.Get(sorted_set.Value(members[i]))
if !member.Exists {
res = fmt.Sprintf("%s\r\n$-1", res)
} else {
res = fmt.Sprintf("%s\r\n+%s", res, strconv.FormatFloat(float64(member.score), 'f', -1, 64))
res = fmt.Sprintf("%s\r\n+%s", res, strconv.FormatFloat(float64(member.Score), 'f', -1, 64))
}
}
@@ -850,7 +855,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.EchoVault
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -860,9 +865,9 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.EchoVault
res := fmt.Sprintf("*%d", len(members))
for _, m := range members {
if withscores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
} else {
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.value), m.value)
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value)
}
}
@@ -871,7 +876,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.EchoVault
return []byte(res), nil
}
func handleZRANK(ctx context.Context, cmd []string, server utils.EchoVault, conn *net.Conn) ([]byte, error) {
func handleZRANK(ctx context.Context, cmd []string, server utils.EchoVault, _ *net.Conn) ([]byte, error) {
keys, err := zrankKeyFunc(cmd)
if err != nil {
return nil, err
@@ -894,23 +899,23 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
members := set.GetAll()
slices.SortFunc(members, func(a, b MemberParam) int {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
if strings.EqualFold(cmd[0], "zrevrank") {
return cmp.Compare(b.score, a.score)
return cmp.Compare(b.Score, a.Score)
}
return cmp.Compare(a.score, b.score)
return cmp.Compare(a.Score, b.Score)
})
for i := 0; i < len(members); i++ {
if members[i].value == Value(member) {
if members[i].Value == sorted_set.Value(member) {
if withscores {
score := strconv.FormatFloat(float64(members[i].score), 'f', -1, 64)
score := strconv.FormatFloat(float64(members[i].Score), 'f', -1, 64)
return []byte(fmt.Sprintf("*2\r\n:%d\r\n$%d\r\n%s\r\n", i, len(score), score)), nil
} else {
return []byte(fmt.Sprintf("*1\r\n:%d\r\n", i)), nil
@@ -938,14 +943,14 @@ func handleZREM(ctx context.Context, cmd []string, server utils.EchoVault, conn
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
deletedCount := 0
for _, m := range cmd[2:] {
if set.Remove(Value(m)) {
if set.Remove(sorted_set.Value(m)) {
deletedCount += 1
}
}
@@ -968,16 +973,16 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.EchoVault, con
return nil, err
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
member := set.Get(Value(cmd[2]))
if !member.exists {
member := set.Get(sorted_set.Value(cmd[2]))
if !member.Exists {
return []byte("$-1\r\n"), nil
}
score := strconv.FormatFloat(float64(member.score), 'f', -1, 64)
score := strconv.FormatFloat(float64(member.Score), 'f', -1, 64)
return []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(score), score)), nil
}
@@ -1011,14 +1016,14 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Echo
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
for _, m := range set.GetAll() {
if m.score >= Score(minimum) && m.score <= Score(maximum) {
set.Remove(m.value)
if m.Score >= sorted_set.Score(minimum) && m.Score <= sorted_set.Score(maximum) {
set.Remove(m.Value)
deletedCount += 1
}
}
@@ -1053,7 +1058,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.EchoV
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1070,20 +1075,20 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.EchoV
}
members := set.GetAll()
slices.SortFunc(members, func(a, b MemberParam) int {
return cmp.Compare(a.score, b.score)
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
return cmp.Compare(a.Score, b.Score)
})
deletedCount := 0
if start < stop {
for i := start; i <= stop; i++ {
set.Remove(members[i].value)
set.Remove(members[i].Value)
deletedCount += 1
}
} else {
for i := stop; i <= start; i++ {
set.Remove(members[i].value)
set.Remove(members[i].Value)
deletedCount += 1
}
}
@@ -1110,7 +1115,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.EchoVa
}
defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1119,7 +1124,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.EchoVa
// Check if all the members have the same score. If not, return 0
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
if members[i].Score != members[i+1].Score {
return []byte(":0\r\n"), nil
}
}
@@ -1128,9 +1133,9 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.EchoVa
// All the members have the same score
for _, m := range members {
if slices.Contains([]int{1, 0}, compareLex(string(m.value), minimum)) &&
slices.Contains([]int{-1, 0}, compareLex(string(m.value), maximum)) {
set.Remove(m.value)
if slices.Contains([]int{1, 0}, internal.CompareLex(string(m.Value), minimum)) &&
slices.Contains([]int{-1, 0}, internal.CompareLex(string(m.Value), maximum)) {
set.Remove(m.Value)
deletedCount += 1
}
}
@@ -1208,7 +1213,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.EchoVault, con
}
defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet)
set, ok := server.GetValue(ctx, key).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key)
}
@@ -1222,43 +1227,43 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.EchoVault, con
members := set.GetAll()
if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b MemberParam) int {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
// Do a score sort
if reverse {
return cmp.Compare(b.score, a.score)
return cmp.Compare(b.Score, a.Score)
}
return cmp.Compare(a.score, b.score)
return cmp.Compare(a.Score, b.Score)
})
}
if strings.EqualFold(policy, "bylex") {
// If policy is BYLEX, all the elements must have the same score
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
if members[i].Score != members[i+1].Score {
return []byte("*0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b MemberParam) int {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
if reverse {
return compareLex(string(b.value), string(a.value))
return internal.CompareLex(string(b.Value), string(a.Value))
}
return compareLex(string(a.value), string(b.value))
return internal.CompareLex(string(a.Value), string(b.Value))
})
}
var resultMembers []MemberParam
var resultMembers []sorted_set.MemberParam
for i := offset; i <= count; i++ {
if i >= len(members) {
break
}
if strings.EqualFold(policy, "byscore") {
if members[i].score >= Score(scoreStart) && members[i].score <= Score(scoreStop) {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
resultMembers = append(resultMembers, members[i])
}
continue
}
if slices.Contains([]int{1, 0}, compareLex(string(members[i].value), lexStart)) &&
slices.Contains([]int{-1, 0}, compareLex(string(members[i].value), lexStop)) {
if slices.Contains([]int{1, 0}, internal.CompareLex(string(members[i].Value), lexStart)) &&
slices.Contains([]int{-1, 0}, internal.CompareLex(string(members[i].Value), lexStop)) {
resultMembers = append(resultMembers, members[i])
}
}
@@ -1267,9 +1272,9 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.EchoVault, con
for _, m := range resultMembers {
if withscores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
} else {
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.value), m.value)
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value)
}
}
@@ -1345,7 +1350,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.EchoVault
}
defer server.KeyRUnlock(ctx, source)
set, ok := server.GetValue(ctx, source).(*SortedSet)
set, ok := server.GetValue(ctx, source).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", source)
}
@@ -1359,48 +1364,48 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.EchoVault
members := set.GetAll()
if strings.EqualFold(policy, "byscore") {
slices.SortFunc(members, func(a, b MemberParam) int {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
// Do a score sort
if reverse {
return cmp.Compare(b.score, a.score)
return cmp.Compare(b.Score, a.Score)
}
return cmp.Compare(a.score, b.score)
return cmp.Compare(a.Score, b.Score)
})
}
if strings.EqualFold(policy, "bylex") {
// If policy is BYLEX, all the elements must have the same score
for i := 0; i < len(members)-1; i++ {
if members[i].score != members[i+1].score {
if members[i].Score != members[i+1].Score {
return []byte(":0\r\n"), nil
}
}
slices.SortFunc(members, func(a, b MemberParam) int {
slices.SortFunc(members, func(a, b sorted_set.MemberParam) int {
if reverse {
return compareLex(string(b.value), string(a.value))
return internal.CompareLex(string(b.Value), string(a.Value))
}
return compareLex(string(a.value), string(b.value))
return internal.CompareLex(string(a.Value), string(b.Value))
})
}
var resultMembers []MemberParam
var resultMembers []sorted_set.MemberParam
for i := offset; i <= count; i++ {
if i >= len(members) {
break
}
if strings.EqualFold(policy, "byscore") {
if members[i].score >= Score(scoreStart) && members[i].score <= Score(scoreStop) {
if members[i].Score >= sorted_set.Score(scoreStart) && members[i].Score <= sorted_set.Score(scoreStop) {
resultMembers = append(resultMembers, members[i])
}
continue
}
if slices.Contains([]int{1, 0}, compareLex(string(members[i].value), lexStart)) &&
slices.Contains([]int{-1, 0}, compareLex(string(members[i].value), lexStop)) {
if slices.Contains([]int{1, 0}, internal.CompareLex(string(members[i].Value), lexStart)) &&
slices.Contains([]int{-1, 0}, internal.CompareLex(string(members[i].Value), lexStop)) {
resultMembers = append(resultMembers, members[i])
}
}
newSortedSet := NewSortedSet(resultMembers)
newSortedSet := sorted_set.NewSortedSet(resultMembers)
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {
@@ -1439,7 +1444,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.EchoVault, con
}
}()
var setParams []SortedSetParam
var setParams []sorted_set.SortedSetParam
for i := 0; i < len(keys); i++ {
if server.KeyExists(ctx, keys[i]) {
@@ -1447,25 +1452,25 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.EchoVault, con
return nil, err
}
locks[keys[i]] = true
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, SortedSetParam{
set: set,
weight: weights[i],
setParams = append(setParams, sorted_set.SortedSetParam{
Set: set,
Weight: weights[i],
})
}
}
union := Union(aggregate, setParams...)
union := sorted_set.Union(aggregate, setParams...)
res := fmt.Sprintf("*%d", union.Cardinality())
for _, m := range union.GetAll() {
if withscores {
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.value), m.value, strconv.FormatFloat(float64(m.score), 'f', -1, 64))
res += fmt.Sprintf("\r\n*2\r\n$%d\r\n%s\r\n+%s", len(m.Value), m.Value, strconv.FormatFloat(float64(m.Score), 'f', -1, 64))
} else {
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.value), m.value)
res += fmt.Sprintf("\r\n*1\r\n$%d\r\n%s", len(m.Value), m.Value)
}
}
@@ -1501,7 +1506,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.EchoVault
}
}()
var setParams []SortedSetParam
var setParams []sorted_set.SortedSetParam
for i := 0; i < len(keys); i++ {
if server.KeyExists(ctx, keys[i]) {
@@ -1509,18 +1514,18 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.EchoVault
return nil, err
}
locks[keys[i]] = true
set, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
set, ok := server.GetValue(ctx, keys[i]).(*sorted_set.SortedSet)
if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[i])
}
setParams = append(setParams, SortedSetParam{
set: set,
weight: weights[i],
setParams = append(setParams, sorted_set.SortedSetParam{
Set: set,
Weight: weights[i],
})
}
}
union := Union(aggregate, setParams...)
union := sorted_set.Union(aggregate, setParams...)
if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil {

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,6 @@
package sorted_set
import (
"cmp"
"errors"
"slices"
"strconv"
@@ -91,109 +90,3 @@ func extractKeysWeightsAggregateWithScores(cmd []string) ([]string, []int, strin
return keys, weights, aggregate, withscores, nil
}
func validateUpdatePolicy(updatePolicy interface{}) (string, error) {
if updatePolicy == nil {
return "", nil
}
err := errors.New("update policy must be a string of value NX or XX")
policy, ok := updatePolicy.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"nx", "xx"}, strings.ToLower(policy)) {
return "", err
}
return policy, nil
}
func validateComparison(comparison interface{}) (string, error) {
if comparison == nil {
return "", nil
}
err := errors.New("comparison condition must be a string of value LT or GT")
comp, ok := comparison.(string)
if !ok {
return "", err
}
if !slices.Contains([]string{"lt", "gt"}, strings.ToLower(comp)) {
return "", err
}
return comp, nil
}
func validateChanged(changed interface{}) (string, error) {
if changed == nil {
return "", nil
}
err := errors.New("changed condition should be a string of value CH")
ch, ok := changed.(string)
if !ok {
return "", err
}
if !strings.EqualFold(ch, "ch") {
return "", err
}
return ch, nil
}
func validateIncr(incr interface{}) (string, error) {
if incr == nil {
return "", nil
}
err := errors.New("incr condition should be a string of value INCR")
i, ok := incr.(string)
if !ok {
return "", err
}
if !strings.EqualFold(i, "incr") {
return "", err
}
return i, nil
}
func compareScores(old Score, new Score, comp string) Score {
switch strings.ToLower(comp) {
default:
return new
case "lt":
if new < old {
return new
}
return old
case "gt":
if new > old {
return new
}
return old
}
}
// compareLex returns -1 when s2 is lexicographically greater than s1,
// 0 if they're equal and 1 if s2 is lexicographically less than s1.
func compareLex(s1 string, s2 string) int {
if s1 == s2 {
return 0
}
if strings.Contains(s1, s2) {
return 1
}
if strings.Contains(s2, s1) {
return -1
}
limit := len(s1)
if len(s2) < limit {
limit = len(s2)
}
var c int
for i := 0; i < limit; i++ {
c = cmp.Compare(s1[i], s2[i])
if c != 0 {
break
}
}
return c
}