From a806f8e64f0c3872cbc9c9afbd894f3fd9b5adf2 Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Sun, 30 Aug 2020 21:04:58 +0800 Subject: [PATCH] sorted set bug fix and unit tests --- src/datastruct/sortedset/skiplist.go | 1 - src/datastruct/sortedset/sortedset.go | 3 + src/db/sortedset.go | 4 +- src/db/sortedset_test.go | 298 ++++++++++++++++++++++++++ src/db/{test.go => util_test.go} | 0 src/redis/reply/asserts/assert.go | 49 +++++ 6 files changed, 352 insertions(+), 3 deletions(-) create mode 100644 src/db/sortedset_test.go rename src/db/{test.go => util_test.go} (100%) create mode 100644 src/redis/reply/asserts/assert.go diff --git a/src/datastruct/sortedset/skiplist.go b/src/datastruct/sortedset/skiplist.go index e7c2b47..0304c51 100644 --- a/src/datastruct/sortedset/skiplist.go +++ b/src/datastruct/sortedset/skiplist.go @@ -272,7 +272,6 @@ func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder n = n.level[level].forward } } - n = n.level[0].forward if !min.less(n.Score) { return nil } diff --git a/src/datastruct/sortedset/sortedset.go b/src/datastruct/sortedset/sortedset.go index ddd52fb..50e4bf7 100644 --- a/src/datastruct/sortedset/sortedset.go +++ b/src/datastruct/sortedset/sortedset.go @@ -180,6 +180,9 @@ func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, o } else { node = node.level[0].forward } + if node == nil { + break + } gtMin := min.less(node.Element.Score) // greater than min ltMax := max.greater(node.Element.Score) if !gtMin || !ltMax { diff --git a/src/db/sortedset.go b/src/db/sortedset.go index ee97f54..fc9607f 100644 --- a/src/db/sortedset.go +++ b/src/db/sortedset.go @@ -402,12 +402,12 @@ func ZRevRangeByScore(db *DB, args [][]byte) redis.Reply { } key := string(args[0]) - min, err := SortedSet.ParseScoreBorder(string(args[1])) + min, err := SortedSet.ParseScoreBorder(string(args[2])) if err != nil { return reply.MakeErrReply(err.Error()) } - max, err := SortedSet.ParseScoreBorder(string(args[2])) + max, err := SortedSet.ParseScoreBorder(string(args[1])) if err != nil { return reply.MakeErrReply(err.Error()) } diff --git a/src/db/sortedset_test.go b/src/db/sortedset_test.go new file mode 100644 index 0000000..f27adcf --- /dev/null +++ b/src/db/sortedset_test.go @@ -0,0 +1,298 @@ +package db + +import ( + "github.com/HDT3213/godis/src/redis/reply/asserts" + "math/rand" + "strconv" + "testing" +) + +func TestZAdd(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + + // add new members + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]float64, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(rand.Int()), 10) + scores[i] = rand.Float64() + setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + asserts.AssertIntReply(t, result, size) + + // test zscore and zrank + for i, member := range members { + result := ZScore(testDB, toArgs(key, member)) + score := strconv.FormatFloat(scores[i], 'f', -1, 64) + asserts.AssertBulkReply(t, result, score) + } + + // test zcard + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size) + + // update members + setArgs = []string{key} + for i := 0; i < size; i++ { + scores[i] = rand.Float64() + 100 + setArgs = append(setArgs, strconv.FormatFloat(scores[i], 'f', -1, 64), members[i]) + } + result = ZAdd(testDB, toArgs(setArgs...)) + asserts.AssertIntReply(t, result, 0) // return number of new members + + // test updated score + for i, member := range members { + result := ZScore(testDB, toArgs(key, member)) + score := strconv.FormatFloat(scores[i], 'f', -1, 64) + asserts.AssertBulkReply(t, result, score) + } +} + +func TestZRank(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(rand.Int()), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) + + // test zrank + for i, member := range members { + result := ZRank(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, i) + + result = ZRevRank(testDB, toArgs(key, member)) + asserts.AssertIntReply(t, result, size-i-1) + } +} + +func TestZRange(t *testing.T) { + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(rand.Int()), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + reverseMembers := make([]string, size) + for i, v := range members { + reverseMembers[size-i-1] = v + } + + start := "0" + end := "9" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:10]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:10]) + + start = "0" + end = "200" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers) + + start = "0" + end = "-10" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:size-10+1]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:size-10+1]) + + start = "0" + end = "-200" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[0:0]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[0:0]) + + start = "-10" + end = "-1" + result = ZRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, members[90:]) + result = ZRevRange(testDB, toArgs(key, start, end)) + asserts.AssertMultiBulkReply(t, result, reverseMembers[90:]) +} + +func reverse(src []string) []string { + result := make([]string, len(src)) + for i, v := range src { + result[len(src)-i-1] = v + } + return result +} + +func TestZRangeByScore(t *testing.T) { + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + + min := "20" + max := "30" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[20:31]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[20:31])) + + min = "-10" + max = "10" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[0:11]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[0:11])) + + min = "90" + max = "110" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[90:]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[90:])) + + min = "(20" + max = "(30" + result = ZRangeByScore(testDB, toArgs(key, min, max)) + asserts.AssertMultiBulkReply(t, result, members[21:30]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min)) + asserts.AssertMultiBulkReply(t, result, reverse(members[21:30])) + + min = "20" + max = "40" + result = ZRangeByScore(testDB, toArgs(key, min, max, "LIMIT", "5", "5")) + asserts.AssertMultiBulkReply(t, result, members[25:30]) + result = ZRevRangeByScore(testDB, toArgs(key, max, min, "LIMIT", "5", "5")) + asserts.AssertMultiBulkReply(t, result, reverse(members[31:36])) +} + +func TestZRem(t *testing.T) { + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) + + args := []string{key} + args = append(args, members[0:10]...) + result := ZRem(testDB, toArgs(args...)) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) + + // test ZRemRangeByRank + FlushAll(testDB, [][]byte{}) + size = 100 + key = strconv.FormatInt(int64(rand.Int()), 10) + members = make([]string, size) + scores = make([]int, size) + setArgs = []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) + + result = ZRemRangeByRank(testDB, toArgs(key, "0", "9")) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) + + // test ZRemRangeByScore + FlushAll(testDB, [][]byte{}) + size = 100 + key = strconv.FormatInt(int64(rand.Int()), 10) + members = make([]string, size) + scores = make([]int, size) + setArgs = []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + ZAdd(testDB, toArgs(setArgs...)) + + result = ZRemRangeByScore(testDB, toArgs(key, "0", "9")) + asserts.AssertIntReply(t, result, 10) + result = ZCard(testDB, toArgs(key)) + asserts.AssertIntReply(t, result, size-10) +} + +func TestZCount(t *testing.T) { + // prepare + FlushAll(testDB, [][]byte{}) + size := 100 + key := strconv.FormatInt(int64(rand.Int()), 10) + members := make([]string, size) + scores := make([]int, size) + setArgs := []string{key} + for i := 0; i < size; i++ { + members[i] = strconv.FormatInt(int64(i), 10) + scores[i] = i + setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) + } + result := ZAdd(testDB, toArgs(setArgs...)) + + min := "20" + max := "30" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 11) + + min = "-10" + max = "10" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 11) + + min = "90" + max = "110" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 10) + + min = "(20" + max = "(30" + result = ZCount(testDB, toArgs(key, min, max)) + asserts.AssertIntReply(t, result, 9) +} + +func TestZIncrBy(t *testing.T) { + FlushAll(testDB, [][]byte{}) + key := strconv.FormatInt(int64(rand.Int()), 10) + ZAdd(testDB, toArgs(key, "10", "a")) + result := ZIncrBy(testDB, toArgs(key, "10", "a")) + asserts.AssertBulkReply(t, result, "20") + + result = ZScore(testDB, toArgs(key, "a")) + asserts.AssertBulkReply(t, result, "20") +} diff --git a/src/db/test.go b/src/db/util_test.go similarity index 100% rename from src/db/test.go rename to src/db/util_test.go diff --git a/src/redis/reply/asserts/assert.go b/src/redis/reply/asserts/assert.go new file mode 100644 index 0000000..7cda3ef --- /dev/null +++ b/src/redis/reply/asserts/assert.go @@ -0,0 +1,49 @@ +package asserts + +import ( + "fmt" + "github.com/HDT3213/godis/src/datastruct/utils" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "testing" +) + +func AssertIntReply(t *testing.T, actual redis.Reply, expected int) { + intResult, ok := actual.(*reply.IntReply) + if !ok { + t.Error(fmt.Sprintf("expected int reply, actually %s", actual.ToBytes())) + return + } + if intResult.Code != int64(expected) { + t.Error(fmt.Sprintf("expected %d, actually %d", expected, intResult.Code)) + } +} + +func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) { + bulkReply, ok := actual.(*reply.BulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) + return + } + if !utils.BytesEquals(bulkReply.Arg, []byte(expected)) { + t.Error(fmt.Sprintf("expected %s, actually %s", expected, actual.ToBytes())) + } +} + +func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) { + multiBulk, ok := actual.(*reply.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk reply, actually %s", actual.ToBytes())) + return + } + if len(multiBulk.Args) != len(expected) { + t.Error(fmt.Sprintf("expected %d elements, actually %d", len(expected), len(multiBulk.Args))) + return + } + for i, v := range multiBulk.Args { + actual := string(v) + if actual != expected[i] { + t.Error(fmt.Sprintf("expected %s, actually %s", expected[i], actual)) + } + } +}