package dict import ( "math" "math/rand" "sync" "sync/atomic" "time" ) // ConcurrentDict is thread safe map using sharding lock type ConcurrentDict struct { table []*shard count int32 shardCount int } type shard struct { m map[string]interface{} mutex sync.RWMutex } func computeCapacity(param int) (size int) { if param <= 16 { return 16 } n := param - 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 if n < 0 { return math.MaxInt32 } return n + 1 } // MakeConcurrent creates ConcurrentDict with the given shard count func MakeConcurrent(shardCount int) *ConcurrentDict { shardCount = computeCapacity(shardCount) table := make([]*shard, shardCount) for i := 0; i < shardCount; i++ { table[i] = &shard{ m: make(map[string]interface{}), } } d := &ConcurrentDict{ count: 0, table: table, shardCount: shardCount, } return d } const prime32 = uint32(16777619) func fnv32(key string) uint32 { hash := uint32(2166136261) for i := 0; i < len(key); i++ { hash *= prime32 hash ^= uint32(key[i]) } return hash } func (dict *ConcurrentDict) spread(hashCode uint32) uint32 { if dict == nil { panic("dict is nil") } tableSize := uint32(len(dict.table)) return (tableSize - 1) & uint32(hashCode) } func (dict *ConcurrentDict) getShard(index uint32) *shard { if dict == nil { panic("dict is nil") } return dict.table[index] } // Get returns the binding value and whether the key is exist func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) shard := dict.getShard(index) shard.mutex.RLock() defer shard.mutex.RUnlock() val, exists = shard.m[key] return } // Len returns the number of dict func (dict *ConcurrentDict) Len() int { if dict == nil { panic("dict is nil") } return int(atomic.LoadInt32(&dict.count)) } // Put puts key value into dict and returns the number of new inserted key-value func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) shard := dict.getShard(index) dict.addCount() shard.mutex.Lock() defer shard.mutex.Unlock() if _, ok := shard.m[key]; ok { shard.m[key] = val return 0 } shard.m[key] = val return 1 } // PutIfAbsent puts value if the key is not exists and returns the number of updated key-value func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) shard := dict.getShard(index) shard.mutex.Lock() defer shard.mutex.Unlock() if _, ok := shard.m[key]; ok { return 0 } shard.m[key] = val dict.addCount() return 1 } // PutIfExists puts value if the key is exist and returns the number of inserted key-value func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) shard := dict.getShard(index) shard.mutex.Lock() defer shard.mutex.Unlock() if _, ok := shard.m[key]; ok { shard.m[key] = val return 1 } return 0 } // Remove removes the key and return the number of deleted key-value func (dict *ConcurrentDict) Remove(key string) (result int) { if dict == nil { panic("dict is nil") } hashCode := fnv32(key) index := dict.spread(hashCode) shard := dict.getShard(index) shard.mutex.Lock() defer shard.mutex.Unlock() if _, ok := shard.m[key]; ok { delete(shard.m, key) dict.decreaseCount() return 1 } return 0 } func (dict *ConcurrentDict) addCount() int32 { return atomic.AddInt32(&dict.count, 1) } func (dict *ConcurrentDict) decreaseCount() int32 { return atomic.AddInt32(&dict.count, -1) } // ForEach traversal the dict // it may not visits new entry inserted during traversal func (dict *ConcurrentDict) ForEach(consumer Consumer) { if dict == nil { panic("dict is nil") } for _, shard := range dict.table { shard.mutex.RLock() func() { defer shard.mutex.RUnlock() for key, value := range shard.m { continues := consumer(key, value) if !continues { return } } }() } } // Keys returns all keys in dict func (dict *ConcurrentDict) Keys() []string { keys := make([]string, dict.Len()) i := 0 dict.ForEach(func(key string, val interface{}) bool { if i < len(keys) { keys[i] = key i++ } else { keys = append(keys, key) } return true }) return keys } // RandomKey returns a key randomly func (shard *shard) RandomKey() string { if shard == nil { panic("shard is nil") } shard.mutex.RLock() defer shard.mutex.RUnlock() for key := range shard.m { return key } return "" } // RandomKeys randomly returns keys of the given number, may contain duplicated key func (dict *ConcurrentDict) RandomKeys(limit int) []string { size := dict.Len() if limit >= size { return dict.Keys() } shardCount := len(dict.table) result := make([]string, limit) nR := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < limit; { shard := dict.getShard(uint32(nR.Intn(shardCount))) if shard == nil { continue } key := shard.RandomKey() if key != "" { result[i] = key i++ } } return result } // RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string { size := dict.Len() if limit >= size { return dict.Keys() } shardCount := len(dict.table) result := make(map[string]struct{}) nR := rand.New(rand.NewSource(time.Now().UnixNano())) for len(result) < limit { shardIndex := uint32(nR.Intn(shardCount)) shard := dict.getShard(shardIndex) if shard == nil { continue } key := shard.RandomKey() if key != "" { if _, exists := result[key]; !exists { result[key] = struct{}{} } } } arr := make([]string, limit) i := 0 for k := range result { arr[i] = k i++ } return arr } // Clear removes all keys in dict func (dict *ConcurrentDict) Clear() { *dict = *MakeConcurrent(dict.shardCount) }