diff --git a/database/database.go b/database/database.go index 7f13cd2..859327a 100644 --- a/database/database.go +++ b/database/database.go @@ -6,7 +6,6 @@ import ( "time" "github.com/hdt3213/godis/datastruct/dict" - "github.com/hdt3213/godis/datastruct/lock" "github.com/hdt3213/godis/interface/database" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" @@ -17,22 +16,18 @@ import ( const ( dataDictSize = 1 << 16 ttlDictSize = 1 << 10 - lockerSize = 1024 ) // DB stores data and execute user's commands type DB struct { index int // key -> DataEntity - data dict.Dict + data *dict.ConcurrentDict // key -> expireTime (time.Time) - ttlMap dict.Dict + ttlMap *dict.ConcurrentDict // key -> version(uint32) - versionMap dict.Dict + versionMap *dict.ConcurrentDict - // dict.Dict will ensure concurrent-safety of its method - // use this mutex for complicated command only, eg. rpush, incr ... - locker *lock.Locks // addaof is used to add command to aof addAof func(CmdLine) } @@ -58,20 +53,17 @@ func makeDB() *DB { data: dict.MakeConcurrent(dataDictSize), ttlMap: dict.MakeConcurrent(ttlDictSize), versionMap: dict.MakeConcurrent(dataDictSize), - locker: lock.Make(lockerSize), addAof: func(line CmdLine) {}, } return db } // makeBasicDB create DB instance only with basic abilities. -// It is not concurrent safe func makeBasicDB() *DB { db := &DB{ - data: dict.MakeSimple(), - ttlMap: dict.MakeSimple(), - versionMap: dict.MakeSimple(), - locker: lock.Make(1), + data: dict.MakeConcurrent(dataDictSize), + ttlMap: dict.MakeConcurrent(ttlDictSize), + versionMap: dict.MakeConcurrent(dataDictSize), addAof: func(line CmdLine) {}, } return db @@ -154,7 +146,7 @@ func validateArity(arity int, cmdArgs [][]byte) bool { // GetEntity returns DataEntity bind to given key func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { - raw, ok := db.data.Get(key) + raw, ok := db.data.GetWithLock(key) if !ok { return nil, false } @@ -167,22 +159,22 @@ func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { // PutEntity a DataEntity into DB func (db *DB) PutEntity(key string, entity *database.DataEntity) int { - return db.data.Put(key, entity) + return db.data.PutWithLock(key, entity) } // PutIfExists edit an existing DataEntity func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { - return db.data.PutIfExists(key, entity) + return db.data.PutIfExistsWithLock(key, entity) } // PutIfAbsent insert an DataEntity only if the key not exists func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { - return db.data.PutIfAbsent(key, entity) + return db.data.PutIfAbsentWithLock(key, entity) } // Remove the given key from db func (db *DB) Remove(key string) { - db.data.Remove(key) + db.data.RemoveWithLock(key) db.ttlMap.Remove(key) taskKey := genExpireTask(key) timewheel.Cancel(taskKey) @@ -192,7 +184,7 @@ func (db *DB) Remove(key string) { func (db *DB) Removes(keys ...string) (deleted int) { deleted = 0 for _, key := range keys { - _, exists := db.data.Get(key) + _, exists := db.data.GetWithLock(key) if exists { db.Remove(key) deleted++ @@ -207,19 +199,18 @@ func (db *DB) Removes(keys ...string) (deleted int) { func (db *DB) Flush() { db.data.Clear() db.ttlMap.Clear() - db.locker = lock.Make(lockerSize) } /* ---- Lock Function ----- */ // RWLocks lock keys for writing and reading func (db *DB) RWLocks(writeKeys []string, readKeys []string) { - db.locker.RWLocks(writeKeys, readKeys) + db.data.RWLocks(writeKeys, readKeys) } // RWUnLocks unlock keys for writing and reading func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) { - db.locker.RWUnLocks(writeKeys, readKeys) + db.data.RWUnLocks(writeKeys, readKeys) } /* ---- TTL Functions ---- */ diff --git a/database/util_test.go b/database/util_test.go index a728dcf..c992b3f 100644 --- a/database/util_test.go +++ b/database/util_test.go @@ -2,7 +2,6 @@ package database import ( "github.com/hdt3213/godis/datastruct/dict" - "github.com/hdt3213/godis/datastruct/lock" ) func makeTestDB() *DB { @@ -10,7 +9,6 @@ func makeTestDB() *DB { data: dict.MakeConcurrent(dataDictSize), versionMap: dict.MakeConcurrent(dataDictSize), ttlMap: dict.MakeConcurrent(ttlDictSize), - locker: lock.Make(lockerSize), addAof: func(line CmdLine) {}, } } diff --git a/datastruct/dict/concurrent.go b/datastruct/dict/concurrent.go index 46bb4e1..eb6d5ce 100644 --- a/datastruct/dict/concurrent.go +++ b/datastruct/dict/concurrent.go @@ -3,6 +3,7 @@ package dict import ( "math" "math/rand" + "sort" "sync" "sync/atomic" "time" @@ -87,8 +88,19 @@ func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) { hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) - s.mutex.RLock() - defer s.mutex.RUnlock() + s.mutex.Lock() + defer s.mutex.Unlock() + val, exists = s.m[key] + return +} + +func (dict *ConcurrentDict) GetWithLock(key string) (val interface{}, exists bool) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + s := dict.getShard(index) val, exists = s.m[key] return } @@ -121,6 +133,23 @@ func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) { return 1 } +func (dict *ConcurrentDict) PutWithLock(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + s := dict.getShard(index) + + if _, ok := s.m[key]; ok { + s.m[key] = val + return 0 + } + dict.addCount() + s.m[key] = val + return 1 +} + // PutIfAbsent puts value if the key is not exists and returns the number of updated key-value func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) { if dict == nil { @@ -140,6 +169,22 @@ func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int return 1 } +func (dict *ConcurrentDict) PutIfAbsentWithLock(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + s := dict.getShard(index) + + if _, ok := s.m[key]; ok { + return 0 + } + s.m[key] = val + dict.addCount() + return 1 +} + // PutIfExists puts value if the key is exist and returns the number of inserted key-value func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) { if dict == nil { @@ -158,6 +203,21 @@ func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int return 0 } +func (dict *ConcurrentDict) PutIfExistsWithLock(key string, val interface{}) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + s := dict.getShard(index) + + if _, ok := s.m[key]; ok { + s.m[key] = val + return 1 + } + return 0 +} + // Remove removes the key and return the number of deleted key-value func (dict *ConcurrentDict) Remove(key string) (result int) { if dict == nil { @@ -177,6 +237,22 @@ func (dict *ConcurrentDict) Remove(key string) (result int) { return 0 } +func (dict *ConcurrentDict) RemoveWithLock(key string) (result int) { + if dict == nil { + panic("dict is nil") + } + hashCode := fnv32(key) + index := dict.spread(hashCode) + s := dict.getShard(index) + + if _, ok := s.m[key]; ok { + delete(s.m, key) + dict.decreaseCount() + return 1 + } + return 0 +} + func (dict *ConcurrentDict) addCount() int32 { return atomic.AddInt32(&dict.count, 1) } @@ -300,3 +376,62 @@ func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string { func (dict *ConcurrentDict) Clear() { *dict = *MakeConcurrent(dict.shardCount) } + +func (dict *ConcurrentDict) toLockIndices(keys []string, reverse bool) []uint32 { + indexMap := make(map[uint32]struct{}) + for _, key := range keys { + index := dict.spread(fnv32(key)) + indexMap[index] = struct{}{} + } + indices := make([]uint32, 0, len(indexMap)) + for index := range indexMap { + indices = append(indices, index) + } + sort.Slice(indices, func(i, j int) bool { + if !reverse { + return indices[i] < indices[j] + } + return indices[i] > indices[j] + }) + return indices +} + +// RWLocks locks write keys and read keys together. allow duplicate keys +func (dict *ConcurrentDict) RWLocks(writeKeys []string, readKeys []string) { + keys := append(writeKeys, readKeys...) + indices := dict.toLockIndices(keys, false) + writeIndexSet := make(map[uint32]struct{}) + for _, wKey := range writeKeys { + idx := dict.spread(fnv32(wKey)) + writeIndexSet[idx] = struct{}{} + } + for _, index := range indices { + _, w := writeIndexSet[index] + mu := &dict.table[index].mutex + if w { + mu.Lock() + } else { + mu.RLock() + } + } +} + +// RWUnLocks unlocks write keys and read keys together. allow duplicate keys +func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) { + keys := append(writeKeys, readKeys...) + indices := dict.toLockIndices(keys, true) + writeIndexSet := make(map[uint32]struct{}) + for _, wKey := range writeKeys { + idx := dict.spread(fnv32(wKey)) + writeIndexSet[idx] = struct{}{} + } + for _, index := range indices { + _, w := writeIndexSet[index] + mu := &dict.table[index].mutex + if w { + mu.Unlock() + } else { + mu.RUnlock() + } + } +} diff --git a/datastruct/dict/concurrent_test.go b/datastruct/dict/concurrent_test.go index e31247f..1703074 100644 --- a/datastruct/dict/concurrent_test.go +++ b/datastruct/dict/concurrent_test.go @@ -36,6 +36,45 @@ func TestConcurrentPut(t *testing.T) { wg.Wait() } +func TestConcurrentPutWithLock(t *testing.T) { + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + keys := make([]string, count) + + for i := 0; i < count; i++ { + // insert + key := "k" + strconv.Itoa(i) + keys[i] = key + } + d.RWLocks(keys, nil) + defer d.RWUnLocks(keys, nil) + + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + ret := d.PutWithLock(key, i) + if ret != 1 { // insert 1 + t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) + } + val, ok := d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) + } + } else { + _, ok := d.GetWithLock(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + wg.Done() + }(i) + } + wg.Wait() +} + func TestConcurrentPutIfAbsent(t *testing.T) { d := MakeConcurrent(0) count := 100 @@ -81,11 +120,61 @@ func TestConcurrentPutIfAbsent(t *testing.T) { wg.Wait() } +func TestConcurrentPutIfAbsentWithLock(t *testing.T) { + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + keys := make([]string, 1) + d.RWLocks(keys, nil) + ret := d.PutIfAbsentWithLock(key, i) + if ret != 1 { // insert 1 + t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) + } + val, ok := d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + + ", key: " + key) + } + } else { + _, ok := d.GetWithLock(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + + // update + ret = d.PutIfAbsentWithLock(key, i*10) + if ret != 0 { // no update + t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) + } + val, ok = d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) + } + } else { + t.Error("put test failed: expected true, actual: false, key: " + key) + } + d.RWUnLocks(keys, nil) + wg.Done() + }(i) + } + wg.Wait() +} + func TestConcurrentPutIfExists(t *testing.T) { d := MakeConcurrent(0) count := 100 var wg sync.WaitGroup wg.Add(count) + for i := 0; i < count; i++ { go func(i int) { // insert @@ -114,6 +203,42 @@ func TestConcurrentPutIfExists(t *testing.T) { wg.Wait() } +func TestConcurrentPutIfExistsWithLock(t *testing.T) { + d := MakeConcurrent(0) + count := 100 + var wg sync.WaitGroup + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + // insert + key := "k" + strconv.Itoa(i) + keys := make([]string, 1) + d.RWLocks(keys, nil) + // insert + ret := d.PutIfExistsWithLock(key, i) + if ret != 0 { // insert + t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) + } + d.PutWithLock(key, i) + d.PutIfExistsWithLock(key, 10*i) + val, ok := d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != 10*i { + t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + _, ok := d.GetWithLock(key) + t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) + } + d.RWUnLocks(keys, nil) + wg.Done() + }(i) + } + wg.Wait() +} + func TestConcurrentRemove(t *testing.T) { d := MakeConcurrent(0) totalCount := 100 @@ -123,7 +248,7 @@ func TestConcurrentRemove(t *testing.T) { key := "k" + strconv.Itoa(i) d.Put(key, i) } - if d.Len()!=totalCount{ + if d.Len() != totalCount { t.Error("put test failed: expected len is 100, actual: " + strconv.Itoa(d.Len())) } for i := 0; i < totalCount; i++ { @@ -143,7 +268,7 @@ func TestConcurrentRemove(t *testing.T) { if ret != 1 { t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) } - if d.Len()!=totalCount-i-1{ + if d.Len() != totalCount-i-1 { t.Error("put test failed: expected len is 99, actual: " + strconv.Itoa(d.Len())) } _, ok = d.Get(key) @@ -154,7 +279,7 @@ func TestConcurrentRemove(t *testing.T) { if ret != 0 { t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) } - if d.Len()!=totalCount-i-1{ + if d.Len() != totalCount-i-1 { t.Error("put test failed: expected len is 99, actual: " + strconv.Itoa(d.Len())) } } @@ -230,6 +355,122 @@ func TestConcurrentRemove(t *testing.T) { } } +func TestConcurrentRemoveWithLock(t *testing.T) { + d := MakeConcurrent(0) + totalCount := 100 + // remove head node + for i := 0; i < totalCount; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.PutWithLock(key, i) + } + if d.Len() != totalCount { + t.Error("put test failed: expected len is 100, actual: " + strconv.Itoa(d.Len())) + } + for i := 0; i < totalCount; i++ { + key := "k" + strconv.Itoa(i) + + val, ok := d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.RemoveWithLock(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) + } + if d.Len() != totalCount-i-1 { + t.Error("put test failed: expected len is 99, actual: " + strconv.Itoa(d.Len())) + } + _, ok = d.GetWithLock(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.RemoveWithLock(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + if d.Len() != totalCount-i-1 { + t.Error("put test failed: expected len is 99, actual: " + strconv.Itoa(d.Len())) + } + } + + // remove tail node + d = MakeConcurrent(0) + for i := 0; i < 100; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.PutWithLock(key, i) + } + for i := 9; i >= 0; i-- { + key := "k" + strconv.Itoa(i) + + val, ok := d.GetWithLock(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.RemoveWithLock(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) + } + _, ok = d.GetWithLock(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.RemoveWithLock(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + } + + // remove middle node + d = MakeConcurrent(0) + d.Put("head", 0) + for i := 0; i < 10; i++ { + // insert + key := "k" + strconv.Itoa(i) + d.PutWithLock(key, i) + } + d.PutWithLock("tail", 0) + for i := 9; i >= 0; i-- { + key := "k" + strconv.Itoa(i) + + val, ok := d.Get(key) + if ok { + intVal, _ := val.(int) + if intVal != i { + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) + } + } else { + t.Error("put test failed: expected true, actual: false") + } + + ret := d.RemoveWithLock(key) + if ret != 1 { + t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) + } + _, ok = d.GetWithLock(key) + if ok { + t.Error("remove test failed: expected true, actual false") + } + ret = d.RemoveWithLock(key) + if ret != 0 { + t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) + } + } +} + //change t.Error remove->forEach func TestConcurrentForEach(t *testing.T) { d := MakeConcurrent(0) diff --git a/datastruct/lock/lock_map.go b/datastruct/lock/lock_map.go index 78d5011..e4bb4ef 100644 --- a/datastruct/lock/lock_map.go +++ b/datastruct/lock/lock_map.go @@ -71,10 +71,10 @@ func (locks *Locks) RUnLock(key string) { } func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 { - indexMap := make(map[uint32]bool) + indexMap := make(map[uint32]struct{}) for _, key := range keys { index := locks.spread(fnv32(key)) - indexMap[index] = true + indexMap[index] = struct{}{} } indices := make([]uint32, 0, len(indexMap)) for index := range indexMap {