diff --git a/datastruct/lock/lock_map.go b/datastruct/lock/lock_map.go index ef81206..427d68c 100644 --- a/datastruct/lock/lock_map.go +++ b/datastruct/lock/lock_map.go @@ -126,3 +126,41 @@ func (locks *Locks) RUnLocks(keys ...string) { mu.RUnlock() } } + +func (locks *Locks) RWLocks(writeKeys []string, readKeys []string) { + keys := append(writeKeys, readKeys...) + indices := locks.toLockIndices(keys, false) + writeIndices := locks.toLockIndices(writeKeys, false) + writeIndexSet := make(map[uint32]struct{}) + for _, idx := range writeIndices { + writeIndexSet[idx] = struct{}{} + } + for _, index := range indices { + _, w := writeIndexSet[index] + mu := locks.table[index] + if w { + mu.Lock() + } else { + mu.RLock() + } + } +} + +func (locks *Locks) RWUnLocks(writeKeys []string, readKeys []string) { + keys := append(writeKeys, readKeys...) + indices := locks.toLockIndices(keys, true) + writeIndices := locks.toLockIndices(writeKeys, true) + writeIndexSet := make(map[uint32]struct{}) + for _, idx := range writeIndices { + writeIndexSet[idx] = struct{}{} + } + for _, index := range indices { + _, w := writeIndexSet[index] + mu := locks.table[index] + if w { + mu.Unlock() + } else { + mu.RUnlock() + } + } +} diff --git a/db.go b/db.go index 1fb1183..bc6f0e8 100644 --- a/db.go +++ b/db.go @@ -220,6 +220,16 @@ func (db *DB) RUnLocks(keys ...string) { db.locker.RUnLocks(keys...) } +// RWLocks lock keys for writing and reading +func (db *DB) RWLocks(writeKeys []string, readKeys []string) { + db.locker.RWLocks(writeKeys, readKeys) +} + +// RWUnLocks unlock keys for writing and reading +func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) { + db.locker.RWUnLocks(writeKeys, readKeys) +} + /* ---- TTL Functions ---- */ func genExpireTask(key string) string { diff --git a/set.go b/set.go index 0a2b449..0a31213 100644 --- a/set.go +++ b/set.go @@ -208,11 +208,8 @@ func execSInterStore(db *DB, args [][]byte) redis.Reply { } // lock - lockedKeySet := HashSet.Make(keys...) - lockedKeySet.Add(dest) - lockedKeys := lockedKeySet.ToSlice() - db.Locks(lockedKeys...) - defer db.UnLocks(lockedKeys...) + db.RWLocks([]string{dest}, keys) + defer db.RWUnLocks([]string{dest}, keys) var result *HashSet.Set for _, key := range keys { @@ -299,11 +296,8 @@ func execSUnionStore(db *DB, args [][]byte) redis.Reply { } // lock - lockedKeySet := HashSet.Make(keys...) - lockedKeySet.Add(dest) - lockedKeys := lockedKeySet.ToSlice() - db.Locks(lockedKeys...) - defer db.UnLocks(lockedKeys...) + db.RWLocks([]string{dest}, keys) + defer db.RWUnLocks([]string{dest}, keys) var result *HashSet.Set for _, key := range keys { @@ -397,11 +391,8 @@ func execSDiffStore(db *DB, args [][]byte) redis.Reply { } // lock - lockedKeySet := HashSet.Make(keys...) - lockedKeySet.Add(dest) - lockedKeys := lockedKeySet.ToSlice() - db.Locks(lockedKeys...) - defer db.UnLocks(lockedKeys...) + db.RWLocks([]string{dest}, keys) + defer db.RWUnLocks([]string{dest}, keys) var result *HashSet.Set for i, key := range keys {