add set data structure

This commit is contained in:
hdt3213
2019-08-18 01:11:29 +08:00
committed by wyb
parent 5e6e894c29
commit 900e92bb49
6 changed files with 764 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
package dict
import (
"math/rand"
"sync"
"sync/atomic"
)
@@ -563,4 +564,95 @@ func (dict *Dict)ForEach(consumer Consumer) {
}
}
}
}
func (dict *Dict)Keys()[]string {
keys := make([]string, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{})bool {
if i < len(keys) {
keys[i] = key
i++
} else {
keys = append(keys, key)
}
return true
})
return keys
}
func (shard *Shard)RandomKey()string {
if shard == nil {
panic("shard is nil")
}
shard.mutex.RLock()
defer shard.mutex.RUnlock()
keys := make([]string, 0)
i := 0
node := shard.head
for node != nil {
if node.key != "" {
keys = append(keys, node.key)
i++
}
node = node.next
}
if i > 0 {
return keys[rand.Intn(i)]
} else {
return ""
}
}
func (dict *Dict)RandomKeys(limit int)[]string {
size := dict.Len()
if limit >= size {
return dict.Keys()
}
table, _ := dict.table.Load().([]*Shard)
shardCount := len(table)
result := make([]string, limit)
for i := 0; i < limit; {
shard := dict.getShard(uint32(rand.Intn(shardCount)))
if shard == nil {
continue
}
key := shard.RandomKey()
if key != "" {
result[i] = key
i++
}
}
return result
}
func (dict *Dict)RandomDistinctKeys(limit int)[]string {
size := dict.Len()
if limit >= size {
return dict.Keys()
}
table, _ := dict.table.Load().([]*Shard)
shardCount := len(table)
result := make(map[string]bool)
for len(result) < limit {
shardIndex := uint32(rand.Intn(shardCount))
shard := dict.getShard(shardIndex)
if shard == nil {
continue
}
key := shard.RandomKey()
if key != "" {
result[key] = true
}
}
arr := make([]string, limit)
i := 0
for k := range result {
arr[i] = k
i++
}
return arr
}

122
src/datastruct/set/set.go Normal file
View File

@@ -0,0 +1,122 @@
package set
import "github.com/HDT3213/godis/src/datastruct/dict"
type Set struct {
dict *dict.Dict
}
func Make(shardCountHint int)*Set {
return &Set{
dict: dict.Make(shardCountHint),
}
}
func MakeFromVals(members ...string)*Set {
set := &Set{
dict: dict.Make(len(members)),
}
for _, member := range members {
set.Add(member)
}
return set
}
func (set *Set)Add(val string)int {
return set.dict.Put(val, true)
}
func (set *Set)Remove(val string)int {
return set.dict.Remove(val)
}
func (set *Set)Has(val string)bool {
_, exists := set.dict.Get(val)
return exists
}
func (set *Set)Len()int {
return set.dict.Len()
}
func (set *Set)ToSlice()[]string {
slice := make([]string, set.Len())
i := 0
set.dict.ForEach(func(key string, val interface{})bool {
if i < len(slice) {
slice[i] = key
} else {
// set extended during traversal
slice = append(slice, key)
}
i++
return true
})
return slice
}
func (set *Set)ForEach(consumer func(member string)bool) {
set.dict.ForEach(func(key string, val interface{})bool {
return consumer(key)
})
}
func (set *Set)Intersect(another *Set)*Set {
if set == nil {
panic("set is nil")
}
setSize := set.Len()
anotherSize := another.Len()
size := setSize
if anotherSize < setSize {
size = anotherSize
}
result := Make(size)
another.ForEach(func(member string)bool {
if set.Has(member) {
result.Add(member)
}
return true
})
return result
}
func (set *Set)Union(another *Set)*Set {
if set == nil {
panic("set is nil")
}
result := Make(set.Len() + another.Len())
another.ForEach(func(member string)bool {
result.Add(member)
return true
})
set.ForEach(func(member string)bool {
result.Add(member)
return true
})
return result
}
func (set *Set)Diff(another *Set)*Set {
if set == nil {
panic("set is nil")
}
result := Make(set.Len())
set.ForEach(func(member string)bool {
if !another.Has(member) {
result.Add(member)
}
return true
})
return result
}
func (set *Set)RandomMembers(limit int)[]string {
return set.dict.RandomKeys(limit)
}
func (set *Set)RandomDistinctMembers(limit int)[]string {
return set.dict.RandomDistinctKeys(limit)
}

View File

@@ -0,0 +1,32 @@
package set
import (
"strconv"
"testing"
)
func TestSet(t *testing.T) {
size := 10
set := Make(0)
for i := 0; i < size; i++ {
set.Add(strconv.Itoa(i))
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if !ok {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Remove(strconv.Itoa(i))
if ok != 1 {
t.Error("expected true actual false, key: " + strconv.Itoa(i))
}
}
for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i))
if ok {
t.Error("expected false actual true, key: " + strconv.Itoa(i))
}
}
}

View File

@@ -104,6 +104,19 @@ func MakeCmdMap()map[string]CmdFunc {
cmdMap["hincrby"] = HIncrBy
cmdMap["hincrbyfloat"] = HIncrByFloat
cmdMap["sadd"] = SAdd
cmdMap["sismember"] = SIsMember
cmdMap["srem"] = SRem
cmdMap["scard"] = SCard
cmdMap["smembers"] = SMembers
cmdMap["sinter"] = SInter
cmdMap["sinterstore"] = SInterStore
cmdMap["sunion"] = SUnion
cmdMap["sunionstore"] = SUnionStore
cmdMap["sdiff"] = SDiff
cmdMap["sdiffstore"] = SDiffStore
cmdMap["srandmember"] = SRandMember
return cmdMap
}

504
src/db/set.go Normal file
View File

@@ -0,0 +1,504 @@
package db
import (
HashSet "github.com/HDT3213/godis/src/datastruct/set"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply"
"strconv"
)
func SAdd(db *DB, args [][]byte)redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sadd' command")
}
key := string(args[0])
members := args[1:]
// lock
db.Locks.Lock(key)
defer db.Locks.UnLock(key)
// get or init entity
entity, exists := db.Get(key)
if !exists {
entity = &DataEntity{
Code: SetCode,
Data: HashSet.Make(0),
}
db.Data.Put(key, entity)
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
counter := 0
for _, member := range members {
counter += set.Add(string(member))
}
return reply.MakeIntReply(int64(counter))
}
func SIsMember(db *DB, args [][]byte)redis.Reply {
if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sismember' command")
}
key := string(args[0])
member := string(args[1])
// get or init entity
entity, exists := db.Get(key)
if !exists {
return reply.MakeIntReply(0)
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
has := set.Has(member)
if has {
return reply.MakeIntReply(1)
} else {
return reply.MakeIntReply(0)
}
}
func SRem(db *DB, args [][]byte)redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srem' command")
}
key := string(args[0])
members := args[1:]
// lock
db.Locks.Lock(key)
defer db.Locks.UnLock(key)
// get or init entity
entity, exists := db.Get(key)
if !exists {
return reply.MakeIntReply(0)
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
counter := 0
for _, member := range members {
counter += set.Remove(string(member))
}
if set.Len() == 0 {
db.Remove(key)
}
return reply.MakeIntReply(int64(counter))
}
func SCard(db *DB, args [][]byte)redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'scard' command")
}
key := string(args[0])
// get or init entity
entity, exists := db.Get(key)
if !exists {
return reply.MakeIntReply(0)
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
return reply.MakeIntReply(int64(set.Len()))
}
func SMembers(db *DB, args [][]byte)redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'smembers' command")
}
key := string(args[0])
// lock
db.Locks.RLock(key)
defer db.Locks.RUnLock(key)
// get or init entity
entity, exists := db.Get(key)
if !exists {
return &reply.EmptyMultiBulkReply{}
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
arr := make([][]byte, set.Len())
i := 0
set.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr)
}
func SInter(db *DB, args [][]byte)redis.Reply {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinter' command")
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
var result *HashSet.Set
for _, key := range keys {
entity, exists := db.Get(key)
if !exists {
return &reply.EmptyMultiBulkReply{}
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}
}
}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr)
}
func SInterStore(db *DB, args [][]byte)redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sinterstore' command")
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
db.Locks.Lock(dest)
defer db.Locks.UnLock(dest)
var result *HashSet.Set
for _, key := range keys {
entity, exists := db.Get(key)
if !exists {
db.Remove(dest) // clean ttl and old value
return reply.MakeIntReply(0)
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Intersect(set)
if result.Len() == 0 {
// early termination
db.Remove(dest) // clean ttl and old value
return reply.MakeIntReply(0)
}
}
}
set := HashSet.MakeFromVals(result.ToSlice()...)
entity := &DataEntity{
Code: SetCode,
Data: set,
}
db.Data.Put(dest, entity)
return reply.MakeIntReply(int64(set.Len()))
}
func SUnion(db *DB, args [][]byte)redis.Reply {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunion' command")
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
var result *HashSet.Set
for _, key := range keys {
entity, exists := db.Get(key)
if !exists {
continue
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Union(set)
}
}
if result == nil {
// all keys are empty set
return &reply.EmptyMultiBulkReply{}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr)
}
func SUnionStore(db *DB, args [][]byte)redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sunionstore' command")
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
db.Locks.Lock(dest)
defer db.Locks.UnLock(dest)
var result *HashSet.Set
for _, key := range keys {
entity, exists := db.Get(key)
if !exists {
continue
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Union(set)
}
}
db.Remove(dest) // clean ttl
if result == nil {
// all keys are empty set
return &reply.EmptyMultiBulkReply{}
}
set := HashSet.MakeFromVals(result.ToSlice()...)
entity := &DataEntity{
Code: SetCode,
Data: set,
}
db.Data.Put(dest, entity)
return reply.MakeIntReply(int64(set.Len()))
}
func SDiff(db *DB, args [][]byte)redis.Reply {
if len(args) < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiff' command")
}
keys := make([]string, len(args))
for i, arg := range args {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
var result *HashSet.Set
for i, key := range keys {
entity, exists := db.Get(key)
if !exists {
if i == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}
} else {
continue
}
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
return &reply.EmptyMultiBulkReply{}
}
}
}
if result == nil {
// all keys are nil
return &reply.EmptyMultiBulkReply{}
}
arr := make([][]byte, result.Len())
i := 0
result.ForEach(func (member string)bool {
arr[i] = []byte(member)
i++
return true
})
return reply.MakeMultiBulkReply(arr)
}
func SDiffStore(db *DB, args [][]byte)redis.Reply {
if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'sdiffstore' command")
}
dest := string(args[0])
keys := make([]string, len(args) - 1)
keyArgs := args[1:]
for i, arg := range keyArgs {
keys[i] = string(arg)
}
// lock
db.Locks.RLocks(keys...)
defer db.Locks.RUnLocks(keys...)
db.Locks.Lock(dest)
defer db.Locks.UnLock(dest)
var result *HashSet.Set
for i, key := range keys {
entity, exists := db.Get(key)
if !exists {
if i == 0 {
// early termination
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}
} else {
continue
}
}
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if result == nil {
// init
result = HashSet.MakeFromVals(set.ToSlice()...)
} else {
result = result.Diff(set)
if result.Len() == 0 {
// early termination
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}
}
}
}
if result == nil {
// all keys are nil
db.Remove(dest)
return &reply.EmptyMultiBulkReply{}
}
set := HashSet.MakeFromVals(result.ToSlice()...)
entity := &DataEntity{
Code: SetCode,
Data: set,
}
db.Data.Put(dest, entity)
return reply.MakeIntReply(int64(set.Len()))
}
func SRandMember(db *DB, args [][]byte)redis.Reply {
if len(args) != 1 && len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'srandmember' command")
}
key := string(args[0])
// lock
db.Locks.RLock(key)
defer db.Locks.RUnLock(key)
// get or init entity
entity, exists := db.Get(key)
if !exists {
return &reply.NullBulkReply{}
}
// check type
if entity.Code != SetCode {
return &reply.WrongTypeErrReply{}
}
set, _ := entity.Data.(*HashSet.Set)
if len(args) == 1 {
members := set.RandomMembers(1)
return reply.MakeBulkReply([]byte(members[0]))
} else {
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range")
}
count := int(count64)
if count > 0 {
members := set.RandomMembers(count)
result := make([][]byte, len(members))
for i, v := range members {
result[i] = []byte(v)
}
return reply.MakeMultiBulkReply(result)
} else if count < 0 {
members := set.RandomDistinctMembers(-count)
result := make([][]byte, len(members))
for i, v := range members {
result[i] = []byte(v)
}
return reply.MakeMultiBulkReply(result)
} else {
return &reply.EmptyMultiBulkReply{}
}
}
}

View File

@@ -106,6 +106,7 @@ func Set(db *DB, args [][]byte)redis.Reply {
Data: value,
}
db.Remove(key) // clean ttl
switch policy {
case upsertPolicy:
db.Data.Put(key, entity)