Files
redis-go/src/db/set.go
2020-02-20 00:54:32 +08:00

505 lines
13 KiB
Go

package db
import (
HashSet "github.com/HDT3213/godis/src/datastruct/set"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func (db *DB)getAsSet(key string)(*HashSet.Set, reply.ErrorReply) {
entity, exists := db.Get(key)
if !exists {
return nil, nil
}
set, ok := entity.Data.(*HashSet.Set)
if !ok {
return nil, &reply.WrongTypeErrReply{}
}
return set, nil
}
func (db *DB) getOrInitSet(key string)(set *HashSet.Set, inited bool, errReply reply.ErrorReply) {
set, errReply = db.getAsSet(key)
if errReply != nil {
return nil, false, errReply
}
inited = false
if set == nil {
set = HashSet.Make()
db.Put(key, &DataEntity{
Data: set,
})
inited = true
}
return set, inited, nil
}
func SAdd(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command"), nil
}
key := string(args[0])
members := args[1:]
// lock
db.Lock(key)
defer db.UnLock(key)
// get or init entity
set, _, errReply := db.getOrInitSet(key)
if errReply != nil {
return errReply, nil
}
counter := 0
for _, member := range members {
counter += set.Add(string(member))
}
return reply.MakeIntReply(int64(counter)), &extra{toPersist: true}
}
func SIsMember(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command"), nil
}
key := string(args[0])
member := string(args[1])
// get set
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return reply.MakeIntReply(0), nil
}
has := set.Has(member)
if has {
return reply.MakeIntReply(1), nil
} else {
return reply.MakeIntReply(0), nil
}
}
func SRem(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command"), nil
}
key := string(args[0])
members := args[1:]
// lock
db.Lock(key)
defer db.UnLock(key)
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return reply.MakeIntReply(0), nil
}
counter := 0
for _, member := range members {
counter += set.Remove(string(member))
}
if set.Len() == 0 {
db.Remove(key)
}
return reply.MakeIntReply(int64(counter)), &extra{toPersist: counter > 0}
}
func SCard(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command"), nil
}
key := string(args[0])
// get or init entity
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return reply.MakeIntReply(0), nil
}
return reply.MakeIntReply(int64(set.Len())), nil
}
func SMembers(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command"), nil
}
key := string(args[0])
// lock
db.Locker.RLock(key)
defer db.Locker.RUnLock(key)
// get or init entity
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return &reply.EmptyMultiBulkReply{}, nil
}
arr := make([][]byte, set.Len())
i := 0
set.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr), nil
}
func SInter(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command"), nil
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return &reply.EmptyMultiBulkReply{}, nil
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}, nil
}
}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr), nil
}
func SInterStore(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command"), nil
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
db.Locker.Lock(dest)
defer db.Locker.UnLock(dest)
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
db.Remove(dest) // clean ttl and old value
return &reply.EmptyMultiBulkReply{}, nil
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
db.Remove(dest) // clean ttl and old value
return reply.MakeIntReply(0), nil
}
}
}
set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{
Data: set,
})
return reply.MakeIntReply(int64(set.Len())), &extra{toPersist: true}
}
func SUnion(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command"), nil
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
continue
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Union(set)
}
}
if result == nil {
// all keys are empty set
return &reply.EmptyMultiBulkReply{}, nil
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr), nil
}
func SUnionStore(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command"), nil
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
db.Locker.Lock(dest)
defer db.Locker.UnLock(dest)
var result *HashSet.Set
for _, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
continue
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Union(set)
}
}
db.Remove(dest) // clean ttl
if result == nil {
// all keys are empty set
return &reply.EmptyMultiBulkReply{}, nil
}
set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{
Data: set,
})
return reply.MakeIntReply(int64(set.Len())), &extra{toPersist: true}
}
func SDiff(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command"), nil
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
var result *HashSet.Set
for i, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
if i == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}, nil
} else {
continue
}
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}, nil
}
}
}
if result == nil {
// all keys are nil
return &reply.EmptyMultiBulkReply{}, nil
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr), nil
}
func SDiffStore(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command"), nil
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locker.RLocks(keys...)
defer db.Locker.RUnLocks(keys...)
db.Locker.Lock(dest)
defer db.Locker.UnLock(dest)
var result *HashSet.Set
for i, key := range keys {
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
if i == 0 {
// early termination
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}, nil
} else {
continue
}
}
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}, nil
}
}
}
if result == nil {
// all keys are nil
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}, nil
}
set := HashSet.MakeFromVals(result.ToSlice()...)
db.Put(dest, &DataEntity{
Data: set,
})
return reply.MakeIntReply(int64(set.Len())), &extra{toPersist: true}
}
func SRandMember(db *DB, args [][]byte) (redis.Reply, *extra) {
if len(args) != 1 && len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command"), nil
}
key := string(args[0])
// lock
db.Locker.RLock(key)
defer db.Locker.RUnLock(key)
// get or init entity
set, errReply := db.getAsSet(key)
if errReply != nil {
return errReply, nil
}
if set == nil {
return &reply.NullBulkReply{}, nil
}
if len(args) == 1 {
members := set.RandomMembers(1)
return reply.MakeBulkReply([]byte(members[0])), nil
} else {
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range"), nil
}
count := int(count64)
if count > 0 {
members := set.RandomMembers(count)
result := make([][]byte, len(members))
for i, v := range members {
result[i] = []byte(v)
}
return reply.MakeMultiBulkReply(result), nil
} else if count < 0 {
members := set.RandomDistinctMembers(-count)
result := make([][]byte, len(members))
for i, v := range members {
result[i] = []byte(v)
}
return reply.MakeMultiBulkReply(result), nil
} else {
return &reply.EmptyMultiBulkReply{}, nil
}
}
}