package dict import ( "github.com/hdt3213/godis/lib/wildcard" "math" "math/rand" "sort" "sync" "sync/atomic" "time" ) // ConcurrentDict is thread safe map using sharding lock type ConcurrentDict struct { table []*shard count int32 shardCount int } type shard struct { m map[string]interface{} mutex sync.RWMutex } func computeCapacity(param int) (size int) { if param <= 16 { return 16 } n := param - 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 if n < 0 { return math.MaxInt32 } return n + 1 } // MakeConcurrent creates ConcurrentDict with the given shard count func MakeConcurrent(shardCount int) *ConcurrentDict { shardCount = computeCapacity(shardCount) table := make([]*shard, shardCount) for i := 0; i < shardCount; i++ { table[i] = &shard{ m: make(map[string]interface{}), } } d := &ConcurrentDict{ count: 0, table: table, shardCount: shardCount, } return d } const prime32 = uint32(16777619) func fnv32(key string) uint32 { hash := uint32(2166136261) for i := 0; i < len(key); i++ { hash *= prime32 hash ^= uint32(key[i]) } return hash } func (dict *ConcurrentDict) spread(hashCode uint32) uint32 { if dict == nil { panic("dict is nil") } tableSize := uint32(len(dict.table)) return (tableSize - 1) & hashCode } func (dict *ConcurrentDict) getShard(index uint32) *shard { if dict == nil { panic("dict is nil") } return dict.table[index] } // Get returns the binding value and whether the key is exist func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) s.mutex.Lock() defer s.mutex.Unlock() val, exists = s.m[key] return } func (dict *ConcurrentDict) GetWithLock(key string) (val interface{}, exists bool) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) val, exists = s.m[key] return } // Len returns the number of dict func (dict *ConcurrentDict) Len() int { if dict == nil { panic("dict is nil") } return int(atomic.LoadInt32(&dict.count)) } // Put puts key value into dict and returns the number of new inserted key-value func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) s.mutex.Lock() defer s.mutex.Unlock() if _, ok := s.m[key]; ok { s.m[key] = val return 0 } dict.addCount() s.m[key] = val return 1 } func (dict *ConcurrentDict) PutWithLock(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) if _, ok := s.m[key]; ok { s.m[key] = val return 0 } dict.addCount() s.m[key] = val return 1 } // PutIfAbsent puts value if the key is not exists and returns the number of updated key-value func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) s.mutex.Lock() defer s.mutex.Unlock() if _, ok := s.m[key]; ok { return 0 } s.m[key] = val dict.addCount() return 1 } func (dict *ConcurrentDict) PutIfAbsentWithLock(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) if _, ok := s.m[key]; ok { return 0 } s.m[key] = val dict.addCount() return 1 } // PutIfExists puts value if the key is existed and returns the number of inserted key-value func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) s.mutex.Lock() defer s.mutex.Unlock() if _, ok := s.m[key]; ok { s.m[key] = val return 1 } return 0 } func (dict *ConcurrentDict) PutIfExistsWithLock(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) if _, ok := s.m[key]; ok { s.m[key] = val return 1 } return 0 } // Remove removes the key and return the number of deleted key-value func (dict *ConcurrentDict) Remove(key string) (val interface{}, result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) s.mutex.Lock() defer s.mutex.Unlock() if val, ok := s.m[key]; ok { delete(s.m, key) dict.decreaseCount() return val, 1 } return nil, 0 } func (dict *ConcurrentDict) RemoveWithLock(key string) (val interface{}, result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) s := dict.getShard(index) if val, ok := s.m[key]; ok { delete(s.m, key) dict.decreaseCount() return val, 1 } return val, 0 } func (dict *ConcurrentDict) addCount() int32 { return atomic.AddInt32(&dict.count, 1) } func (dict *ConcurrentDict) decreaseCount() int32 { return atomic.AddInt32(&dict.count, -1) } // ForEach traversal the dict // it may not visit new entry inserted during traversal func (dict *ConcurrentDict) ForEach(consumer Consumer) { if dict == nil { panic("dict is nil") } for _, s := range dict.table { s.mutex.RLock() f := func() bool { defer s.mutex.RUnlock() for key, value := range s.m { continues := consumer(key, value) if !continues { return false } } return true } if !f() { break } } } // Keys returns all keys in dict func (dict *ConcurrentDict) Keys() []string { keys := make([]string, dict.Len()) i := 0 dict.ForEach(func(key string, val interface{}) bool { if i < len(keys) { keys[i] = key i++ } else { keys = append(keys, key) } return true }) return keys } // RandomKey returns a key randomly func (shard *shard) RandomKey() string { if shard == nil { panic("shard is nil") } shard.mutex.RLock() defer shard.mutex.RUnlock() for key := range shard.m { return key } return "" } // RandomKeys randomly returns keys of the given number, may contain duplicated key func (dict *ConcurrentDict) RandomKeys(limit int) []string { size := dict.Len() if limit >= size { return dict.Keys() } shardCount := len(dict.table) result := make([]string, limit) nR := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < limit; { s := dict.getShard(uint32(nR.Intn(shardCount))) if s == nil { continue } key := s.RandomKey() if key != "" { result[i] = key i++ } } return result } // RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string { size := dict.Len() if limit >= size { return dict.Keys() } shardCount := len(dict.table) result := make(map[string]struct{}) nR := rand.New(rand.NewSource(time.Now().UnixNano())) for len(result) < limit { shardIndex := uint32(nR.Intn(shardCount)) s := dict.getShard(shardIndex) if s == nil { continue } key := s.RandomKey() if key != "" { if _, exists := result[key]; !exists { result[key] = struct{}{} } } } arr := make([]string, limit) i := 0 for k := range result { arr[i] = k i++ } return arr } // Clear removes all keys in dict func (dict *ConcurrentDict) Clear() { *dict = *MakeConcurrent(dict.shardCount) } func (dict *ConcurrentDict) toLockIndices(keys []string, reverse bool) []uint32 { indexMap := make(map[uint32]struct{}) for _, key := range keys { index := dict.spread(fnv32(key)) indexMap[index] = struct{}{} } indices := make([]uint32, 0, len(indexMap)) for index := range indexMap { indices = append(indices, index) } sort.Slice(indices, func(i, j int) bool { if !reverse { return indices[i] < indices[j] } return indices[i] > indices[j] }) return indices } // RWLocks locks write keys and read keys together. allow duplicate keys func (dict *ConcurrentDict) RWLocks(writeKeys []string, readKeys []string) { keys := append(writeKeys, readKeys...) indices := dict.toLockIndices(keys, false) writeIndexSet := make(map[uint32]struct{}) for _, wKey := range writeKeys { idx := dict.spread(fnv32(wKey)) writeIndexSet[idx] = struct{}{} } for _, index := range indices { _, w := writeIndexSet[index] mu := &dict.table[index].mutex if w { mu.Lock() } else { mu.RLock() } } } // RWUnLocks unlocks write keys and read keys together. allow duplicate keys func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) { keys := append(writeKeys, readKeys...) indices := dict.toLockIndices(keys, true) writeIndexSet := make(map[uint32]struct{}) for _, wKey := range writeKeys { idx := dict.spread(fnv32(wKey)) writeIndexSet[idx] = struct{}{} } for _, index := range indices { _, w := writeIndexSet[index] mu := &dict.table[index].mutex if w { mu.Unlock() } else { mu.RUnlock() } } } func stringsToBytes(strSlice []string) [][]byte { byteSlice := make([][]byte, len(strSlice)) for i, str := range strSlice { byteSlice[i] = []byte(str) } return byteSlice } func (dict *ConcurrentDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) { size := dict.Len() result := make([][]byte, 0) if pattern == "*" && count >= size { return stringsToBytes(dict.Keys()), 0 } matchKey, err := wildcard.CompilePattern(pattern) if err != nil { return result, -1 } shardCount := len(dict.table) shardIndex := cursor for shardIndex < shardCount { shard := dict.table[shardIndex] shard.mutex.RLock() if len(result)+len(shard.m) > count && shardIndex > cursor { shard.mutex.RUnlock() return result, shardIndex } for key := range shard.m { if pattern == "*" || matchKey.IsMatch(key) { result = append(result, []byte(key)) } } shard.mutex.RUnlock() shardIndex++ } return result, 0 }