mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 16:57:06 +08:00
add scan command
This commit is contained in:
@@ -59,27 +59,37 @@ func execFlushDB(db *DB, args [][]byte) redis.Reply {
|
||||
return &protocol.OkReply{}
|
||||
}
|
||||
|
||||
// execType returns the type of entity, including: string, list, hash, set and zset
|
||||
func execType(db *DB, args [][]byte) redis.Reply {
|
||||
key := string(args[0])
|
||||
// returns the type of entity, including: string, list, hash, set and zset
|
||||
func getType(db *DB, key string) string {
|
||||
entity, exists := db.GetEntity(key)
|
||||
if !exists {
|
||||
return protocol.MakeStatusReply("none")
|
||||
return "none"
|
||||
}
|
||||
switch entity.Data.(type) {
|
||||
case []byte:
|
||||
return protocol.MakeStatusReply("string")
|
||||
return "string"
|
||||
case list.List:
|
||||
return protocol.MakeStatusReply("list")
|
||||
return "list"
|
||||
case dict.Dict:
|
||||
return protocol.MakeStatusReply("hash")
|
||||
return "hash"
|
||||
case *set.Set:
|
||||
return protocol.MakeStatusReply("set")
|
||||
return "set"
|
||||
case *sortedset.SortedSet:
|
||||
return protocol.MakeStatusReply("zset")
|
||||
return "zset"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// execType returns the type of entity, including: string, list, hash, set and zset
|
||||
func execType(db *DB, args [][]byte) redis.Reply {
|
||||
key := string(args[0])
|
||||
result := getType(db, key)
|
||||
if len(result) > 0 {
|
||||
return protocol.MakeStatusReply(result)
|
||||
} else {
|
||||
return &protocol.UnknownErrReply{}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareRename(args [][]byte) ([]string, []string) {
|
||||
src := string(args[0])
|
||||
@@ -413,6 +423,57 @@ func execCopy(mdb *Server, conn redis.Connection, args [][]byte) redis.Reply {
|
||||
return protocol.MakeIntReply(1)
|
||||
}
|
||||
|
||||
// execScan return the result of the scan
|
||||
func execScan(db *DB, args [][]byte) redis.Reply {
|
||||
var count int = 10
|
||||
var pattern string = "*"
|
||||
var scanType string = ""
|
||||
if len(args) > 1 {
|
||||
for i := 1; i < len(args); i++ {
|
||||
arg := strings.ToLower(string(args[i]))
|
||||
if arg == "count" {
|
||||
count0, err := strconv.Atoi(string(args[i+1]))
|
||||
if err != nil {
|
||||
return &protocol.SyntaxErrReply{}
|
||||
}
|
||||
count = count0
|
||||
i++
|
||||
} else if arg == "match" {
|
||||
pattern = string(args[i+1])
|
||||
i++
|
||||
} else if arg == "type" {
|
||||
scanType = strings.ToLower(string(args[i+1]))
|
||||
i++
|
||||
} else {
|
||||
return &protocol.SyntaxErrReply{}
|
||||
}
|
||||
}
|
||||
}
|
||||
cursor, err := strconv.Atoi(string(args[0]))
|
||||
if err != nil {
|
||||
return protocol.MakeErrReply("ERR invalid cursor")
|
||||
}
|
||||
keysReply, nextCursor := db.data.DictScan(cursor, count, pattern)
|
||||
if nextCursor < 0 {
|
||||
return protocol.MakeErrReply("Invalid argument")
|
||||
}
|
||||
|
||||
if len(scanType) != 0 {
|
||||
for i := 0; i < len(keysReply); {
|
||||
if getType(db, string(keysReply[i])) != scanType {
|
||||
keysReply = append(keysReply[:i], keysReply[i+1:]...)
|
||||
} else {
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]redis.Reply, 2)
|
||||
result[0] = protocol.MakeBulkReply([]byte(strconv.FormatInt(int64(nextCursor), 10)))
|
||||
result[1] = protocol.MakeMultiBulkReply(keysReply)
|
||||
|
||||
return protocol.MakeMultiRawReply(result)
|
||||
}
|
||||
|
||||
func init() {
|
||||
registerCommand("Del", execDel, writeAllKeys, undoDel, -2, flagWrite).
|
||||
attachCommandExtra([]string{redisFlagWrite}, 1, -1, 1)
|
||||
@@ -444,4 +505,6 @@ func init() {
|
||||
attachCommandExtra([]string{redisFlagWrite, redisFlagFast}, 1, 1, 1)
|
||||
registerCommand("Keys", execKeys, noPrepare, nil, 2, flagReadOnly).
|
||||
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
|
||||
registerCommand("Scan", execScan, readAllKeys, nil, -2, flagReadOnly).
|
||||
attachCommandExtra([]string{redisFlagReadonly, redisFlagSortForScript}, 1, 1, 1)
|
||||
}
|
||||
|
@@ -313,3 +313,85 @@ func TestCopy(t *testing.T) {
|
||||
result = testMDB.Exec(conn, utils.ToCmdLine("ttl", destKey))
|
||||
asserts.AssertIntReplyGreaterThan(t, result, 0)
|
||||
}
|
||||
|
||||
func TestScan(t *testing.T) {
|
||||
testDB.Flush()
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune(i))
|
||||
value := key
|
||||
testDB.Exec(nil, utils.ToCmdLine("set", "a:"+key, value))
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
key := string(rune(i))
|
||||
value := key
|
||||
testDB.Exec(nil, utils.ToCmdLine("set", "b:"+key, value))
|
||||
}
|
||||
|
||||
// test scan 0 when keys < 10
|
||||
result := testDB.Exec(nil, utils.ToCmdLine("scan", "0"))
|
||||
cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg)
|
||||
cursor, err := strconv.Atoi(cursorStr)
|
||||
if err == nil {
|
||||
if cursor != 0 {
|
||||
t.Errorf("expect cursor 0, actually %d", cursor)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
t.Errorf("get scan result error")
|
||||
return
|
||||
}
|
||||
|
||||
// test scan 0 match a*
|
||||
result = testDB.Exec(nil, utils.ToCmdLine("scan", "0", "match", "a*"))
|
||||
returnKeys := result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
|
||||
for i := range returnKeys {
|
||||
key := string(returnKeys[i])
|
||||
if key[0] != 'a' {
|
||||
t.Errorf("The key %s should match a*", key)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// test scan 0 type string
|
||||
testDB.Exec(nil, utils.ToCmdLine("hset", "hashkey", "hashkey", "1"))
|
||||
result = testDB.Exec(nil, utils.ToCmdLine("scan", "0", "type", "string"))
|
||||
returnKeys = result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
|
||||
for i := range returnKeys {
|
||||
key := string(returnKeys[i])
|
||||
if key == "hashkey" {
|
||||
t.Errorf("expect type string, found hash")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// test returned cursor
|
||||
testDB.Flush()
|
||||
for i := 0; i < 100; i++ {
|
||||
key := string(rune(i))
|
||||
value := key
|
||||
testDB.Exec(nil, utils.ToCmdLine("set", "a"+key, value))
|
||||
}
|
||||
cursor = 0
|
||||
resultByte := make([][]byte, 0)
|
||||
for {
|
||||
scanCursor := strconv.Itoa(cursor)
|
||||
result = testDB.Exec(nil, utils.ToCmdLine("scan", scanCursor, "count", "20"))
|
||||
cursorStr := string(result.(*protocol.MultiRawReply).Replies[0].(*protocol.BulkReply).Arg)
|
||||
returnKeys = result.(*protocol.MultiRawReply).Replies[1].(*protocol.MultiBulkReply).Args
|
||||
resultByte = append(resultByte, returnKeys...)
|
||||
cursor, err = strconv.Atoi(cursorStr)
|
||||
if err == nil {
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t.Errorf("get scan result error")
|
||||
return
|
||||
}
|
||||
}
|
||||
resultByte = utils.RemoveDuplicates(resultByte)
|
||||
if len(resultByte) != 100 {
|
||||
t.Errorf("expect result num 100, actually %d", len(resultByte))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package dict
|
||||
|
||||
import (
|
||||
"github.com/hdt3213/godis/lib/wildcard"
|
||||
"math"
|
||||
"math/rand"
|
||||
"sort"
|
||||
@@ -435,3 +436,47 @@ func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func stringsToBytes(strSlice []string) [][]byte {
|
||||
byteSlice := make([][]byte, len(strSlice))
|
||||
for i, str := range strSlice {
|
||||
byteSlice[i] = []byte(str)
|
||||
}
|
||||
return byteSlice
|
||||
}
|
||||
|
||||
func (dict *ConcurrentDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) {
|
||||
size := dict.Len()
|
||||
result := make([][]byte, 0)
|
||||
|
||||
if pattern == "*" && count >= size {
|
||||
return stringsToBytes(dict.Keys()), 0
|
||||
}
|
||||
|
||||
matchKey, err := wildcard.CompilePattern(pattern)
|
||||
if err != nil {
|
||||
return result, -1
|
||||
}
|
||||
|
||||
shardCount := len(dict.table)
|
||||
shardIndex := cursor
|
||||
|
||||
for shardIndex < shardCount {
|
||||
shard := dict.table[shardIndex]
|
||||
shard.mutex.RLock()
|
||||
if len(result)+len(shard.m) > count && shardIndex > cursor {
|
||||
shard.mutex.RUnlock()
|
||||
return result, shardIndex
|
||||
}
|
||||
|
||||
for key := range shard.m {
|
||||
if pattern == "*" || matchKey.IsMatch(key) {
|
||||
result = append(result, []byte(key))
|
||||
}
|
||||
}
|
||||
shard.mutex.RUnlock()
|
||||
shardIndex++
|
||||
}
|
||||
|
||||
return result, 0
|
||||
}
|
||||
|
@@ -524,3 +524,51 @@ func TestConcurrentDict_Keys(t *testing.T) {
|
||||
t.Errorf("expect %d keys, actual: %d", size, len(d.Keys()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDictScan(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
for i := 0; i < count; i++ {
|
||||
key := "kkk" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
key := "key" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
cursor := 0
|
||||
matchKey := "*"
|
||||
c := 20
|
||||
result := make([][]byte, 0)
|
||||
var returnKeys [][]byte
|
||||
for {
|
||||
returnKeys, cursor = d.DictScan(cursor, c, matchKey)
|
||||
result = append(result, returnKeys...)
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
result = utils.RemoveDuplicates(result)
|
||||
if len(result) != count*2 {
|
||||
t.Errorf("scan command result number error: %d, should be %d ", len(result), count*2)
|
||||
}
|
||||
matchKey = "key*"
|
||||
cursor = 0
|
||||
mresult := make([][]byte, 0)
|
||||
for {
|
||||
returnKeys, cursor = d.DictScan(cursor, c, matchKey)
|
||||
mresult = append(mresult, returnKeys...)
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
mresult = utils.RemoveDuplicates(mresult)
|
||||
if len(mresult) != count {
|
||||
t.Errorf("scan command result number error: %d, should be %d ", len(mresult), count)
|
||||
}
|
||||
matchKey = "no*"
|
||||
returnKeys, _ = d.DictScan(cursor, c, matchKey)
|
||||
if len(returnKeys) != 0 {
|
||||
t.Errorf("returnKeys should be empty")
|
||||
}
|
||||
}
|
||||
|
@@ -16,4 +16,5 @@ type Dict interface {
|
||||
RandomKeys(limit int) []string
|
||||
RandomDistinctKeys(limit int) []string
|
||||
Clear()
|
||||
DictScan(cursor int, count int, pattern string) ([][]byte, int)
|
||||
}
|
||||
|
@@ -120,3 +120,7 @@ func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
|
||||
func (dict *SimpleDict) Clear() {
|
||||
*dict = *MakeSimple()
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) {
|
||||
return stringsToBytes(dict.Keys()), 0
|
||||
}
|
||||
|
@@ -84,3 +84,20 @@ func ConvertRange(start int64, end int64, size int64) (int, int) {
|
||||
}
|
||||
return int(start), int(end)
|
||||
}
|
||||
|
||||
// RemoveDuplicates removes duplicate byte slices from a 2D byte slice
|
||||
func RemoveDuplicates(input [][]byte) [][]byte {
|
||||
uniqueMap := make(map[string]struct{})
|
||||
var result [][]byte
|
||||
|
||||
for _, item := range input {
|
||||
// Use bytes.Buffer to convert byte slice to string
|
||||
key := string(item)
|
||||
if _, exists := uniqueMap[key]; !exists {
|
||||
uniqueMap[key] = struct{}{}
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
Reference in New Issue
Block a user