refactor flush db

This commit is contained in:
finley
2022-07-11 22:42:17 +08:00
parent 461d1a7a3f
commit ab0754e2a5
6 changed files with 82 additions and 73 deletions

View File

@@ -10,7 +10,6 @@ import (
"github.com/hdt3213/godis/lib/timewheel"
"github.com/hdt3213/godis/redis/protocol"
"strings"
"sync"
"time"
)
@@ -33,9 +32,7 @@ type DB struct {
// dict.Dict will ensure concurrent-safety of its method
// use this mutex for complicated command only, eg. rpush, incr ...
locker *lock.Locks
// stop all data access for execFlushDB
stopWorld sync.WaitGroup
addAof func(CmdLine)
addAof func(CmdLine)
}
// ExecFunc is interface for command executor
@@ -102,14 +99,6 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
return protocol.MakeArgNumErrReply(cmdName)
}
return Watch(db, c, cmdLine[1:])
} else if cmdName == "flushdb" {
if !validateArity(1, cmdLine) {
return protocol.MakeArgNumErrReply(cmdName)
}
if c.InMultiState() {
return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI")
}
return execFlushDB(db, cmdLine[1:])
}
if c != nil && c.InMultiState() {
EnqueueCmd(c, cmdLine)
@@ -164,8 +153,6 @@ func validateArity(arity int, cmdArgs [][]byte) bool {
// GetEntity returns DataEntity bind to given key
func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
db.stopWorld.Wait()
raw, ok := db.data.Get(key)
if !ok {
return nil, false
@@ -179,25 +166,21 @@ func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
// PutEntity a DataEntity into DB
func (db *DB) PutEntity(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.Put(key, entity)
}
// PutIfExists edit an existing DataEntity
func (db *DB) PutIfExists(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.PutIfExists(key, entity)
}
// PutIfAbsent insert an DataEntity only if the key not exists
func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.PutIfAbsent(key, entity)
}
// Remove the given key from db
func (db *DB) Remove(key string) {
db.stopWorld.Wait()
db.data.Remove(key)
db.ttlMap.Remove(key)
taskKey := genExpireTask(key)
@@ -206,7 +189,6 @@ func (db *DB) Remove(key string) {
// Removes the given keys from db
func (db *DB) Removes(keys ...string) (deleted int) {
db.stopWorld.Wait()
deleted = 0
for _, key := range keys {
_, exists := db.data.Get(key)
@@ -219,24 +201,14 @@ func (db *DB) Removes(keys ...string) (deleted int) {
}
// Flush clean database
// deprecated
// for test only
func (db *DB) Flush() {
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.data.Clear()
db.ttlMap.Clear()
db.locker = lock.Make(lockerSize)
}
func (db *DB) Load(db2 *DB) {
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.data = db2.data
db.ttlMap = db2.ttlMap
db.locker = lock.Make(lockerSize)
}
/* ---- Lock Function ----- */
// RWLocks lock keys for writing and reading
@@ -257,7 +229,6 @@ func genExpireTask(key string) string {
// Expire sets ttlCmd of key
func (db *DB) Expire(key string, expireTime time.Time) {
db.stopWorld.Wait()
db.ttlMap.Put(key, expireTime)
taskKey := genExpireTask(key)
timewheel.At(expireTime, taskKey, func() {
@@ -280,7 +251,6 @@ func (db *DB) Expire(key string, expireTime time.Time) {
// Persist cancel ttlCmd of key
func (db *DB) Persist(key string) {
db.stopWorld.Wait()
db.ttlMap.Remove(key)
taskKey := genExpireTask(key)
timewheel.Cancel(taskKey)
@@ -288,7 +258,6 @@ func (db *DB) Persist(key string) {
// IsExpired check whether a key is expired
func (db *DB) IsExpired(key string) bool {
db.stopWorld.Wait()
rawExpireTime, ok := db.ttlMap.Get(key)
if !ok {
return false
@@ -304,7 +273,6 @@ func (db *DB) IsExpired(key string) bool {
/* --- add version --- */
func (db *DB) addVersion(keys ...string) {
db.stopWorld.Wait()
for _, key := range keys {
versionCode := db.GetVersion(key)
db.versionMap.Put(key, versionCode+1)
@@ -313,7 +281,6 @@ func (db *DB) addVersion(keys ...string) {
// GetVersion returns version code for given key
func (db *DB) GetVersion(key string) uint32 {
db.stopWorld.Wait()
entity, ok := db.versionMap.Get(key)
if !ok {
return 0
@@ -323,7 +290,6 @@ func (db *DB) GetVersion(key string) uint32 {
// ForEach traverses all the keys in the database
func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
db.stopWorld.Wait()
db.data.ForEach(func(key string, raw interface{}) bool {
entity, _ := raw.(*database.DataEntity)
var expiration *time.Time