refactor flush db

This commit is contained in:
finley
2022-07-11 22:42:17 +08:00
parent 461d1a7a3f
commit ab0754e2a5
6 changed files with 82 additions and 73 deletions

View File

@@ -20,7 +20,7 @@ import (
// MultiDB is a set of multiple database set
type MultiDB struct {
dbSet []*DB
dbSet []*atomic.Value // *DB
// handle publish/subscribe
hub *pubsub.Hub
@@ -39,11 +39,13 @@ func NewStandaloneServer() *MultiDB {
if config.Properties.Databases == 0 {
config.Properties.Databases = 16
}
mdb.dbSet = make([]*DB, config.Properties.Databases)
mdb.dbSet = make([]*atomic.Value, config.Properties.Databases)
for i := range mdb.dbSet {
singleDB := makeDB()
singleDB.index = i
mdb.dbSet[i] = singleDB
holder := &atomic.Value{}
holder.Store(singleDB)
mdb.dbSet[i] = holder
}
mdb.hub = pubsub.MakeHub()
validAof := false
@@ -56,8 +58,7 @@ func NewStandaloneServer() *MultiDB {
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
// avoid closure
singleDB := db
singleDB := db.Load().(*DB)
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
@@ -77,9 +78,11 @@ func NewStandaloneServer() *MultiDB {
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
func MakeBasicMultiDB() *MultiDB {
mdb := &MultiDB{}
mdb.dbSet = make([]*DB, config.Properties.Databases)
mdb.dbSet = make([]*atomic.Value, config.Properties.Databases)
for i := range mdb.dbSet {
mdb.dbSet[i] = makeBasicDB()
holder := &atomic.Value{}
holder.Store(makeBasicDB())
mdb.dbSet[i] = holder
}
return mdb
}
@@ -139,6 +142,14 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
return RewriteAOF(mdb, cmdLine[1:])
} else if cmdName == "flushall" {
return mdb.flushAll()
} else if cmdName == "flushdb" {
if !validateArity(1, cmdLine) {
return protocol.MakeArgNumErrReply(cmdName)
}
if c.InMultiState() {
return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI")
}
return mdb.flushDB(c.GetDBIndex())
} else if cmdName == "save" {
return SaveRDB(mdb, cmdLine[1:])
} else if cmdName == "bgsave" {
@@ -161,10 +172,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
// normal commands
dbIndex := c.GetDBIndex()
if dbIndex >= len(mdb.dbSet) {
return protocol.MakeErrReply("ERR DB index is out of range")
selectedDB, errReply := mdb.selectDB(dbIndex)
if errReply != nil {
return errReply
}
selectedDB := mdb.dbSet[dbIndex]
return selectedDB.Exec(c, cmdLine)
}
@@ -194,9 +205,29 @@ func execSelect(c redis.Connection, mdb *MultiDB, args [][]byte) redis.Reply {
return protocol.MakeOkReply()
}
func (mdb *MultiDB) flushDB(dbIndex int) redis.Reply {
if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
return protocol.MakeErrReply("ERR DB index is out of range")
}
newDB := makeDB()
mdb.loadDB(dbIndex, newDB)
return &protocol.OkReply{}
}
func (mdb *MultiDB) loadDB(dbIndex int, newDB *DB) redis.Reply {
if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
return protocol.MakeErrReply("ERR DB index is out of range")
}
oldDB := mdb.mustSelectDB(dbIndex)
newDB.index = dbIndex
newDB.addAof = oldDB.addAof // inherit oldDB
mdb.dbSet[dbIndex].Store(newDB)
return &protocol.OkReply{}
}
func (mdb *MultiDB) flushAll() redis.Reply {
for _, db := range mdb.dbSet {
db.Flush()
for i := range mdb.dbSet {
mdb.flushDB(i)
}
if mdb.aofHandler != nil {
mdb.aofHandler.AddAof(0, utils.ToCmdLine("FlushAll"))
@@ -204,48 +235,56 @@ func (mdb *MultiDB) flushAll() redis.Reply {
return &protocol.OkReply{}
}
func (mdb *MultiDB) selectDB(dbIndex int) *DB {
if dbIndex >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
func (mdb *MultiDB) selectDB(dbIndex int) (*DB, *protocol.StandardErrReply) {
if dbIndex >= len(mdb.dbSet) || dbIndex < 0 {
return nil, protocol.MakeErrReply("ERR DB index is out of range")
}
return mdb.dbSet[dbIndex]
return mdb.dbSet[dbIndex].Load().(*DB), nil
}
func (mdb *MultiDB) mustSelectDB(dbIndex int) *DB {
selectedDB, err := mdb.selectDB(dbIndex)
if err != nil {
panic(err)
}
return selectedDB
}
// ForEach traverses all the keys in the given database
func (mdb *MultiDB) ForEach(dbIndex int, cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
mdb.selectDB(dbIndex).ForEach(cb)
mdb.mustSelectDB(dbIndex).ForEach(cb)
}
// ExecMulti executes multi commands transaction Atomically and Isolated
func (mdb *MultiDB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply {
if conn.GetDBIndex() >= len(mdb.dbSet) {
return protocol.MakeErrReply("ERR DB index is out of range")
selectedDB, errReply := mdb.selectDB(conn.GetDBIndex())
if errReply != nil {
return errReply
}
db := mdb.dbSet[conn.GetDBIndex()]
return db.ExecMulti(conn, watching, cmdLines)
return selectedDB.ExecMulti(conn, watching, cmdLines)
}
// RWLocks lock keys for writing and reading
func (mdb *MultiDB) RWLocks(dbIndex int, writeKeys []string, readKeys []string) {
mdb.selectDB(dbIndex).RWLocks(writeKeys, readKeys)
mdb.mustSelectDB(dbIndex).RWLocks(writeKeys, readKeys)
}
// RWUnLocks unlock keys for writing and reading
func (mdb *MultiDB) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) {
mdb.selectDB(dbIndex).RWUnLocks(writeKeys, readKeys)
mdb.mustSelectDB(dbIndex).RWUnLocks(writeKeys, readKeys)
}
// GetUndoLogs return rollback commands
func (mdb *MultiDB) GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine {
return mdb.selectDB(dbIndex).GetUndoLogs(cmdLine)
return mdb.mustSelectDB(dbIndex).GetUndoLogs(cmdLine)
}
// ExecWithLock executes normal commands, invoker should provide locks
func (mdb *MultiDB) ExecWithLock(conn redis.Connection, cmdLine [][]byte) redis.Reply {
if conn.GetDBIndex() >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
db, errReply := mdb.selectDB(conn.GetDBIndex())
if errReply != nil {
return errReply
}
db := mdb.dbSet[conn.GetDBIndex()]
return db.execWithLock(cmdLine)
}
@@ -297,6 +336,6 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply {
// GetDBSize returns keys count and ttl key count
func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) {
db := mdb.selectDB(dbIndex)
db := mdb.mustSelectDB(dbIndex)
return db.data.Len(), db.ttlMap.Len()
}