diff --git a/cluster/router.go b/cluster/router.go index 952c8ac..e617b7d 100644 --- a/cluster/router.go +++ b/cluster/router.go @@ -71,6 +71,7 @@ func makeRouter() map[string]CmdFunc { routerMap["sadd"] = defaultFunc routerMap["sismember"] = defaultFunc routerMap["srem"] = defaultFunc + routerMap["spop"] = defaultFunc routerMap["scard"] = defaultFunc routerMap["smembers"] = defaultFunc routerMap["sinter"] = defaultFunc diff --git a/commands.md b/commands.md index 2d56280..4155945 100644 --- a/commands.md +++ b/commands.md @@ -64,6 +64,7 @@ - sadd - sismember - srem + - spop - scard - smembers - sinter diff --git a/database/set.go b/database/set.go index 224e944..624fa2c 100644 --- a/database/set.go +++ b/database/set.go @@ -101,6 +101,46 @@ func execSRem(db *DB, args [][]byte) redis.Reply { return protocol.MakeIntReply(int64(counter)) } +// execSPop removes one or more random members from set +func execSPop(db *DB, args [][]byte) redis.Reply { + if len(args) != 1 && len(args) != 2 { + return protocol.MakeErrReply("ERR wrong number of arguments for 'spop' command") + } + key := string(args[0]) + + set, errReply := db.getAsSet(key) + if errReply != nil { + return errReply + } + if set == nil { + return &protocol.NullBulkReply{} + } + + count := 1 + if len(args) == 2 { + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil || count64 <= 0 { + return protocol.MakeErrReply("ERR value is out of range, must be positive") + } + count = int(count64) + } + if count > set.Len() { + count = set.Len() + } + + members := set.RandomDistinctMembers(count) + result := make([][]byte, len(members)) + for i, v := range members { + set.Remove(v) + result[i] = []byte(v) + } + + if count > 0 { + db.addAof(utils.ToCmdLine3("spop", args...)) + } + return protocol.MakeMultiBulkReply(result) +} + // execSCard gets the number of members in a set func execSCard(db *DB, args [][]byte) redis.Reply { key := string(args[0]) @@ -442,6 +482,7 @@ func init() { RegisterCommand("SAdd", execSAdd, writeFirstKey, undoSetChange, -3) RegisterCommand("SIsMember", execSIsMember, readFirstKey, nil, 3) RegisterCommand("SRem", execSRem, writeFirstKey, undoSetChange, -3) + RegisterCommand("SPop", execSPop, writeFirstKey, undoSetChange, -2) RegisterCommand("SCard", execSCard, readFirstKey, nil, 2) RegisterCommand("SMembers", execSMembers, readFirstKey, nil, 2) RegisterCommand("SInter", execSInter, prepareSetCalculate, nil, -2) diff --git a/database/set_test.go b/database/set_test.go index 46a7602..a4a3561 100644 --- a/database/set_test.go +++ b/database/set_test.go @@ -5,6 +5,7 @@ import ( "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol/asserts" + "math/rand" "strconv" "testing" ) @@ -63,6 +64,40 @@ func TestSRem(t *testing.T) { } } +func TestSPop(t *testing.T) { + testDB.Flush() + size := 100 + + // mock data + key := utils.RandString(10) + for i := 0; i < size; i++ { + member := strconv.Itoa(i) + testDB.Exec(nil, utils.ToCmdLine("sadd", key, member)) + } + + result := testDB.Exec(nil, utils.ToCmdLine("spop", key)) + asserts.AssertMultiBulkReplySize(t, result, 1) + + currentSize := size - 1 + for currentSize > 0 { + count := rand.Intn(currentSize) + 1 + resultSpop := testDB.Exec(nil, utils.ToCmdLine("spop", key, strconv.FormatInt(int64(count), 10))) + multiBulk, ok := resultSpop.(*protocol.MultiBulkReply) + if !ok { + t.Error(fmt.Sprintf("expected bulk protocol, actually %s", resultSpop.ToBytes())) + return + } + removedSize := len(multiBulk.Args) + for _, arg := range multiBulk.Args { + resultSIsMember := testDB.Exec(nil, utils.ToCmdLine("SIsMember", key, string(arg))) + asserts.AssertIntReply(t, resultSIsMember, 0) + } + currentSize -= removedSize + resultSCard := testDB.Exec(nil, utils.ToCmdLine("SCard", key)) + asserts.AssertIntReply(t, resultSCard, currentSize) + } +} + func TestSInter(t *testing.T) { testDB.Flush() size := 100 diff --git a/go.mod b/go.mod index bfed9be..0fadc4a 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,7 @@ module github.com/hdt3213/godis go 1.16 require ( - github.com/emirpasic/gods v1.16.0 // indirect - github.com/hdt3213/rdb v1.0.0 // indirect + github.com/hdt3213/rdb v1.0.0 github.com/jolestar/go-commons-pool/v2 v2.1.1 github.com/shopspring/decimal v1.2.0 ) diff --git a/go.sum b/go.sum index 4d0d22e..65c274e 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emirpasic/gods v1.12.0/go.mod h1:YfzfFFoVP/catgzJb4IKIqXjX78Ha8FMSDh3ymbK86o= -github.com/emirpasic/gods v1.16.0 h1:K8GFZcq7YD5BL7IuQULdIKMWxVmqiEBUBaN+v/Ku214= -github.com/emirpasic/gods v1.16.0/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/hdt3213/rdb v1.0.0 h1:rG8pRz6Y+2XtZw4C35rize3nXByClkFmwfM5ffj7sFs= @@ -17,6 +15,7 @@ github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFR github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5 h1:ymVxjfMaHvXD8RqPRmzHHsB3VvucivSkIAvJFDI5O3c=