diff --git a/src/datastruct/dict/concurrent.go b/src/datastruct/dict/concurrent.go new file mode 100644 index 0000000..9c2acd2 --- /dev/null +++ b/src/datastruct/dict/concurrent.go @@ -0,0 +1,262 @@ +package dict + +import ( + "math/rand" + "sync" + "sync/atomic" +) + +type ConcurrentDict struct { + table []*Shard + count int32 +} + +type Shard struct { + m map[string]interface{} + mutex sync.RWMutex +} + +func MakeConcurrent(shardCount int) *ConcurrentDict { + if shardCount < 1 { + shardCount = 16 + } + table := make([]*Shard, shardCount) + for i := 0; i < shardCount; i++ { + table[i] = &Shard{ + m: make(map[string]interface{}), + } + } + d := &ConcurrentDict{ + count: 0, + table: table, + } + return d +} + +const prime32 = uint32(16777619) + +func fnv32(key string) uint32 { + hash := uint32(2166136261) + for i := 0; i < len(key); i++ { + hash *= prime32 + hash ^= uint32(key[i]) + } + return hash +} + +func (dict *ConcurrentDict) spread(hashCode uint32) uint32 { + if dict == nil { + panic("dict is nil") + } + tableSize := uint32(len(dict.table)) + return (tableSize - 1) & uint32(hashCode) +} + +func (dict *ConcurrentDict) getShard(index uint32) *Shard { + if dict == nil { + panic("dict is nil") + } + return dict.table[index] +} + +func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + shard := dict.getShard(index) + shard.mutex.RLock() + defer shard.mutex.RUnlock() + val, exists = shard.m[key] + return +} + +func (dict *ConcurrentDict) Len() int { + if dict == nil { + panic("dict is nil") + } + return int(atomic.LoadInt32(&dict.count)) +} + +// return the number of new inserted key-value +func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + shard := dict.getShard(index) + shard.mutex.Lock() + defer shard.mutex.Unlock() + + if _, ok := shard.m[key]; ok { + shard.m[key] = val + return 0 + } else { + shard.m[key] = val + dict.addCount() + return 1 + } +} + +// return the number of updated key-value +func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + shard := dict.getShard(index) + shard.mutex.Lock() + defer shard.mutex.Unlock() + + if _, ok := shard.m[key]; ok { + return 0 + } else { + shard.m[key] = val + dict.addCount() + return 1 + } +} + +// return the number of updated key-value +func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + shard := dict.getShard(index) + shard.mutex.Lock() + defer shard.mutex.Unlock() + + if _, ok := shard.m[key]; ok { + shard.m[key] = val + return 1 + } else { + return 0 + } +} + +// return the number of deleted key-value +func (dict *ConcurrentDict) Remove(key string) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + shard := dict.getShard(index) + shard.mutex.Lock() + defer shard.mutex.Unlock() + + if _, ok := shard.m[key]; ok { + delete(shard.m, key) + return 1 + } else { + return 0 + } + return +} + +func (dict *ConcurrentDict) addCount() int32 { + return atomic.AddInt32(&dict.count, 1) +} + +/* + * may not contains new entry inserted during traversal + */ +func (dict *ConcurrentDict) ForEach(consumer Consumer) { + if dict == nil { + panic("dict is nil") + } + + for _, shard := range dict.table { + for key, value := range shard.m { + shard.mutex.RLock() + continues := consumer(key, value) + shard.mutex.RUnlock() + if !continues { + return + } + } + } +} + +func (dict *ConcurrentDict) Keys() []string { + keys := make([]string, dict.Len()) + i := 0 + dict.ForEach(func(key string, val interface{}) bool { + if i < len(keys) { + keys[i] = key + i++ + } else { + keys = append(keys, key) + } + return true + }) + return keys +} + +func (shard *Shard) RandomKey() string { + if shard == nil { + panic("shard is nil") + } + shard.mutex.RLock() + defer shard.mutex.RUnlock() + + for key := range shard.m { + return key + } + return "" +} + +func (dict *ConcurrentDict) RandomKeys(limit int) []string { + size := dict.Len() + if limit >= size { + return dict.Keys() + } + shardCount := len(dict.table) + + result := make([]string, limit) + for i := 0; i < limit; { + shard := dict.getShard(uint32(rand.Intn(shardCount))) + if shard == nil { + continue + } + key := shard.RandomKey() + if key != "" { + result[i] = key + i++ + } + } + return result +} + +func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string { + size := dict.Len() + if limit >= size { + return dict.Keys() + } + + shardCount := len(dict.table) + result := make(map[string]bool) + for len(result) < limit { + shardIndex := uint32(rand.Intn(shardCount)) + shard := dict.getShard(shardIndex) + if shard == nil { + continue + } + key := shard.RandomKey() + if key != "" { + result[key] = true + } + } + arr := make([]string, limit) + i := 0 + for k := range result { + arr[i] = k + i++ + } + return arr +} diff --git a/src/datastruct/dict/dict.go b/src/datastruct/dict/dict.go index ccd6c65..5129bb0 100644 --- a/src/datastruct/dict/dict.go +++ b/src/datastruct/dict/dict.go @@ -1,266 +1,16 @@ package dict -import ( - "math/rand" - "sync" - "sync/atomic" -) - -type Dict struct { - table []*Shard - count int32 -} - -type Shard struct { - m map[string]interface{} - mutex sync.RWMutex -} - - -func Make(shardCount int) *Dict { - if shardCount < 1 { - shardCount = 16 - } - table := make([]*Shard, shardCount) - for i := 0; i < shardCount; i++ { - table[i] = &Shard{ - m: make(map[string]interface{}), - } - } - d := &Dict{ - count: 0, - table: table, - } - return d -} - -const prime32 = uint32(16777619) - -func fnv32(key string) uint32 { - hash := uint32(2166136261) - for i := 0; i < len(key); i++ { - hash *= prime32 - hash ^= uint32(key[i]) - } - return hash -} - -func (dict *Dict) spread(hashCode uint32) uint32 { - if dict == nil { - panic("dict is nil") - } - tableSize := uint32(len(dict.table)) - return (tableSize - 1) & uint32(hashCode) -} - -func (dict *Dict) getShard(index uint32) *Shard { - if dict == nil { - panic("dict is nil") - } - return dict.table[index] -} - -func (dict *Dict) Get(key string) (val interface{}, exists bool) { - if dict == nil { - panic("dict is nil") - } - hashCode := fnv32(key) - index := dict.spread(hashCode) - shard := dict.getShard(index) - shard.mutex.RLock() - defer shard.mutex.RUnlock() - val, exists = shard.m[key] - return -} - -func (dict *Dict) Len() int { - if dict == nil { - panic("dict is nil") - } - return int(atomic.LoadInt32(&dict.count)) -} - -// return the number of new inserted key-value -func (dict *Dict) Put(key string, val interface{}) (result int) { - if dict == nil { - panic("dict is nil") - } - hashCode := fnv32(key) - index := dict.spread(hashCode) - shard := dict.getShard(index) - shard.mutex.Lock() - defer shard.mutex.Unlock() - - if _, ok := shard.m[key]; ok { - shard.m[key] = val - return 0 - } else { - shard.m[key] = val - dict.addCount() - return 1 - } -} - -// return the number of updated key-value -func (dict *Dict) PutIfAbsent(key string, val interface{}) (result int) { - if dict == nil { - panic("dict is nil") - } - hashCode := fnv32(key) - index := dict.spread(hashCode) - shard := dict.getShard(index) - shard.mutex.Lock() - defer shard.mutex.Unlock() - - if _, ok := shard.m[key]; ok { - return 0 - } else { - shard.m[key] = val - dict.addCount() - return 1 - } -} - - -// return the number of updated key-value -func (dict *Dict) PutIfExists(key string, val interface{})(result int) { - if dict == nil { - panic("dict is nil") - } - hashCode := fnv32(key) - index := dict.spread(hashCode) - shard := dict.getShard(index) - shard.mutex.Lock() - defer shard.mutex.Unlock() - - if _, ok := shard.m[key]; ok { - shard.m[key] = val - return 1 - } else { - return 0 - } -} - -// return the number of deleted key-value -func (dict *Dict) Remove(key string)(result int) { - if dict == nil { - panic("dict is nil") - } - hashCode := fnv32(key) - index := dict.spread(hashCode) - shard := dict.getShard(index) - shard.mutex.Lock() - defer shard.mutex.Unlock() - - if _, ok := shard.m[key]; ok { - delete(shard.m, key) - return 1 - } else { - return 0 - } - return -} - -func (dict *Dict) addCount() int32 { - return atomic.AddInt32(&dict.count, 1) -} - type Consumer func(key string, val interface{})bool -/* - * may not contains new entry inserted during traversal - */ -func (dict *Dict)ForEach(consumer Consumer) { - if dict == nil { - panic("dict is nil") - } - - for _, shard := range dict.table { - for key, value := range shard.m { - shard.mutex.RLock() - continues := consumer(key, value) - shard.mutex.RUnlock() - if !continues { - return - } - } - } +type Dict interface { + Get(key string) (val interface{}, exists bool) + Len() int + Put(key string, val interface{}) (result int) + PutIfAbsent(key string, val interface{}) (result int) + PutIfExists(key string, val interface{}) (result int) + Remove(key string) (result int) + ForEach(consumer Consumer) + Keys() []string + RandomKeys(limit int) []string + RandomDistinctKeys(limit int) []string } - -func (dict *Dict)Keys()[]string { - keys := make([]string, dict.Len()) - i := 0 - dict.ForEach(func(key string, val interface{})bool { - if i < len(keys) { - keys[i] = key - i++ - } else { - keys = append(keys, key) - } - return true - }) - return keys -} - -func (shard *Shard)RandomKey()string { - if shard == nil { - panic("shard is nil") - } - shard.mutex.RLock() - defer shard.mutex.RUnlock() - - for key := range shard.m { - return key - } - return "" -} - -func (dict *Dict)RandomKeys(limit int)[]string { - size := dict.Len() - if limit >= size { - return dict.Keys() - } - shardCount := len(dict.table) - - result := make([]string, limit) - for i := 0; i < limit; { - shard := dict.getShard(uint32(rand.Intn(shardCount))) - if shard == nil { - continue - } - key := shard.RandomKey() - if key != "" { - result[i] = key - i++ - } - } - return result -} - -func (dict *Dict)RandomDistinctKeys(limit int)[]string { - size := dict.Len() - if limit >= size { - return dict.Keys() - } - - shardCount := len(dict.table) - result := make(map[string]bool) - for len(result) < limit { - shardIndex := uint32(rand.Intn(shardCount)) - shard := dict.getShard(shardIndex) - if shard == nil { - continue - } - key := shard.RandomKey() - if key != "" { - result[key] = true - } - } - arr := make([]string, limit) - i := 0 - for k := range result { - arr[i] = k - i++ - } - return arr -} \ No newline at end of file diff --git a/src/datastruct/dict/dict_test.go b/src/datastruct/dict/dict_test.go index 70ee3f6..285de31 100644 --- a/src/datastruct/dict/dict_test.go +++ b/src/datastruct/dict/dict_test.go @@ -7,7 +7,7 @@ import ( ) func TestPut(t *testing.T) { - d := Make(0) + d := MakeConcurrent(0) count := 100 var wg sync.WaitGroup wg.Add(count) @@ -36,7 +36,7 @@ func TestPut(t *testing.T) { } func TestPutIfAbsent(t *testing.T) { - d := Make(0) + d := MakeConcurrent(0) count := 100 var wg sync.WaitGroup wg.Add(count) @@ -81,7 +81,7 @@ func TestPutIfAbsent(t *testing.T) { } func TestPutIfExists(t *testing.T) { - d := Make(0) + d := MakeConcurrent(0) count := 100 var wg sync.WaitGroup wg.Add(count) @@ -114,7 +114,7 @@ func TestPutIfExists(t *testing.T) { } func TestRemove(t *testing.T) { - d := Make(0) + d := MakeConcurrent(0) // remove head node for i := 0; i < 100; i++ { @@ -150,7 +150,7 @@ func TestRemove(t *testing.T) { } // remove tail node - d = Make(0) + d = MakeConcurrent(0) for i := 0; i < 100; i++ { // insert key := "k" + strconv.Itoa(i) @@ -184,7 +184,7 @@ func TestRemove(t *testing.T) { } // remove middle node - d = Make(0) + d = MakeConcurrent(0) d.Put("head", 0) for i := 0; i < 10; i++ { // insert @@ -221,7 +221,7 @@ func TestRemove(t *testing.T) { } func TestForEach(t *testing.T) { - d := Make(0) + d := MakeConcurrent(0) size := 100 for i := 0; i < size; i++ { // insert diff --git a/src/datastruct/dict/simple.go b/src/datastruct/dict/simple.go new file mode 100644 index 0000000..ac39b67 --- /dev/null +++ b/src/datastruct/dict/simple.go @@ -0,0 +1,108 @@ +package dict + +type SimpleDict struct { + m map[string]interface{} +} + +func MakeSimple() *SimpleDict { + return &SimpleDict{ + m: make(map[string]interface{}), + } +} + +func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) { + val, ok := dict.m[key] + return val, ok +} + +func (dict *SimpleDict) Len() int { + if dict.m == nil { + panic("m is nil") + } + return len(dict.m) +} + +func (dict *SimpleDict) Put(key string, val interface{}) (result int) { + _, existed := dict.m[key] + dict.m[key] = val + if existed { + return 0 + } else { + return 1 + } +} + +func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) { + _, existed := dict.m[key] + if existed { + return 0 + } else { + dict.m[key] = val + return 1 + } +} + +func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) { + _, existed := dict.m[key] + if existed { + dict.m[key] = val + return 1 + } else { + return 0 + } +} + +func (dict *SimpleDict) Remove(key string) (result int) { + _, existed := dict.m[key] + delete(dict.m, key) + if existed { + return 1 + } else { + return 0 + } +} + +func (dict *SimpleDict) Keys() []string { + result := make([]string, len(dict.m)) + i := 0 + for k := range dict.m { + result[i] = k + } + return result +} + +func (dict *SimpleDict) ForEach(consumer Consumer) { + for k, v := range dict.m { + if !consumer(k, v) { + break + } + } +} + +func (dict *SimpleDict) RandomKeys(limit int) []string { + result := make([]string, limit) + for i := 0; i < limit; i++ { + for k := range dict.m { + result[i] = k + break + } + } + return result +} + +func (dict *SimpleDict) RandomDistinctKeys(limit int) []string { + size := limit + if size > len(dict.m) { + size = len(dict.m) + } + result := make([]string, size) + i := 0 + for k := range dict.m { + if i == limit { + break + } + result[i] = k + i++ + } + return result +} diff --git a/src/datastruct/set/set.go b/src/datastruct/set/set.go index fbeac69..36daf91 100644 --- a/src/datastruct/set/set.go +++ b/src/datastruct/set/set.go @@ -3,18 +3,18 @@ package set import "github.com/HDT3213/godis/src/datastruct/dict" type Set struct { - dict *dict.Dict + dict dict.Dict } -func Make(shardCountHint int)*Set { +func Make() *Set { return &Set{ - dict: dict.Make(shardCountHint), + dict: dict.MakeSimple(), } } func MakeFromVals(members ...string)*Set { set := &Set{ - dict: dict.Make(len(members)), + dict: dict.MakeConcurrent(len(members)), } for _, member := range members { set.Add(member) @@ -65,14 +65,8 @@ func (set *Set)Intersect(another *Set)*Set { if set == nil { panic("set is nil") } - setSize := set.Len() - anotherSize := another.Len() - size := setSize - if anotherSize < setSize { - size = anotherSize - } - result := Make(size) + result := Make() another.ForEach(func(member string)bool { if set.Has(member) { result.Add(member) @@ -86,7 +80,7 @@ func (set *Set)Union(another *Set)*Set { if set == nil { panic("set is nil") } - result := Make(set.Len() + another.Len()) + result := Make() another.ForEach(func(member string)bool { result.Add(member) return true @@ -103,7 +97,7 @@ func (set *Set)Diff(another *Set)*Set { panic("set is nil") } - result := Make(set.Len()) + result := Make() set.ForEach(func(member string)bool { if !another.Has(member) { result.Add(member) diff --git a/src/datastruct/set/set_test.go b/src/datastruct/set/set_test.go index 2dbfa9c..b8b85fb 100644 --- a/src/datastruct/set/set_test.go +++ b/src/datastruct/set/set_test.go @@ -7,7 +7,7 @@ import ( func TestSet(t *testing.T) { size := 10 - set := Make(0) + set := Make() for i := 0; i < size; i++ { set.Add(strconv.Itoa(i)) } diff --git a/src/db/aof.go b/src/db/aof.go index 159de3e..80335c2 100644 --- a/src/db/aof.go +++ b/src/db/aof.go @@ -169,8 +169,8 @@ func (db *DB) aofRewrite() { // load aof file tmpDB := &DB{ - Data: dict.Make(dataDictSize), - TTLMap: dict.Make(ttlDictSize), + Data: dict.MakeSimple(), + TTLMap: dict.MakeSimple(), Locker: lock.Make(lockerSize), interval: 5 * time.Second, @@ -189,7 +189,7 @@ func (db *DB) aofRewrite() { cmd = persistList(key, val) case *set.Set: cmd = persistSet(key, val) - case *dict.Dict: + case dict.Dict: cmd = persistHash(key, val) case *SortedSet.SortedSet: cmd = persistZSet(key, val) @@ -245,7 +245,7 @@ func persistSet(key string, set *set.Set) *reply.MultiBulkReply { var hMSetCmd = []byte("HMSET") -func persistHash(key string, hash *dict.Dict) *reply.MultiBulkReply { +func persistHash(key string, hash dict.Dict) *reply.MultiBulkReply { args := make([][]byte, 2+hash.Len()*2) args[0] = hMSetCmd args[1] = []byte(key) diff --git a/src/db/db.go b/src/db/db.go index 2ba9172..8550463 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -38,11 +38,11 @@ type CmdFunc func(db *DB, args [][]byte) (redis.Reply, *extra) type DB struct { // key -> DataEntity - Data *dict.Dict + Data dict.Dict // key -> expireTime (time.Time) - TTLMap *dict.Dict + TTLMap dict.Dict // channel -> list<*client> - SubMap *dict.Dict + SubMap dict.Dict // dict will ensure thread safety of its method // use this mutex for complicated command only, eg. rpush, incr ... @@ -52,7 +52,7 @@ type DB struct { interval time.Duration // channel -> list(*Client) - subs *dict.Dict + subs dict.Dict // lock channel subsLocker *lock.Locks @@ -71,12 +71,12 @@ var router = MakeRouter() func MakeDB() *DB { db := &DB{ - Data: dict.Make(dataDictSize), - TTLMap: dict.Make(ttlDictSize), + Data: dict.MakeConcurrent(dataDictSize), + TTLMap: dict.MakeConcurrent(ttlDictSize), Locker: lock.Make(lockerSize), interval: 5 * time.Second, - subs: dict.Make(4), + subs: dict.MakeConcurrent(4), subsLocker: lock.Make(16), } @@ -216,8 +216,8 @@ func (db *DB) Flush() { db.stopWorld.Lock() defer db.stopWorld.Unlock() - db.Data = dict.Make(dataDictSize) - db.TTLMap = dict.Make(ttlDictSize) + db.Data = dict.MakeConcurrent(dataDictSize) + db.TTLMap = dict.MakeConcurrent(ttlDictSize) db.Locker = lock.Make(lockerSize) } diff --git a/src/db/hash.go b/src/db/hash.go index ba958d7..52acf13 100644 --- a/src/db/hash.go +++ b/src/db/hash.go @@ -8,26 +8,26 @@ import ( "strconv" ) -func (db *DB)getAsDict(key string)(*Dict.Dict, reply.ErrorReply) { +func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) { entity, exists := db.Get(key) if !exists { return nil, nil } - dict, ok := entity.Data.(*Dict.Dict) + dict, ok := entity.Data.(Dict.Dict) if !ok { return nil, &reply.WrongTypeErrReply{} } return dict, nil } -func (db *DB) getOrInitDict(key string)(dict *Dict.Dict, inited bool, errReply reply.ErrorReply) { +func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) { dict, errReply = db.getAsDict(key) if errReply != nil { return nil, false, errReply } inited = false if dict == nil { - dict = Dict.Make(1) + dict = Dict.MakeSimple() db.Put(key, &DataEntity{ Data: dict, }) diff --git a/src/db/set.go b/src/db/set.go index 6d3d367..29c3a46 100644 --- a/src/db/set.go +++ b/src/db/set.go @@ -27,7 +27,7 @@ func (db *DB) getOrInitSet(key string)(set *HashSet.Set, inited bool, errReply r } inited = false if set == nil { - set = HashSet.Make(0) + set = HashSet.Make() db.Put(key, &DataEntity{ Data: set, })