Add sscan, hscan, zscan

This commit is contained in:
lhpqaq
2024-07-23 21:48:31 +08:00
committed by finley
parent c1dd65d84f
commit 75030407cb
13 changed files with 432 additions and 10 deletions

View File

@@ -496,6 +496,56 @@ func execHRandField(db *DB, args [][]byte) redis.Reply {
return &protocol.EmptyMultiBulkReply{}
}
func execHScan(db *DB, args [][]byte) redis.Reply {
var count int = 10
var pattern string = "*"
if len(args) > 2 {
for i := 2; i < len(args); i++ {
arg := strings.ToLower(string(args[i]))
if arg == "count" {
count0, err := strconv.Atoi(string(args[i+1]))
if err != nil {
return &protocol.SyntaxErrReply{}
}
count = count0
i++
} else if arg == "match" {
pattern = string(args[i+1])
i++
} else {
return &protocol.SyntaxErrReply{}
}
}
}
if len(args) < 2 {
return &protocol.SyntaxErrReply{}
}
key := string(args[0])
// get entity
dict, errReply := db.getAsDict(key)
if errReply != nil {
return errReply
}
if dict == nil {
return &protocol.NullBulkReply{}
}
cursor, err := strconv.Atoi(string(args[1]))
if err != nil {
return protocol.MakeErrReply("ERR invalid cursor")
}
keysReply, nextCursor := dict.DictScan(cursor, count, pattern)
if nextCursor < 0 {
return protocol.MakeErrReply("Invalid argument")
}
result := make([]redis.Reply, 2)
result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10)))
result[1] = protocol.MakeMultiBulkReply(keysReply)
return protocol.MakeMultiRawReply(result)
}
func init() {
registerCommand("HSet", execHSet, writeFirstKey, undoHSet, 4, flagWrite).
attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1)
@@ -529,4 +579,6 @@ func init() {
attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1)
registerCommand("HRandField", execHRandField, readFirstKey, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagRandom, redisFlagReadonly}, 1, 1, 1)
registerCommand("HScan", execHScan, readFirstKey, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
}

View File

@@ -346,3 +346,48 @@ func TestUndoHIncr(t *testing.T) {
result := testDB.Exec(nil, utils.ToCmdLine("hget", key, field))
asserts.AssertBulkReply(t, result, "1")
}
func TestHScan(t *testing.T) {
testDB.Flush()
hashKey := "test:hash"
for i := 0; i < 3; i++ {
key := string(rune(i))
value := key
testDB.Exec(nil, utils.ToCmdLine("hset", hashKey, "a"+key, value))
}
for i := 0; i < 3; i++ {
key := string(rune(i))
value := key
testDB.Exec(nil, utils.ToCmdLine("hset", hashKey, "b"+key, value))
}
result := testDB.Exec(nil, utils.ToCmdLine("hscan", hashKey, "0", "count", "10"))
cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg)
cursor, err := strconv.Atoi(cursorStr)
if err == nil {
if cursor != 0 {
t.Errorf("expect cursor 0, actually %d", cursor)
return
}
} else {
t.Errorf("get scan result error")
return
}
// test hscan 0 match a*
result = testDB.Exec(nil, utils.ToCmdLine("hscan", hashKey, "0", "match", "a*"))
returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
i := 0
for i < len(returnKeys) {
if i%2 != 0 {
i++
continue // pass value
}
key := string(returnKeys[i])
i++
if key[0] != 'a' {
t.Errorf("The key %s should match a*", key)
return
}
}
}

View File

@@ -505,6 +505,6 @@ func init() {
attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1)
registerCommand("Keys", execKeys, noPrepare, nil, 2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
registerCommand("Scan", execScan, readAllKeys, nil, -2, flagReadOnly).
registerCommand("Scan", execScan, noPrepare, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol"
"strconv"
"strings"
)
func (db *DB) getAsSet(key string) (*HashSet.Set, protocol.ErrorReply) {
@@ -354,6 +355,53 @@ func execSRandMember(db *DB, args [][]byte) redis.Reply {
return &protocol.EmptyMultiBulkReply{}
}
func execSScan(db *DB, args [][]byte) redis.Reply {
var count int = 10
var pattern string = "*"
if len(args) > 2 {
for i := 2; i < len(args); i++ {
arg := strings.ToLower(string(args[i]))
if arg == "count" {
count0, err := strconv.Atoi(string(args[i+1]))
if err != nil {
return &protocol.SyntaxErrReply{}
}
count = count0
i++
} else if arg == "match" {
pattern = string(args[i+1])
i++
} else {
return &protocol.SyntaxErrReply{}
}
}
}
key := string(args[0])
// get entity
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply
}
if set == nil {
return &protocol.EmptyMultiBulkReply{}
}
cursor, err := strconv.Atoi(string(args[1]))
if err != nil {
return protocol.MakeErrReply("ERR invalid cursor")
}
keysReply, nextCursor := set.SetScan(cursor, count, pattern)
if nextCursor < 0 {
return protocol.MakeErrReply("Invalid argument")
}
result := make([]redis.Reply, 2)
result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10)))
result[1] = protocol.MakeMultiBulkReply(keysReply)
return protocol.MakeMultiRawReply(result)
}
func init() {
registerCommand("SAdd", execSAdd, writeFirstKey, undoSetChange, -3, flagWrite).
attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1)
@@ -381,4 +429,6 @@ func init() {
attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM}, 1, 1, 1)
registerCommand("SRandMember", execSRandMember, readFirstKey, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly, redisFlagRandom}, 1, 1, 1)
registerCommand("SScan", execSScan, readFirstKey, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
}

View File

@@ -248,3 +248,40 @@ func TestSRandMember(t *testing.T) {
result = testDB.Exec(nil, utils.ToCmdLine("SRandMember", key, "-110"))
asserts.AssertMultiBulkReplySize(t, result, 110)
}
func TestSScan(t *testing.T) {
testDB.Flush()
setKey := "test:set"
for i := 0; i < 3; i++ {
key := string(rune(i))
testDB.Exec(nil, utils.ToCmdLine("sadd", setKey, "a"+key))
}
for i := 0; i < 3; i++ {
key := string(rune(i))
testDB.Exec(nil, utils.ToCmdLine("sadd", setKey, "b"+key))
}
result := testDB.Exec(nil, utils.ToCmdLine("sscan", setKey, "0", "count", "10"))
cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg)
cursor, err := strconv.Atoi(cursorStr)
if err == nil {
if cursor != 0 {
t.Errorf("expect cursor 0, actually %d", cursor)
return
}
} else {
t.Errorf("get scan result error")
return
}
// test sscan 0 match a*
result = testDB.Exec(nil, utils.ToCmdLine("sscan", setKey, "0", "match", "a*"))
returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
for i := range returnKeys {
key := string(returnKeys[i])
if key[0] != 'a' {
t.Errorf("The key %s should match a*", key)
return
}
}
}

View File

@@ -1,15 +1,14 @@
package database
import (
"math"
"strconv"
"strings"
SortedSet "github.com/hdt3213/godis/datastruct/sortedset"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol"
"math"
"strconv"
"strings"
)
func (db *DB) getAsSortedSet(key string) (*SortedSet.SortedSet, protocol.ErrorReply) {
@@ -796,6 +795,53 @@ func execZRevRangeByLex(db *DB, args [][]byte) redis.Reply {
return protocol.MakeMultiBulkReply(result)
}
func execZScan(db *DB, args [][]byte) redis.Reply {
var count int = 10
var pattern string = "*"
if len(args) > 2 {
for i := 2; i < len(args); i++ {
arg := strings.ToLower(string(args[i]))
if arg == "count" {
count0, err := strconv.Atoi(string(args[i+1]))
if err != nil {
return &protocol.SyntaxErrReply{}
}
count = count0
i++
} else if arg == "match" {
pattern = string(args[i+1])
i++
} else {
return &protocol.SyntaxErrReply{}
}
}
}
key := string(args[0])
// get entity
set, errReply := db.getAsSortedSet(key)
if errReply != nil {
return errReply
}
if set == nil {
return &protocol.EmptyMultiBulkReply{}
}
cursor, err := strconv.Atoi(string(args[1]))
if err != nil {
return protocol.MakeErrReply("ERR invalid cursor")
}
keysReply, nextCursor := set.ZSetScan(cursor, count, pattern)
if nextCursor < 0 {
return protocol.MakeErrReply("Invalid argument")
}
result := make([]redis.Reply, 2)
result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10)))
result[1] = protocol.MakeMultiBulkReply(keysReply)
return protocol.MakeMultiRawReply(result)
}
func init() {
registerCommand("ZAdd", execZAdd, writeFirstKey, undoZAdd, -4, flagWrite).
attachCommandExtra([]string{redisFlagWrite, redisFlagDenyOOM, redisFlagFast}, 1, 1, 1)
@@ -835,4 +881,6 @@ func init() {
attachCommandExtra([]string{redisFlagWrite}, 1, 1, 1)
registerCommand("ZRevRangeByLex", execZRevRangeByLex, readFirstKey, nil, -4, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly}, 1, 1, 1)
registerCommand("ZScan", execZScan, readFirstKey, nil, -2, flagReadOnly).
attachCommandExtra([]string{redisFlagReadonly}, 1, 1, 1)
}

View File

@@ -1,12 +1,12 @@
package database
import (
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts"
"math/rand"
"strconv"
"testing"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol/asserts"
)
func TestZAdd(t *testing.T) {
@@ -762,3 +762,49 @@ func TestZRevRangeByLex(t *testing.T) {
result30 := testDB.Exec(nil, utils.ToCmdLine("ZRevRangeByLex", key, "+", "-", "limit", "2", "2"))
asserts.AssertMultiBulkReply(t, result30, []string{"c", "b"})
}
func TestZScan(t *testing.T) {
testDB.Flush()
zsetKey := "zsetkey"
expectKeyScore := make(map[string]float64)
for i := 0; i < 3; i++ {
key := string(rune(i))
expectKeyScore[key] = float64(i)
testDB.Exec(nil, utils.ToCmdLine("zadd", zsetKey, strconv.FormatInt(int64(i), 10), "a"+key))
}
for i := 0; i < 3; i++ {
key := string(rune(i))
expectKeyScore[key] = float64(i + 3)
testDB.Exec(nil, utils.ToCmdLine("zadd", zsetKey, strconv.FormatInt(int64(i), 10), "b"+key))
}
result := testDB.Exec(nil, utils.ToCmdLine("zscan", zsetKey, "0", "count", "10"))
cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg)
cursor, err := strconv.Atoi(cursorStr)
if err == nil {
if cursor != 0 {
t.Errorf("expect cursor 0, actually %d", cursor)
return
}
} else {
t.Errorf("get scan result error")
return
}
// test zscan 0 match a*
result = testDB.Exec(nil, utils.ToCmdLine("zscan", zsetKey, "0", "match", "a*"))
returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
i := 0
for i < len(returnKeys) {
if i%2 != 0 {
i++
continue // pass score
}
key := string(returnKeys[i])
i++
if key[0] != 'a' {
t.Errorf("The key %s should match a*", key)
return
}
}
}

View File

@@ -1,5 +1,9 @@
package dict
import (
"github.com/hdt3213/godis/lib/wildcard"
)
// SimpleDict wraps a map, it is not thread safe
type SimpleDict struct {
m map[string]interface{}
@@ -122,5 +126,20 @@ func (dict *SimpleDict) Clear() {
}
func (dict *SimpleDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) {
return stringsToBytes(dict.Keys()), 0
result := make([][]byte, 0)
matchKey, err := wildcard.CompilePattern(pattern)
if err != nil {
return result, -1
}
for k := range dict.m {
if pattern == "*" || matchKey.IsMatch(k) {
raw, exists := dict.Get(k)
if !exists {
continue
}
result = append(result, []byte(k))
result = append(result, raw.([]byte))
}
}
return result, 0
}

View File

@@ -53,3 +53,30 @@ func TestSimpleDict_PutIfExists(t *testing.T) {
return
}
}
func TestSimpleDict_Scan(t *testing.T) {
d := MakeSimple()
size := 10
for i := 0; i < size; i++ {
str := "a" + utils.RandString(5)
d.Put(str, []byte(str))
}
keys, nextCursor := d.DictScan(0, size, "*")
if len(keys) != size*2 {
t.Errorf("expect %d keys, actual: %d", size*2, len(keys))
return
}
if nextCursor != 0 {
t.Errorf("expect 0, actual: %d", nextCursor)
return
}
for i := 0; i < size; i++ {
str := "b" + utils.RandString(5)
d.Put(str, str)
}
keys, _ = d.DictScan(0, size*2, "a*")
if len(keys) != size*2 {
t.Errorf("expect %d keys, actual: %d", size*2, len(keys))
return
}
}

View File

@@ -2,6 +2,7 @@ package set
import (
"github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/lib/wildcard"
)
// Set is a set of elements based on hash table
@@ -149,3 +150,20 @@ func (set *Set) RandomMembers(limit int) []string {
func (set *Set) RandomDistinctMembers(limit int) []string {
return set.dict.RandomDistinctKeys(limit)
}
// Scan set with cursor and pattern
func (set *Set) SetScan(cursor int, count int, pattern string) ([][]byte, int) {
result := make([][]byte, 0)
matchKey, err := wildcard.CompilePattern(pattern)
if err != nil {
return result, -1
}
set.ForEach(func(member string) bool {
if pattern == "*" || matchKey.IsMatch(member) {
result = append(result, []byte(member))
}
return true
})
return result, 0
}

View File

@@ -1,6 +1,7 @@
package set
import (
"github.com/hdt3213/godis/lib/utils"
"strconv"
"testing"
)
@@ -30,3 +31,30 @@ func TestSet(t *testing.T) {
}
}
}
func TestSetScan(t *testing.T) {
set := Make()
size := 10
for i := 0; i < size; i++ {
str := "a" + utils.RandString(5)
set.Add(str)
}
keys, nextCursor := set.SetScan(0, size, "*")
if len(keys) != size {
t.Errorf("expect %d keys, actual: %d", size, len(keys))
return
}
if nextCursor != 0 {
t.Errorf("expect 0, actual: %d", nextCursor)
return
}
for i := 0; i < size; i++ {
str := "b" + utils.RandString(5)
set.Add(str)
}
keys, _ = set.SetScan(0, size*2, "a*")
if len(keys) != size {
t.Errorf("expect %d keys, actual: %d", size, len(keys))
return
}
}

View File

@@ -2,6 +2,8 @@ package sortedset
import (
"strconv"
"github.com/hdt3213/godis/lib/wildcard"
)
// SortedSet is a set which keys sorted by bound score
@@ -236,3 +238,22 @@ func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
}
return int64(len(removed))
}
func (sortedSet *SortedSet) ZSetScan(cursor int, count int, pattern string) ([][]byte, int) {
result := make([][]byte, 0)
matchKey, err := wildcard.CompilePattern(pattern)
if err != nil {
return result, -1
}
for k := range sortedSet.dict {
if pattern == "*" || matchKey.IsMatch(k) {
elem, exists := sortedSet.dict[k]
if !exists {
continue
}
result = append(result, []byte(k))
result = append(result, []byte(strconv.FormatFloat(elem.Score, 'f', 10, 64)))
}
}
return result, 0
}

View File

@@ -1,6 +1,10 @@
package sortedset
import "testing"
import (
"testing"
"github.com/hdt3213/godis/lib/utils"
)
func TestSortedSet_PopMin(t *testing.T) {
var set = Make()
@@ -14,3 +18,30 @@ func TestSortedSet_PopMin(t *testing.T) {
t.Fail()
}
}
func TestSetScan(t *testing.T) {
set := Make()
size := 10
for i := 0; i < size; i++ {
str := "a" + utils.RandString(5)
set.Add(str, float64(i))
}
keys, nextCursor := set.ZSetScan(0, size, "*")
if len(keys) != size*2 {
t.Errorf("expect %d keys, actual: %d", size*2, len(keys))
return
}
if nextCursor != 0 {
t.Errorf("expect 0, actual: %d", nextCursor)
return
}
for i := 0; i < size; i++ {
str := "b" + utils.RandString(5)
set.Add(str, float64(i+size))
}
keys, _ = set.ZSetScan(0, size*2, "a*")
if len(keys) != size*2 {
t.Errorf("expect %d keys, actual: %d", size*2, len(keys))
return
}
}