Files
redis-go/datastruct/dict/concurrent.go
2025-02-03 18:38:47 +08:00

489 lines
9.9 KiB
Go

package dict
import (
"github.com/hdt3213/godis/lib/wildcard"
"math"
"math/rand"
"sort"
"sync"
"sync/atomic"
"time"
)
// ConcurrentDict is thread safe map using sharding lock
type ConcurrentDict struct {
table []*shard
count int32
shardCount int
}
type shard struct {
m map[string]interface{}
mutex sync.RWMutex
}
func computeCapacity(param int) (size int) {
if param <= 16 {
return 16
}
n := param - 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
if n < 0 {
return math.MaxInt32
}
return n + 1
}
// MakeConcurrent creates ConcurrentDict with the given shard count
func MakeConcurrent(shardCount int) *ConcurrentDict {
if shardCount == 1 {
table := []*shard{
{
m: make(map[string]interface{}),
},
}
return &ConcurrentDict{
count: 0,
table: table,
shardCount: shardCount,
}
}
shardCount = computeCapacity(shardCount)
table := make([]*shard, shardCount)
for i := 0; i < shardCount; i++ {
table[i] = &shard{
m: make(map[string]interface{}),
}
}
d := &ConcurrentDict{
count: 0,
table: table,
shardCount: shardCount,
}
return d
}
const prime32 = uint32(16777619)
func fnv32(key string) uint32 {
hash := uint32(2166136261)
for i := 0; i < len(key); i++ {
hash *= prime32
hash ^= uint32(key[i])
}
return hash
}
func (dict *ConcurrentDict) spread(key string) uint32 {
if dict == nil {
panic("dict is nil")
}
if len(dict.table) == 1 {
return 0
}
hashCode := fnv32(key)
tableSize := uint32(len(dict.table))
return (tableSize - 1) & hashCode
}
func (dict *ConcurrentDict) getShard(index uint32) *shard {
if dict == nil {
panic("dict is nil")
}
return dict.table[index]
}
// Get returns the binding value and whether the key is exist
func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
s.mutex.Lock()
defer s.mutex.Unlock()
val, exists = s.m[key]
return
}
func (dict *ConcurrentDict) GetWithLock(key string) (val interface{}, exists bool) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
val, exists = s.m[key]
return
}
// Len returns the number of dict
func (dict *ConcurrentDict) Len() int {
if dict == nil {
panic("dict is nil")
}
return int(atomic.LoadInt32(&dict.count))
}
// Put puts key value into dict and returns the number of new inserted key-value
func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.m[key]; ok {
s.m[key] = val
return 0
}
dict.addCount()
s.m[key] = val
return 1
}
func (dict *ConcurrentDict) PutWithLock(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
if _, ok := s.m[key]; ok {
s.m[key] = val
return 0
}
dict.addCount()
s.m[key] = val
return 1
}
// PutIfAbsent puts value if the key is not exists and returns the number of updated key-value
func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.m[key]; ok {
return 0
}
s.m[key] = val
dict.addCount()
return 1
}
func (dict *ConcurrentDict) PutIfAbsentWithLock(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
if _, ok := s.m[key]; ok {
return 0
}
s.m[key] = val
dict.addCount()
return 1
}
// PutIfExists puts value if the key is existed and returns the number of inserted key-value
func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
s.mutex.Lock()
defer s.mutex.Unlock()
if _, ok := s.m[key]; ok {
s.m[key] = val
return 1
}
return 0
}
func (dict *ConcurrentDict) PutIfExistsWithLock(key string, val interface{}) (result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
if _, ok := s.m[key]; ok {
s.m[key] = val
return 1
}
return 0
}
// Remove removes the key and return the number of deleted key-value
func (dict *ConcurrentDict) Remove(key string) (val interface{}, result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
s.mutex.Lock()
defer s.mutex.Unlock()
if val, ok := s.m[key]; ok {
delete(s.m, key)
dict.decreaseCount()
return val, 1
}
return nil, 0
}
func (dict *ConcurrentDict) RemoveWithLock(key string) (val interface{}, result int) {
if dict == nil {
panic("dict is nil")
}
index := dict.spread(key)
s := dict.getShard(index)
if val, ok := s.m[key]; ok {
delete(s.m, key)
dict.decreaseCount()
return val, 1
}
return val, 0
}
func (dict *ConcurrentDict) addCount() int32 {
return atomic.AddInt32(&dict.count, 1)
}
func (dict *ConcurrentDict) decreaseCount() int32 {
return atomic.AddInt32(&dict.count, -1)
}
// ForEach traversal the dict
// it may not visit new entry inserted during traversal
func (dict *ConcurrentDict) ForEach(consumer Consumer) {
if dict == nil {
panic("dict is nil")
}
for _, s := range dict.table {
s.mutex.RLock()
f := func() bool {
defer s.mutex.RUnlock()
for key, value := range s.m {
continues := consumer(key, value)
if !continues {
return false
}
}
return true
}
if !f() {
break
}
}
}
// Keys returns all keys in dict
func (dict *ConcurrentDict) Keys() []string {
keys := make([]string, dict.Len())
i := 0
dict.ForEach(func(key string, val interface{}) bool {
if i < len(keys) {
keys[i] = key
i++
} else {
keys = append(keys, key)
}
return true
})
return keys
}
// RandomKey returns a key randomly
func (shard *shard) RandomKey() string {
if shard == nil {
panic("shard is nil")
}
shard.mutex.RLock()
defer shard.mutex.RUnlock()
for key := range shard.m {
return key
}
return ""
}
// RandomKeys randomly returns keys of the given number, may contain duplicated key
func (dict *ConcurrentDict) RandomKeys(limit int) []string {
size := dict.Len()
if limit >= size {
return dict.Keys()
}
shardCount := len(dict.table)
result := make([]string, limit)
nR := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < limit; {
s := dict.getShard(uint32(nR.Intn(shardCount)))
if s == nil {
continue
}
key := s.RandomKey()
if key != "" {
result[i] = key
i++
}
}
return result
}
// RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key
func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string {
size := dict.Len()
if limit >= size {
return dict.Keys()
}
shardCount := len(dict.table)
result := make(map[string]struct{})
nR := rand.New(rand.NewSource(time.Now().UnixNano()))
for len(result) < limit {
shardIndex := uint32(nR.Intn(shardCount))
s := dict.getShard(shardIndex)
if s == nil {
continue
}
key := s.RandomKey()
if key != "" {
if _, exists := result[key]; !exists {
result[key] = struct{}{}
}
}
}
arr := make([]string, limit)
i := 0
for k := range result {
arr[i] = k
i++
}
return arr
}
// Clear removes all keys in dict
func (dict *ConcurrentDict) Clear() {
*dict = *MakeConcurrent(dict.shardCount)
}
func (dict *ConcurrentDict) toLockIndices(keys []string, reverse bool) []uint32 {
indexMap := make(map[uint32]struct{})
for _, key := range keys {
index := dict.spread(key)
indexMap[index] = struct{}{}
}
indices := make([]uint32, 0, len(indexMap))
for index := range indexMap {
indices = append(indices, index)
}
sort.Slice(indices, func(i, j int) bool {
if !reverse {
return indices[i] < indices[j]
}
return indices[i] > indices[j]
})
return indices
}
// RWLocks locks write keys and read keys together. allow duplicate keys
func (dict *ConcurrentDict) RWLocks(writeKeys []string, readKeys []string) {
keys := append(writeKeys, readKeys...)
indices := dict.toLockIndices(keys, false)
writeIndexSet := make(map[uint32]struct{})
for _, wKey := range writeKeys {
idx := dict.spread(wKey)
writeIndexSet[idx] = struct{}{}
}
for _, index := range indices {
_, w := writeIndexSet[index]
mu := &dict.table[index].mutex
if w {
mu.Lock()
} else {
mu.RLock()
}
}
}
// RWUnLocks unlocks write keys and read keys together. allow duplicate keys
func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) {
keys := append(writeKeys, readKeys...)
indices := dict.toLockIndices(keys, true)
writeIndexSet := make(map[uint32]struct{})
for _, wKey := range writeKeys {
idx := dict.spread(wKey)
writeIndexSet[idx] = struct{}{}
}
for _, index := range indices {
_, w := writeIndexSet[index]
mu := &dict.table[index].mutex
if w {
mu.Unlock()
} else {
mu.RUnlock()
}
}
}
func stringsToBytes(strSlice []string) [][]byte {
byteSlice := make([][]byte, len(strSlice))
for i, str := range strSlice {
byteSlice[i] = []byte(str)
}
return byteSlice
}
func (dict *ConcurrentDict) DictScan(cursor int, count int, pattern string) ([][]byte, int) {
size := dict.Len()
result := make([][]byte, 0)
if pattern == "*" && count >= size {
return stringsToBytes(dict.Keys()), 0
}
matchKey, err := wildcard.CompilePattern(pattern)
if err != nil {
return result, -1
}
shardCount := len(dict.table)
shardIndex := cursor
for shardIndex < shardCount {
shard := dict.table[shardIndex]
shard.mutex.RLock()
if len(result)+len(shard.m) > count && shardIndex > cursor {
shard.mutex.RUnlock()
return result, shardIndex
}
for key := range shard.m {
if pattern == "*" || matchKey.IsMatch(key) {
result = append(result, []byte(key))
}
}
shard.mutex.RUnlock()
shardIndex++
}
return result, 0
}