diff --git a/src/datastruct/sortedset/border.go b/src/datastruct/sortedset/border.go new file mode 100644 index 0000000..e5d644e --- /dev/null +++ b/src/datastruct/sortedset/border.go @@ -0,0 +1,89 @@ +package sortedset + +import ( + "errors" + "strconv" +) + +/* + * ScoreBorder is a struct represents `min` `max` parameter of redis command `ZRANGEBYSCORE` + * can accept: + * int or float value, such as 2.718, 2, -2.718, -2 ... + * exclusive int or float value, such as (2.718, (2, (-2.718, (-2 ... + * infinity: +inf, -inf, inf(same as +inf) + */ + +const ( + negativeInf int8 = -1 + positiveInf int8 = 1 +) + +type ScoreBorder struct { + Inf int8 + Value float64 + Exclude bool +} + +func (border *ScoreBorder)greater(value float64)bool { + if border.Inf == negativeInf { + return false + } else if border.Inf == positiveInf { + return true + } + if border.Exclude { + return border.Value > value + } else { + return border.Value >= value + } +} + +func (border *ScoreBorder)less(value float64)bool { + if border.Inf == negativeInf { + return true + } else if border.Inf == positiveInf { + return false + } + if border.Exclude { + return border.Value < value + } else { + return border.Value <= value + } +} + +var positiveInfBorder = &ScoreBorder { + Inf: positiveInf, +} + +var negativeInfBorder = &ScoreBorder { + Inf: negativeInf, +} + +func ParseScoreBorder(s string)(*ScoreBorder, error) { + if s == "inf" || s == "+inf" { + return positiveInfBorder, nil + } + if s == "-inf" { + return negativeInfBorder, nil + } + if s[0] == '(' { + value, err := strconv.ParseFloat(s[1:], 64) + if err != nil { + return nil, errors.New("ERR min or max is not a float") + } + return &ScoreBorder{ + Inf: 0, + Value: value, + Exclude: true, + }, nil + } else { + value, err := strconv.ParseFloat(s, 64) + if err != nil { + return nil, errors.New("ERR min or max is not a float") + } + return &ScoreBorder{ + Inf: 0, + Value: value, + Exclude: false, + }, nil + } +} \ No newline at end of file diff --git a/src/datastruct/sortedset/skiplist.go b/src/datastruct/sortedset/skiplist.go new file mode 100644 index 0000000..2820a51 --- /dev/null +++ b/src/datastruct/sortedset/skiplist.go @@ -0,0 +1,290 @@ +package sortedset + +import "math/rand" + +const ( + maxLevel = 16 +) + + +type Element struct { + Member string + Score float64 +} + +// level aspect of a Node +type Level struct { + forward *Node // forward node has greater score + span int64 +} + +type Node struct { + Element + backward *Node + level []*Level // level[0] is base level +} + +type skiplist struct { + header *Node + tail *Node + length int64 + level int16 +} + +func makeNode(level int16, score float64, member string)*Node { + n := &Node{ + Element: Element{ + Score: score, + Member: member, + }, + level: make([]*Level, level), + } + for i := range n.level { + n.level[i] = new(Level) + } + return n +} + +func makeSkiplist()*skiplist { + return &skiplist{ + level: 1, + header: makeNode(maxLevel, 0, ""), + } +} + +func randomLevel() int16 { + level := int16(1) + for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) { + level++ + } + if level < maxLevel { + return level + } + return maxLevel +} + +func (skiplist *skiplist)insert(member string, score float64)*Node { + update := make([]*Node, maxLevel) // link new node with node in `update` + rank := make([]int64, maxLevel) + + // find position to insert + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + if i == skiplist.level - 1 { + rank[i] = 0 + } else { + rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position + } + if node.level[i] != nil { + // traverse the skip list + for node.level[i].forward != nil && + (node.level[i].forward.Score < score || + (node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key + rank[i] += node.level[i].span + node = node.level[i].forward + } + } + update[i] = node + } + + level := randomLevel() + // extend skiplist level + if level > skiplist.level { + for i := skiplist.level; i < level; i++ { + rank[i] = 0 + update[i] = skiplist.header + update[i].level[i].span = skiplist.length + } + skiplist.level = level + } + + // make node and link into skiplist + node = makeNode(level, score, member) + for i := int16(0); i < level; i++ { + node.level[i].forward = update[i].level[i].forward + update[i].level[i].forward = node + + // update span covered by update[i] as node is inserted here + node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) + update[i].level[i].span = (rank[0] - rank[i]) + 1 + } + + // increment span for untouched levels + for i := level; i < skiplist.level; i++ { + update[i].level[i].span++ + } + + // set backward node + if update[0] == skiplist.header { + node.backward = nil + } else { + node.backward = update[0] + } + if node.level[0].forward != nil { + node.level[0].forward.backward = node + } else { + skiplist.tail = node + } + skiplist.length++ + return node +} + +/* + * param node: node to delete + * param update: backward node (of target) or last node of each level + */ +func (skiplist *skiplist) removeNode(node *Node, update []*Node) { + for i := int16(0); i < skiplist.level; i++ { + if update[i].level[i].forward == node { + update[i].level[i].span += node.level[i].span - 1 + update[i].level[i].forward = node.level[i].forward + } else { + update[i].level[i].span-- + } + } + if node.level[0].forward != nil { + node.level[0].forward.backward = node.backward + } else { + skiplist.tail = node.backward + } + for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil { + skiplist.level-- + } + skiplist.length-- +} + +/* + * return: has found and removed node + */ +func (skiplist *skiplist) remove(member string, score float64)bool { + /* + * find backward node (of target) or last node of each level + * their forward need to be updated + */ + update := make([]*Node, maxLevel) + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for node.level[i].forward != nil && + (node.level[i].forward.Score < score || + (node.level[i].forward.Score == score && + node.level[i].forward.Member < member)) { + node = node.level[i].forward + } + update[i] = node + } + node = node.level[0].forward + if node != nil && score == node.Score && node.Member == member { + skiplist.removeNode(node, update) + // free x + return true + } + return false +} + +/* + * return: 1 based rank, 0 means member not found + */ +func (skiplist *skiplist) getRank(member string, score float64)int64 { + var rank int64 = 0 + x := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for x.level[i].forward != nil && + (x.level[i].forward.Score < score || + (x.level[i].forward.Score == score && + x.level[i].forward.Member <= member)) { + rank += x.level[i].span + x = x.level[i].forward + } + + /* x might be equal to zsl->header, so test if obj is non-NULL */ + if x.Member == member { + return rank + } + } + return 0 +} + +/* + * 1-based rank + */ +func (skiplist *skiplist) getByRank(rank int64)*Node { + var i int64 = 0 + n := skiplist.header + // scan from top level + for level := skiplist.level - 1; level >= 0; level-- { + for n.level[level].forward != nil && (i+n.level[level].span) <= rank { + i += n.level[level].span + n = n.level[level].forward + } + if i == rank { + return n + } + } + return nil +} + +/* + * return removed elements + */ +func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) { + update := make([]*Node, maxLevel) + removed = make([]*Element, 0) + // find backward nodes (of target range) or last node of each level + node := skiplist.header + for i := skiplist.level - 1; i >= 0; i-- { + for node.level[i].forward != nil { + if min.less(node.level[i].forward.Score) { // already in range + break + } + node = node.level[i].forward + } + update[i] = node + } + + // node is the first one within range + node = node.level[0].forward + + // remove nodes in range + for node != nil { + if !max.greater(node.Score) { // already out of range + break + } + next := node.level[0].forward + removedElement := node.Element + removed = append(removed, &removedElement) + skiplist.removeNode(node, update) + node = next + } + return removed +} + +// 1-based rank, including start, exclude stop +func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) { + var i int64 = 0 // rank of iterator + update := make([]*Node, maxLevel) + removed = make([]*Element, 0) + + // scan from top level + node := skiplist.header + for level := skiplist.level - 1; level >= 0; level-- { + for node.level[level].forward != nil && (i+node.level[level].span) < start { + i += node.level[level].span + node = node.level[level].forward + } + update[level] = node + } + + i++ + node = node.level[0].forward // first node in range + + // remove nodes in range + for node != nil && i < stop { + next := node.level[0].forward + removedElement := node.Element + removed = append(removed, &removedElement) + skiplist.removeNode(node, update) + node = next + i++ + } + return removed +} diff --git a/src/datastruct/sortedset/sortedset.go b/src/datastruct/sortedset/sortedset.go new file mode 100644 index 0000000..48ba062 --- /dev/null +++ b/src/datastruct/sortedset/sortedset.go @@ -0,0 +1,203 @@ +package sortedset + +import ( + "strconv" +) + +type SortedSet struct { + dict map[string]*Element + skiplist *skiplist +} + +func Make()*SortedSet { + return &SortedSet{ + dict: make(map[string]*Element), + skiplist: makeSkiplist(), + } +} + +/* + * return: has inserted new node + */ +func (sortedSet *SortedSet)Add(member string, score float64)bool { + element, ok := sortedSet.dict[member] + sortedSet.dict[member] = &Element{ + Member: member, + Score: score, + } + if ok { + if score != element.Score { + sortedSet.skiplist.remove(member, score) + sortedSet.skiplist.insert(member, score) + } + return false + } else { + sortedSet.skiplist.insert(member, score) + return true + } +} + +func (sortedSet *SortedSet) Len()int64 { + return int64(len(sortedSet.dict)) +} + +func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) { + element, ok = sortedSet.dict[member] + if !ok { + return nil, false + } + return element, true +} + +func (sortedSet *SortedSet) Remove(member string)bool { + v, ok := sortedSet.dict[member] + if ok { + sortedSet.skiplist.remove(member, v.Score) + delete(sortedSet.dict, member) + return true + } + return false +} + +/** + * get 0-based rank + */ +func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) { + element, ok := sortedSet.dict[member] + if !ok { + return -1 + } + r := sortedSet.skiplist.getRank(member, element.Score) + if desc { + r = sortedSet.skiplist.length - r + } else { + r-- + } + return r +} + +/** + * traverse [start, stop), 0-based rank + */ +func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) { + size := int64(sortedSet.Len()) + if start < 0 || start >= size { + panic("illegal start " + strconv.FormatInt(start, 10)) + } + if stop < start || stop > size { + panic("illegal end " + strconv.FormatInt(stop, 10)) + } + + // find start node + var node *Node + if desc { + node = sortedSet.skiplist.tail + if start > 0 { + node = sortedSet.skiplist.getByRank(int64(size - start)) + } + } else { + node = sortedSet.skiplist.header.level[0].forward + if start > 0 { + node = sortedSet.skiplist.getByRank(int64(start + 1)) + } + } + + sliceSize := int(stop - start) + for i := 0; i < sliceSize; i++ { + if !consumer(&node.Element) { + break + } + if desc { + node = node.backward + } else { + node = node.level[0].forward + } + } +} + +/** + * return [start, stop), 0-based rank + * assert start in [0, size), stop in [start, size] + */ +func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element { + sliceSize := int(stop - start) + slice := make([]*Element, sliceSize) + i := 0 + sortedSet.ForEach(start, stop, desc, func(element *Element)bool { + slice[i] = element + i++ + return true + }) + return slice +} + +func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 { + var i int64 = 0 + // ascending order + sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool { + gtMin := min.less(element.Score) // greater than min + if !gtMin { + // has not into range, continue foreach + return true + } + ltMax := max.greater(element.Score) // less than max + if !ltMax { + // break through score border, break foreach + return false + } + // gtMin && ltMax + i++ + return true + }) + return i +} + +/* + * param limit: <0 means no limit + */ +func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element { + if limit == 0 || offset < 0{ + return make([]*Element, 0) + } + slice := make([]*Element, 0) + var skipped int64 = 0 + sortedSet.ForEach(0, sortedSet.Len(), desc, func(element *Element)bool { + gtMin := min.less(element.Score) // greater than min + ltMax := max.greater(element.Score) // less than max + if gtMin && ltMax { // in score range + if skipped < offset { + skipped++ + return true + } + slice = append(slice, element) + if len(slice) == int(limit) { // reach limit + return false + } + } + if (desc && !gtMin) || (!desc && !ltMax) { + return false // break through score border + } + return true + }) + return slice +} + +func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 { + removed := sortedSet.skiplist.RemoveRangeByScore(min, max) + for _, element := range removed { + delete(sortedSet.dict, element.Member) + } + return int64(len(removed)) +} + + +/* + * 0-based rank, [start, stop) + */ +func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 { + removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1) + for _, element := range removed { + delete(sortedSet.dict, element.Member) + } + return int64(len(removed)) +} \ No newline at end of file diff --git a/src/db/db.go b/src/db/db.go index 8384983..aa97269 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -117,6 +117,21 @@ func MakeCmdMap()map[string]CmdFunc { cmdMap["sdiffstore"] = SDiffStore cmdMap["srandmember"] = SRandMember + cmdMap["zadd"] = ZAdd + cmdMap["zscore"] = ZScore + cmdMap["zincrby"] = ZIncrBy + cmdMap["zrank"] = ZRank + cmdMap["zcount"] = ZCount + cmdMap["zrevrank"] = ZRevRank + cmdMap["zcard"] = ZCard + cmdMap["zrange"] = ZRange + cmdMap["zrevrange"] = ZRevRange + cmdMap["zrangebyscore"] = ZRangeByScore + cmdMap["zrevrangebyscore"] = ZRevRangeByScore + cmdMap["zrem"] = ZRem + cmdMap["zremrangebyscore"] = ZRemRangeByScore + cmdMap["zremrangebyrank"] = ZRemRangeByRank + return cmdMap } diff --git a/src/db/sortedset.go b/src/db/sortedset.go new file mode 100644 index 0000000..0d4cf21 --- /dev/null +++ b/src/db/sortedset.go @@ -0,0 +1,602 @@ +package db + +import ( + SortedSet "github.com/HDT3213/godis/src/datastruct/sortedset" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" + "strings" +) + +func ZAdd(db *DB, args [][]byte)redis.Reply { + if len(args) < 3 || len(args) % 2 != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zadd' command") + } + key := string(args[0]) + size := (len(args) - 1) / 2 + elements := make([]*SortedSet.Element, size) + for i := 0; i < size; i++ { + scoreValue := args[2 * i + 1] + member := string(args[2 * i + 2]) + score, err := strconv.ParseFloat(string(scoreValue), 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + elements[i] = &SortedSet.Element{ + Member:member, + Score:score, + } + } + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + entity, exists := db.Get(key) + if !exists { + entity = &DataEntity{ + Code: SortedSetCode, + Data: SortedSet.Make(), + } + db.Data.Put(key, entity) + } + + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + // insert + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + i := 0 + for _, e := range elements { + if sortedSet.Add(e.Member, e.Score) { + i++ + } + } + + return reply.MakeIntReply(int64(i)) +} + +func ZScore(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zscore' command") + } + key := string(args[0]) + member := string(args[1]) + + // get entity + entity, exists := db.Get(key) + if !exists { + return &reply.NullBulkReply{} + } + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + element, exists := sortedSet.Get(member) + if !exists { + return &reply.NullBulkReply{} + } + value := strconv.FormatFloat(element.Score, 'f', -1, 64) + return reply.MakeBulkReply([]byte(value)) +} + +func ZRank(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrank' command") + } + key := string(args[0]) + member := string(args[1]) + + // get entity + entity, exists := db.Get(key) + if !exists { + return &reply.NullBulkReply{} + } + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + rank := sortedSet.GetRank(member, false) + if rank < 0 { + return &reply.NullBulkReply{} + } + return reply.MakeIntReply(rank) +} + +func ZRevRank(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrank' command") + } + key := string(args[0]) + member := string(args[1]) + + // get entity + entity, exists := db.Get(key) + if !exists { + return &reply.NullBulkReply{} + } + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + rank := sortedSet.GetRank(member, true) + if rank < 0 { + return &reply.NullBulkReply{} + } + return reply.MakeIntReply(rank) +} + +func ZCard(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zcard' command") + } + key := string(args[0]) + + // get entity + entity, exists := db.Get(key) + if !exists { + return reply.MakeIntReply(0) + } + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + return reply.MakeIntReply(int64(sortedSet.Len())) +} + +func ZRange(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 3 && len(args) != 4 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrange' command") + } + withScores := false + if len(args) == 4 { + if strings.ToUpper(string(args[3])) != "WITHSCORES" { + return reply.MakeErrReply("syntax error") + } else { + withScores = true + } + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + return range0(db, key, start, stop, withScores, false) +} + +func ZRevRange(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 3 && len(args) != 4 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrevrange' command") + } + withScores := false + if len(args) == 4 { + if string(args[3]) != "WITHSCORES" { + return reply.MakeErrReply("syntax error") + } else { + withScores = true + } + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + return range0(db, key, start, stop, withScores, true) +} + +func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool)redis.Reply { + // lock key + db.Locks.RLock(key) + defer db.Locks.RUnLock(key) + + // get data + entity, exists := db.Get(key) + if !exists { + return &reply.EmptyMultiBulkReply{} + } + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + // compute index + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + size := sortedSet.Len() // assert: size > 0 + if start < -1 * size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return &reply.EmptyMultiBulkReply{} + } + if stop < -1 * size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + slice := sortedSet.Range(start, stop, desc) + if withScores { + result := make([][]byte, len(slice) * 2) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) + result[i] = []byte(scoreStr) + i++ + } + return reply.MakeMultiBulkReply(result) + } else { + result := make([][]byte, len(slice)) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + } + return reply.MakeMultiBulkReply(result) + } +} + +func ZCount(db *DB, args [][]byte)redis.Reply { + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zcount' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + db.Locks.RLock(key) + defer db.Locks.RUnLock(key) + + // get data + entity, exists := db.Get(key) + if !exists { + return reply.MakeIntReply(0) + } + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + return reply.MakeIntReply(sortedSet.Count(min, max)) +} + +/* + * param limit: limit < 0 means no limit + */ +func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool)redis.Reply { + // lock key + db.Locks.RLock(key) + defer db.Locks.RUnLock(key) + + // get data + entity, exists := db.Get(key) + if !exists { + return &reply.EmptyMultiBulkReply{} + } + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + slice := sortedSet.RangeByScore(min, max, offset, limit, desc) + if withScores { + result := make([][]byte, len(slice) * 2) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + scoreStr := strconv.FormatFloat(element.Score, 'f', -1, 64) + result[i] = []byte(scoreStr) + i++ + } + return reply.MakeMultiBulkReply(result) + } else { + result := make([][]byte, len(slice)) + i := 0 + for _, element := range slice { + result[i] = []byte(element.Member) + i++ + } + return reply.MakeMultiBulkReply(result) + } +} + +func ZRangeByScore(db *DB, args [][]byte)redis.Reply { + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + withScores := false + var offset int64 = 0 + var limit int64 = -1 + if len(args) > 3 { + for i := 3; i < len(args); { + s := string(args[i]) + if strings.ToUpper(s) == "WITHSCORES" { + withScores = true + i++ + } else if strings.ToUpper(s) == "LIMIT" { + if len(args) < i+3 { + return reply.MakeErrReply("ERR syntax error") + } + offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + i += 3 + } else { + return reply.MakeErrReply("ERR syntax error") + } + } + } + return rangeByScore0(db, key, min, max, offset, limit, withScores, false) +} + +func ZRevRangeByScore(db *DB, args [][]byte)redis.Reply { + if len(args) < 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + withScores := false + var offset int64 = 0 + var limit int64 = -1 + if len(args) > 3 { + for i := 3; i < len(args); { + s := string(args[i]) + if strings.ToUpper(s) == "WITHSCORES" { + withScores = true + i++ + } else if strings.ToUpper(s) == "LIMIT" { + if len(args) < i+3 { + return reply.MakeErrReply("ERR syntax error") + } + offset, err = strconv.ParseInt(string(args[i+1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + limit, err = strconv.ParseInt(string(args[i+2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + i += 3 + } else { + return reply.MakeErrReply("ERR syntax error") + } + } + } + return rangeByScore0(db, key, min, max, offset, limit, withScores, true) +} + +func ZRemRangeByScore(db *DB, args [][]byte)redis.Reply { + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyscore' command") + } + key := string(args[0]) + + min, err := SortedSet.ParseScoreBorder(string(args[1])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + max, err := SortedSet.ParseScoreBorder(string(args[2])) + if err != nil { + return reply.MakeErrReply(err.Error()) + } + + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data + entity, exists := db.Get(key) + if !exists { + return &reply.EmptyMultiBulkReply{} + } + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + removed := sortedSet.RemoveByScore(min, max) + return reply.MakeIntReply(removed) +} + +func ZRemRangeByRank(db *DB, args [][]byte)redis.Reply { + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zremrangebyrank' command") + } + key := string(args[0]) + start, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data + entity, exists := db.Get(key) + if !exists { + return reply.MakeIntReply(0) + } + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + // compute index + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + size := sortedSet.Len() // assert: size > 0 + if start < -1 * size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return reply.MakeIntReply(0) + } + if stop < -1 * size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + removed := sortedSet.RemoveByRank(start, stop) + return reply.MakeIntReply(removed) +} + +func ZRem(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zrem' command") + } + key := string(args[0]) + fields := make([]string, len(args)-1) + fieldArgs := args[1:] + for i, v := range fieldArgs { + fields[i] = string(v) + } + + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get entity + entity, exists := db.Get(key) + if !exists { + return reply.MakeIntReply(0) + } + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + var deleted int64 = 0 + for _, field := range fields { + if sortedSet.Remove(field) { + deleted++ + } + } + return reply.MakeIntReply(deleted) +} + +func ZIncrBy(db *DB, args [][]byte)redis.Reply { + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'zincrby' command") + } + key := string(args[0]) + rawDelta := string(args[1]) + field := string(args[2]) + delta, err := strconv.ParseFloat(rawDelta, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not a valid float") + } + + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + entity, exists := db.Get(key) + if !exists { + entity = &DataEntity{ + Code: SortedSetCode, + Data: SortedSet.Make(), + } + db.Data.Put(key, entity) + } + + // check type + if entity.Code != SortedSetCode { + return &reply.WrongTypeErrReply{} + } + + // put data + sortedSet, _ := entity.Data.(*SortedSet.SortedSet) + element, exists := sortedSet.Get(field) + if !exists { + sortedSet.Add(field, delta) + return reply.MakeBulkReply(args[1]) + } else { + score := element.Score + delta + sortedSet.Add(field, score) + bytes := []byte(strconv.FormatFloat(score, 'f', -1, 64)) + return reply.MakeBulkReply(bytes) + } +} \ No newline at end of file