refactor flush db

This commit is contained in:
finley
2022-07-11 22:42:17 +08:00
parent 461d1a7a3f
commit ab0754e2a5
6 changed files with 82 additions and 73 deletions

View File

@@ -20,7 +20,7 @@ import (
// MultiDB is a set of multiple database set // MultiDB is a set of multiple database set
type MultiDB struct { type MultiDB struct {
dbSet []*DB dbSet []*atomic.Value // *DB
// handle publish/subscribe // handle publish/subscribe
hub *pubsub.Hub hub *pubsub.Hub
@@ -39,11 +39,13 @@ func NewStandaloneServer() *MultiDB {
if config.Properties.Databases == 0 { if config.Properties.Databases == 0 {
config.Properties.Databases = 16 config.Properties.Databases = 16
} }
mdb.dbSet = make([]*DB, config.Properties.Databases) mdb.dbSet = make([]*atomic.Value, config.Properties.Databases)
for i := range mdb.dbSet { for i := range mdb.dbSet {
singleDB := makeDB() singleDB := makeDB()
singleDB.index = i singleDB.index = i
mdb.dbSet[i] = singleDB holder := &atomic.Value{}
holder.Store(singleDB)
mdb.dbSet[i] = holder
} }
mdb.hub = pubsub.MakeHub() mdb.hub = pubsub.MakeHub()
validAof := false validAof := false
@@ -56,8 +58,7 @@ func NewStandaloneServer() *MultiDB {
} }
mdb.aofHandler = aofHandler mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet { for _, db := range mdb.dbSet {
// avoid closure singleDB := db.Load().(*DB)
singleDB := db
singleDB.addAof = func(line CmdLine) { singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line) mdb.aofHandler.AddAof(singleDB.index, line)
} }
@@ -77,9 +78,11 @@ func NewStandaloneServer() *MultiDB {
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages // MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
func MakeBasicMultiDB() *MultiDB { func MakeBasicMultiDB() *MultiDB {
mdb := &MultiDB{} mdb := &MultiDB{}
mdb.dbSet = make([]*DB, config.Properties.Databases) mdb.dbSet = make([]*atomic.Value, config.Properties.Databases)
for i := range mdb.dbSet { for i := range mdb.dbSet {
mdb.dbSet[i] = makeBasicDB() holder := &atomic.Value{}
holder.Store(makeBasicDB())
mdb.dbSet[i] = holder
} }
return mdb return mdb
} }
@@ -139,6 +142,14 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
return RewriteAOF(mdb, cmdLine[1:]) return RewriteAOF(mdb, cmdLine[1:])
} else if cmdName == "flushall" { } else if cmdName == "flushall" {
return mdb.flushAll() return mdb.flushAll()
} else if cmdName == "flushdb" {
if !validateArity(1, cmdLine) {
return protocol.MakeArgNumErrReply(cmdName)
}
if c.InMultiState() {
return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI")
}
return mdb.flushDB(c.GetDBIndex())
} else if cmdName == "save" { } else if cmdName == "save" {
return SaveRDB(mdb, cmdLine[1:]) return SaveRDB(mdb, cmdLine[1:])
} else if cmdName == "bgsave" { } else if cmdName == "bgsave" {
@@ -161,10 +172,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
// normal commands // normal commands
dbIndex := c.GetDBIndex() dbIndex := c.GetDBIndex()
if dbIndex >= len(mdb.dbSet) { selectedDB, errReply := mdb.selectDB(dbIndex)
return protocol.MakeErrReply("ERR DB index is out of range") if errReply != nil {
return errReply
} }
selectedDB := mdb.dbSet[dbIndex]
return selectedDB.Exec(c, cmdLine) return selectedDB.Exec(c, cmdLine)
} }
@@ -194,9 +205,29 @@ func execSelect(c redis.Connection, mdb *MultiDB, args [][]byte) redis.Reply {
return protocol.MakeOkReply() return protocol.MakeOkReply()
} }
func (mdb *MultiDB) flushDB(dbIndex int) redis.Reply {
if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
return protocol.MakeErrReply("ERR DB index is out of range")
}
newDB := makeDB()
mdb.loadDB(dbIndex, newDB)
return &protocol.OkReply{}
}
func (mdb *MultiDB) loadDB(dbIndex int, newDB *DB) redis.Reply {
if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
return protocol.MakeErrReply("ERR DB index is out of range")
}
oldDB := mdb.mustSelectDB(dbIndex)
newDB.index = dbIndex
newDB.addAof = oldDB.addAof // inherit oldDB
mdb.dbSet[dbIndex].Store(newDB)
return &protocol.OkReply{}
}
func (mdb *MultiDB) flushAll() redis.Reply { func (mdb *MultiDB) flushAll() redis.Reply {
for _, db := range mdb.dbSet { for i := range mdb.dbSet {
db.Flush() mdb.flushDB(i)
} }
if mdb.aofHandler != nil { if mdb.aofHandler != nil {
mdb.aofHandler.AddAof(0, utils.ToCmdLine("FlushAll")) mdb.aofHandler.AddAof(0, utils.ToCmdLine("FlushAll"))
@@ -204,48 +235,56 @@ func (mdb *MultiDB) flushAll() redis.Reply {
return &protocol.OkReply{} return &protocol.OkReply{}
} }
func (mdb *MultiDB) selectDB(dbIndex int) *DB { func (mdb *MultiDB) selectDB(dbIndex int) (*DB, *protocol.StandardErrReply) {
if dbIndex >= len(mdb.dbSet) { if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
panic("ERR DB index is out of range") return nil, protocol.MakeErrReply("ERR DB index is out of range")
} }
return mdb.dbSet[dbIndex] return mdb.dbSet[dbIndex].Load().(*DB), nil
}
func (mdb *MultiDB) mustSelectDB(dbIndex int) *DB {
selectedDB, err := mdb.selectDB(dbIndex)
if err != nil {
panic(err)
}
return selectedDB
} }
// ForEach traverses all the keys in the given database // ForEach traverses all the keys in the given database
func (mdb *MultiDB) ForEach(dbIndex int, cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { func (mdb *MultiDB) ForEach(dbIndex int, cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
mdb.selectDB(dbIndex).ForEach(cb) mdb.mustSelectDB(dbIndex).ForEach(cb)
} }
// ExecMulti executes multi commands transaction Atomically and Isolated // ExecMulti executes multi commands transaction Atomically and Isolated
func (mdb *MultiDB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { func (mdb *MultiDB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply {
if conn.GetDBIndex() >= len(mdb.dbSet) { selectedDB, errReply := mdb.selectDB(conn.GetDBIndex())
return protocol.MakeErrReply("ERR DB index is out of range") if errReply != nil {
return errReply
} }
db := mdb.dbSet[conn.GetDBIndex()] return selectedDB.ExecMulti(conn, watching, cmdLines)
return db.ExecMulti(conn, watching, cmdLines)
} }
// RWLocks lock keys for writing and reading // RWLocks lock keys for writing and reading
func (mdb *MultiDB) RWLocks(dbIndex int, writeKeys []string, readKeys []string) { func (mdb *MultiDB) RWLocks(dbIndex int, writeKeys []string, readKeys []string) {
mdb.selectDB(dbIndex).RWLocks(writeKeys, readKeys) mdb.mustSelectDB(dbIndex).RWLocks(writeKeys, readKeys)
} }
// RWUnLocks unlock keys for writing and reading // RWUnLocks unlock keys for writing and reading
func (mdb *MultiDB) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) { func (mdb *MultiDB) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) {
mdb.selectDB(dbIndex).RWUnLocks(writeKeys, readKeys) mdb.mustSelectDB(dbIndex).RWUnLocks(writeKeys, readKeys)
} }
// GetUndoLogs return rollback commands // GetUndoLogs return rollback commands
func (mdb *MultiDB) GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine { func (mdb *MultiDB) GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine {
return mdb.selectDB(dbIndex).GetUndoLogs(cmdLine) return mdb.mustSelectDB(dbIndex).GetUndoLogs(cmdLine)
} }
// ExecWithLock executes normal commands, invoker should provide locks // ExecWithLock executes normal commands, invoker should provide locks
func (mdb *MultiDB) ExecWithLock(conn redis.Connection, cmdLine [][]byte) redis.Reply { func (mdb *MultiDB) ExecWithLock(conn redis.Connection, cmdLine [][]byte) redis.Reply {
if conn.GetDBIndex() >= len(mdb.dbSet) { db, errReply := mdb.selectDB(conn.GetDBIndex())
panic("ERR DB index is out of range") if errReply != nil {
return errReply
} }
db := mdb.dbSet[conn.GetDBIndex()]
return db.execWithLock(cmdLine) return db.execWithLock(cmdLine)
} }
@@ -297,6 +336,6 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply {
// GetDBSize returns keys count and ttl key count // GetDBSize returns keys count and ttl key count
func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) { func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) {
db := mdb.selectDB(dbIndex) db := mdb.mustSelectDB(dbIndex)
return db.data.Len(), db.ttlMap.Len() return db.data.Len(), db.ttlMap.Len()
} }

View File

@@ -51,6 +51,7 @@ func execExists(db *DB, args [][]byte) redis.Reply {
} }
// execFlushDB removes all data in current db // execFlushDB removes all data in current db
// deprecated, use MultiDB.flushDB
func execFlushDB(db *DB, args [][]byte) redis.Reply { func execFlushDB(db *DB, args [][]byte) redis.Reply {
db.Flush() db.Flush()
db.addAof(utils.ToCmdLine3("flushdb", args...)) db.addAof(utils.ToCmdLine3("flushdb", args...))
@@ -316,7 +317,7 @@ func undoExpire(db *DB, args [][]byte) []CmdLine {
// This command copies the value stored at the source key to the destination key. // This command copies the value stored at the source key to the destination key.
func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply { func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply {
dbIndex := conn.GetDBIndex() dbIndex := conn.GetDBIndex()
db := mdb.dbSet[dbIndex] // Current DB db := mdb.mustSelectDB(dbIndex) // Current DB
replaceFlag := false replaceFlag := false
srcKey := string(args[0]) srcKey := string(args[0])
destKey := string(args[1]) destKey := string(args[1])
@@ -356,7 +357,7 @@ func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply {
return protocol.MakeIntReply(0) return protocol.MakeIntReply(0)
} }
destDB := mdb.dbSet[dbIndex] destDB := mdb.mustSelectDB(dbIndex)
if _, exists = destDB.GetEntity(destKey); exists != false { if _, exists = destDB.GetEntity(destKey); exists != false {
// If destKey exists and there is no "replace" option // If destKey exists and there is no "replace" option
if replaceFlag == false { if replaceFlag == false {

View File

@@ -31,7 +31,7 @@ func loadRdbFile(mdb *MultiDB) {
func dumpRDB(dec *core.Decoder, mdb *MultiDB) error { func dumpRDB(dec *core.Decoder, mdb *MultiDB) error {
return dec.Parse(func(o rdb.RedisObject) bool { return dec.Parse(func(o rdb.RedisObject) bool {
db := mdb.selectDB(o.GetDBIndex()) db := mdb.mustSelectDB(o.GetDBIndex())
switch o.GetType() { switch o.GetType() {
case rdb.StringType: case rdb.StringType:
str := o.(*rdb.StringObject) str := o.(*rdb.StringObject)

View File

@@ -303,9 +303,9 @@ func (mdb *MultiDB) doPsync() error {
} }
logger.Info("full resync from master: " + mdb.replication.replId) logger.Info("full resync from master: " + mdb.replication.replId)
logger.Info("current offset:", mdb.replication.replOffset) logger.Info("current offset:", mdb.replication.replOffset)
for i, newDB := range rdbHolder.dbSet { for i, h := range rdbHolder.dbSet {
oldDB := mdb.selectDB(i) newDB := h.Load().(*DB)
oldDB.Load(newDB) mdb.loadDB(i, newDB)
} }
// there is no CRLF between RDB and following AOF, reset stream to avoid parser error // there is no CRLF between RDB and following AOF, reset stream to avoid parser error
mdb.replication.masterChan = parser.ParseStream(mdb.replication.masterConn) mdb.replication.masterChan = parser.ParseStream(mdb.replication.masterConn)

View File

@@ -8,17 +8,20 @@ import (
"github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts" "github.com/hdt3213/godis/redis/protocol/asserts"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
func TestReplication(t *testing.T) { func TestReplication(t *testing.T) {
mdb := &MultiDB{} mdb := &MultiDB{}
mdb.dbSet = make([]*DB, 16) mdb.dbSet = make([]*atomic.Value, 16)
for i := range mdb.dbSet { for i := range mdb.dbSet {
singleDB := makeDB() singleDB := makeDB()
singleDB.index = i singleDB.index = i
mdb.dbSet[i] = singleDB holder := &atomic.Value{}
holder.Store(singleDB)
mdb.dbSet[i] = holder
} }
mdb.replication = initReplStatus() mdb.replication = initReplStatus()
masterCli, err := client.MakeClient("127.0.0.1:6379") masterCli, err := client.MakeClient("127.0.0.1:6379")

View File

@@ -10,7 +10,6 @@ import (
"github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/lib/timewheel"
"github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol"
"strings" "strings"
"sync"
"time" "time"
) )
@@ -33,8 +32,6 @@ type DB struct {
// dict.Dict will ensure concurrent-safety of its method // dict.Dict will ensure concurrent-safety of its method
// use this mutex for complicated command only, eg. rpush, incr ... // use this mutex for complicated command only, eg. rpush, incr ...
locker *lock.Locks locker *lock.Locks
// stop all data access for execFlushDB
stopWorld sync.WaitGroup
addAof func(CmdLine) addAof func(CmdLine)
} }
@@ -102,14 +99,6 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
return protocol.MakeArgNumErrReply(cmdName) return protocol.MakeArgNumErrReply(cmdName)
} }
return Watch(db, c, cmdLine[1:]) return Watch(db, c, cmdLine[1:])
} else if cmdName == "flushdb" {
if !validateArity(1, cmdLine) {
return protocol.MakeArgNumErrReply(cmdName)
}
if c.InMultiState() {
return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI")
}
return execFlushDB(db, cmdLine[1:])
} }
if c != nil && c.InMultiState() { if c != nil && c.InMultiState() {
EnqueueCmd(c, cmdLine) EnqueueCmd(c, cmdLine)
@@ -164,8 +153,6 @@ func validateArity(arity int, cmdArgs [][]byte) bool {
// GetEntity returns DataEntity bind to given key // GetEntity returns DataEntity bind to given key
func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
db.stopWorld.Wait()
raw, ok := db.data.Get(key) raw, ok := db.data.Get(key)
if !ok { if !ok {
return nil, false return nil, false
@@ -179,25 +166,21 @@ func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
// PutEntity a DataEntity into DB // PutEntity a DataEntity into DB
func (db *DB) PutEntity(key string, entity *database.DataEntity) int { func (db *DB) PutEntity(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.Put(key, entity) return db.data.Put(key, entity)
} }
// PutIfExists edit an existing DataEntity // PutIfExists edit an existing DataEntity
func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { func (db *DB) PutIfExists(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.PutIfExists(key, entity) return db.data.PutIfExists(key, entity)
} }
// PutIfAbsent insert an DataEntity only if the key not exists // PutIfAbsent insert an DataEntity only if the key not exists
func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int {
db.stopWorld.Wait()
return db.data.PutIfAbsent(key, entity) return db.data.PutIfAbsent(key, entity)
} }
// Remove the given key from db // Remove the given key from db
func (db *DB) Remove(key string) { func (db *DB) Remove(key string) {
db.stopWorld.Wait()
db.data.Remove(key) db.data.Remove(key)
db.ttlMap.Remove(key) db.ttlMap.Remove(key)
taskKey := genExpireTask(key) taskKey := genExpireTask(key)
@@ -206,7 +189,6 @@ func (db *DB) Remove(key string) {
// Removes the given keys from db // Removes the given keys from db
func (db *DB) Removes(keys ...string) (deleted int) { func (db *DB) Removes(keys ...string) (deleted int) {
db.stopWorld.Wait()
deleted = 0 deleted = 0
for _, key := range keys { for _, key := range keys {
_, exists := db.data.Get(key) _, exists := db.data.Get(key)
@@ -219,24 +201,14 @@ func (db *DB) Removes(keys ...string) (deleted int) {
} }
// Flush clean database // Flush clean database
// deprecated
// for test only
func (db *DB) Flush() { func (db *DB) Flush() {
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.data.Clear() db.data.Clear()
db.ttlMap.Clear() db.ttlMap.Clear()
db.locker = lock.Make(lockerSize) db.locker = lock.Make(lockerSize)
} }
func (db *DB) Load(db2 *DB) {
db.stopWorld.Add(1)
defer db.stopWorld.Done()
db.data = db2.data
db.ttlMap = db2.ttlMap
db.locker = lock.Make(lockerSize)
}
/* ---- Lock Function ----- */ /* ---- Lock Function ----- */
// RWLocks lock keys for writing and reading // RWLocks lock keys for writing and reading
@@ -257,7 +229,6 @@ func genExpireTask(key string) string {
// Expire sets ttlCmd of key // Expire sets ttlCmd of key
func (db *DB) Expire(key string, expireTime time.Time) { func (db *DB) Expire(key string, expireTime time.Time) {
db.stopWorld.Wait()
db.ttlMap.Put(key, expireTime) db.ttlMap.Put(key, expireTime)
taskKey := genExpireTask(key) taskKey := genExpireTask(key)
timewheel.At(expireTime, taskKey, func() { timewheel.At(expireTime, taskKey, func() {
@@ -280,7 +251,6 @@ func (db *DB) Expire(key string, expireTime time.Time) {
// Persist cancel ttlCmd of key // Persist cancel ttlCmd of key
func (db *DB) Persist(key string) { func (db *DB) Persist(key string) {
db.stopWorld.Wait()
db.ttlMap.Remove(key) db.ttlMap.Remove(key)
taskKey := genExpireTask(key) taskKey := genExpireTask(key)
timewheel.Cancel(taskKey) timewheel.Cancel(taskKey)
@@ -288,7 +258,6 @@ func (db *DB) Persist(key string) {
// IsExpired check whether a key is expired // IsExpired check whether a key is expired
func (db *DB) IsExpired(key string) bool { func (db *DB) IsExpired(key string) bool {
db.stopWorld.Wait()
rawExpireTime, ok := db.ttlMap.Get(key) rawExpireTime, ok := db.ttlMap.Get(key)
if !ok { if !ok {
return false return false
@@ -304,7 +273,6 @@ func (db *DB) IsExpired(key string) bool {
/* --- add version --- */ /* --- add version --- */
func (db *DB) addVersion(keys ...string) { func (db *DB) addVersion(keys ...string) {
db.stopWorld.Wait()
for _, key := range keys { for _, key := range keys {
versionCode := db.GetVersion(key) versionCode := db.GetVersion(key)
db.versionMap.Put(key, versionCode+1) db.versionMap.Put(key, versionCode+1)
@@ -313,7 +281,6 @@ func (db *DB) addVersion(keys ...string) {
// GetVersion returns version code for given key // GetVersion returns version code for given key
func (db *DB) GetVersion(key string) uint32 { func (db *DB) GetVersion(key string) uint32 {
db.stopWorld.Wait()
entity, ok := db.versionMap.Get(key) entity, ok := db.versionMap.Get(key)
if !ok { if !ok {
return 0 return 0
@@ -323,7 +290,6 @@ func (db *DB) GetVersion(key string) uint32 {
// ForEach traverses all the keys in the database // ForEach traverses all the keys in the database
func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
db.stopWorld.Wait()
db.data.ForEach(func(key string, raw interface{}) bool { db.data.ForEach(func(key string, raw interface{}) bool {
entity, _ := raw.(*database.DataEntity) entity, _ := raw.(*database.DataEntity)
var expiration *time.Time var expiration *time.Time