diff --git a/database/database.go b/database/database.go index a790af7..0e0a7d5 100644 --- a/database/database.go +++ b/database/database.go @@ -20,7 +20,7 @@ import ( // MultiDB is a set of multiple database set type MultiDB struct { - dbSet []*DB + dbSet []*atomic.Value // *DB // handle publish/subscribe hub *pubsub.Hub @@ -39,11 +39,13 @@ func NewStandaloneServer() *MultiDB { if config.Properties.Databases == 0 { config.Properties.Databases = 16 } - mdb.dbSet = make([]*DB, config.Properties.Databases) + mdb.dbSet = make([]*atomic.Value, config.Properties.Databases) for i := range mdb.dbSet { singleDB := makeDB() singleDB.index = i - mdb.dbSet[i] = singleDB + holder := &atomic.Value{} + holder.Store(singleDB) + mdb.dbSet[i] = holder } mdb.hub = pubsub.MakeHub() validAof := false @@ -56,8 +58,7 @@ func NewStandaloneServer() *MultiDB { } mdb.aofHandler = aofHandler for _, db := range mdb.dbSet { - // avoid closure - singleDB := db + singleDB := db.Load().(*DB) singleDB.addAof = func(line CmdLine) { mdb.aofHandler.AddAof(singleDB.index, line) } @@ -77,9 +78,11 @@ func NewStandaloneServer() *MultiDB { // MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages func MakeBasicMultiDB() *MultiDB { mdb := &MultiDB{} - mdb.dbSet = make([]*DB, config.Properties.Databases) + mdb.dbSet = make([]*atomic.Value, config.Properties.Databases) for i := range mdb.dbSet { - mdb.dbSet[i] = makeBasicDB() + holder := &atomic.Value{} + holder.Store(makeBasicDB()) + mdb.dbSet[i] = holder } return mdb } @@ -139,6 +142,14 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep return RewriteAOF(mdb, cmdLine[1:]) } else if cmdName == "flushall" { return mdb.flushAll() + } else if cmdName == "flushdb" { + if !validateArity(1, cmdLine) { + return protocol.MakeArgNumErrReply(cmdName) + } + if c.InMultiState() { + return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI") + } + return mdb.flushDB(c.GetDBIndex()) } else if cmdName == "save" { return SaveRDB(mdb, cmdLine[1:]) } else if cmdName == "bgsave" { @@ -161,10 +172,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep // normal commands dbIndex := c.GetDBIndex() - if dbIndex >= len(mdb.dbSet) { - return protocol.MakeErrReply("ERR DB index is out of range") + selectedDB, errReply := mdb.selectDB(dbIndex) + if errReply != nil { + return errReply } - selectedDB := mdb.dbSet[dbIndex] return selectedDB.Exec(c, cmdLine) } @@ -194,9 +205,29 @@ func execSelect(c redis.Connection, mdb *MultiDB, args [][]byte) redis.Reply { return protocol.MakeOkReply() } +func (mdb *MultiDB) flushDB(dbIndex int) redis.Reply { + if dbIndex >= len(mdb.dbSet) || dbIndex < 0 { + return protocol.MakeErrReply("ERR DB index is out of range") + } + newDB := makeDB() + mdb.loadDB(dbIndex, newDB) + return &protocol.OkReply{} +} + +func (mdb *MultiDB) loadDB(dbIndex int, newDB *DB) redis.Reply { + if dbIndex >= len(mdb.dbSet) || dbIndex < 0 { + return protocol.MakeErrReply("ERR DB index is out of range") + } + oldDB := mdb.mustSelectDB(dbIndex) + newDB.index = dbIndex + newDB.addAof = oldDB.addAof // inherit oldDB + mdb.dbSet[dbIndex].Store(newDB) + return &protocol.OkReply{} +} + func (mdb *MultiDB) flushAll() redis.Reply { - for _, db := range mdb.dbSet { - db.Flush() + for i := range mdb.dbSet { + mdb.flushDB(i) } if mdb.aofHandler != nil { mdb.aofHandler.AddAof(0, utils.ToCmdLine("FlushAll")) @@ -204,48 +235,56 @@ func (mdb *MultiDB) flushAll() redis.Reply { return &protocol.OkReply{} } -func (mdb *MultiDB) selectDB(dbIndex int) *DB { - if dbIndex >= len(mdb.dbSet) { - panic("ERR DB index is out of range") +func (mdb *MultiDB) selectDB(dbIndex int) (*DB, *protocol.StandardErrReply) { + if dbIndex >= len(mdb.dbSet) || dbIndex < 0 { + return nil, protocol.MakeErrReply("ERR DB index is out of range") } - return mdb.dbSet[dbIndex] + return mdb.dbSet[dbIndex].Load().(*DB), nil +} + +func (mdb *MultiDB) mustSelectDB(dbIndex int) *DB { + selectedDB, err := mdb.selectDB(dbIndex) + if err != nil { + panic(err) + } + return selectedDB } // ForEach traverses all the keys in the given database func (mdb *MultiDB) ForEach(dbIndex int, cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { - mdb.selectDB(dbIndex).ForEach(cb) + mdb.mustSelectDB(dbIndex).ForEach(cb) } // ExecMulti executes multi commands transaction Atomically and Isolated func (mdb *MultiDB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { - if conn.GetDBIndex() >= len(mdb.dbSet) { - return protocol.MakeErrReply("ERR DB index is out of range") + selectedDB, errReply := mdb.selectDB(conn.GetDBIndex()) + if errReply != nil { + return errReply } - db := mdb.dbSet[conn.GetDBIndex()] - return db.ExecMulti(conn, watching, cmdLines) + return selectedDB.ExecMulti(conn, watching, cmdLines) } // RWLocks lock keys for writing and reading func (mdb *MultiDB) RWLocks(dbIndex int, writeKeys []string, readKeys []string) { - mdb.selectDB(dbIndex).RWLocks(writeKeys, readKeys) + mdb.mustSelectDB(dbIndex).RWLocks(writeKeys, readKeys) } // RWUnLocks unlock keys for writing and reading func (mdb *MultiDB) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) { - mdb.selectDB(dbIndex).RWUnLocks(writeKeys, readKeys) + mdb.mustSelectDB(dbIndex).RWUnLocks(writeKeys, readKeys) } // GetUndoLogs return rollback commands func (mdb *MultiDB) GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine { - return mdb.selectDB(dbIndex).GetUndoLogs(cmdLine) + return mdb.mustSelectDB(dbIndex).GetUndoLogs(cmdLine) } // ExecWithLock executes normal commands, invoker should provide locks func (mdb *MultiDB) ExecWithLock(conn redis.Connection, cmdLine [][]byte) redis.Reply { - if conn.GetDBIndex() >= len(mdb.dbSet) { - panic("ERR DB index is out of range") + db, errReply := mdb.selectDB(conn.GetDBIndex()) + if errReply != nil { + return errReply } - db := mdb.dbSet[conn.GetDBIndex()] return db.execWithLock(cmdLine) } @@ -297,6 +336,6 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply { // GetDBSize returns keys count and ttl key count func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) { - db := mdb.selectDB(dbIndex) + db := mdb.mustSelectDB(dbIndex) return db.data.Len(), db.ttlMap.Len() } diff --git a/database/keys.go b/database/keys.go index c9095d6..dbb2210 100644 --- a/database/keys.go +++ b/database/keys.go @@ -51,6 +51,7 @@ func execExists(db *DB, args [][]byte) redis.Reply { } // execFlushDB removes all data in current db +// deprecated, use MultiDB.flushDB func execFlushDB(db *DB, args [][]byte) redis.Reply { db.Flush() db.addAof(utils.ToCmdLine3("flushdb", args...)) @@ -316,7 +317,7 @@ func undoExpire(db *DB, args [][]byte) []CmdLine { // This command copies the value stored at the source key to the destination key. func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply { dbIndex := conn.GetDBIndex() - db := mdb.dbSet[dbIndex] // Current DB + db := mdb.mustSelectDB(dbIndex) // Current DB replaceFlag := false srcKey := string(args[0]) destKey := string(args[1]) @@ -356,7 +357,7 @@ func execCopy(mdb *MultiDB, conn redis.Connection, args [][]byte) redis.Reply { return protocol.MakeIntReply(0) } - destDB := mdb.dbSet[dbIndex] + destDB := mdb.mustSelectDB(dbIndex) if _, exists = destDB.GetEntity(destKey); exists != false { // If destKey exists and there is no "replace" option if replaceFlag == false { diff --git a/database/rdb.go b/database/rdb.go index e80bbb4..b628780 100644 --- a/database/rdb.go +++ b/database/rdb.go @@ -31,7 +31,7 @@ func loadRdbFile(mdb *MultiDB) { func dumpRDB(dec *core.Decoder, mdb *MultiDB) error { return dec.Parse(func(o rdb.RedisObject) bool { - db := mdb.selectDB(o.GetDBIndex()) + db := mdb.mustSelectDB(o.GetDBIndex()) switch o.GetType() { case rdb.StringType: str := o.(*rdb.StringObject) diff --git a/database/replication.go b/database/replication.go index 72d7cfd..ccc59d9 100644 --- a/database/replication.go +++ b/database/replication.go @@ -303,9 +303,9 @@ func (mdb *MultiDB) doPsync() error { } logger.Info("full resync from master: " + mdb.replication.replId) logger.Info("current offset:", mdb.replication.replOffset) - for i, newDB := range rdbHolder.dbSet { - oldDB := mdb.selectDB(i) - oldDB.Load(newDB) + for i, h := range rdbHolder.dbSet { + newDB := h.Load().(*DB) + mdb.loadDB(i, newDB) } // there is no CRLF between RDB and following AOF, reset stream to avoid parser error mdb.replication.masterChan = parser.ParseStream(mdb.replication.masterConn) diff --git a/database/replication_test.go b/database/replication_test.go index c2c19a5..537631b 100644 --- a/database/replication_test.go +++ b/database/replication_test.go @@ -8,17 +8,20 @@ import ( "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol/asserts" + "sync/atomic" "testing" "time" ) func TestReplication(t *testing.T) { mdb := &MultiDB{} - mdb.dbSet = make([]*DB, 16) + mdb.dbSet = make([]*atomic.Value, 16) for i := range mdb.dbSet { singleDB := makeDB() singleDB.index = i - mdb.dbSet[i] = singleDB + holder := &atomic.Value{} + holder.Store(singleDB) + mdb.dbSet[i] = holder } mdb.replication = initReplStatus() masterCli, err := client.MakeClient("127.0.0.1:6379") diff --git a/database/single_db.go b/database/single_db.go index e5cf34f..5e402a5 100644 --- a/database/single_db.go +++ b/database/single_db.go @@ -10,7 +10,6 @@ import ( "github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/redis/protocol" "strings" - "sync" "time" ) @@ -33,9 +32,7 @@ type DB struct { // dict.Dict will ensure concurrent-safety of its method // use this mutex for complicated command only, eg. rpush, incr ... locker *lock.Locks - // stop all data access for execFlushDB - stopWorld sync.WaitGroup - addAof func(CmdLine) + addAof func(CmdLine) } // ExecFunc is interface for command executor @@ -102,14 +99,6 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply { return protocol.MakeArgNumErrReply(cmdName) } return Watch(db, c, cmdLine[1:]) - } else if cmdName == "flushdb" { - if !validateArity(1, cmdLine) { - return protocol.MakeArgNumErrReply(cmdName) - } - if c.InMultiState() { - return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI") - } - return execFlushDB(db, cmdLine[1:]) } if c != nil && c.InMultiState() { EnqueueCmd(c, cmdLine) @@ -164,8 +153,6 @@ func validateArity(arity int, cmdArgs [][]byte) bool { // GetEntity returns DataEntity bind to given key func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { - db.stopWorld.Wait() - raw, ok := db.data.Get(key) if !ok { return nil, false @@ -179,25 +166,21 @@ func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { // PutEntity a DataEntity into DB func (db *DB) PutEntity(key string, entity *database.DataEntity) int { - db.stopWorld.Wait() return db.data.Put(key, entity) } // PutIfExists edit an existing DataEntity func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { - db.stopWorld.Wait() return db.data.PutIfExists(key, entity) } // PutIfAbsent insert an DataEntity only if the key not exists func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { - db.stopWorld.Wait() return db.data.PutIfAbsent(key, entity) } // Remove the given key from db func (db *DB) Remove(key string) { - db.stopWorld.Wait() db.data.Remove(key) db.ttlMap.Remove(key) taskKey := genExpireTask(key) @@ -206,7 +189,6 @@ func (db *DB) Remove(key string) { // Removes the given keys from db func (db *DB) Removes(keys ...string) (deleted int) { - db.stopWorld.Wait() deleted = 0 for _, key := range keys { _, exists := db.data.Get(key) @@ -219,24 +201,14 @@ func (db *DB) Removes(keys ...string) (deleted int) { } // Flush clean database +// deprecated +// for test only func (db *DB) Flush() { - db.stopWorld.Add(1) - defer db.stopWorld.Done() - db.data.Clear() db.ttlMap.Clear() db.locker = lock.Make(lockerSize) } -func (db *DB) Load(db2 *DB) { - db.stopWorld.Add(1) - defer db.stopWorld.Done() - - db.data = db2.data - db.ttlMap = db2.ttlMap - db.locker = lock.Make(lockerSize) -} - /* ---- Lock Function ----- */ // RWLocks lock keys for writing and reading @@ -257,7 +229,6 @@ func genExpireTask(key string) string { // Expire sets ttlCmd of key func (db *DB) Expire(key string, expireTime time.Time) { - db.stopWorld.Wait() db.ttlMap.Put(key, expireTime) taskKey := genExpireTask(key) timewheel.At(expireTime, taskKey, func() { @@ -280,7 +251,6 @@ func (db *DB) Expire(key string, expireTime time.Time) { // Persist cancel ttlCmd of key func (db *DB) Persist(key string) { - db.stopWorld.Wait() db.ttlMap.Remove(key) taskKey := genExpireTask(key) timewheel.Cancel(taskKey) @@ -288,7 +258,6 @@ func (db *DB) Persist(key string) { // IsExpired check whether a key is expired func (db *DB) IsExpired(key string) bool { - db.stopWorld.Wait() rawExpireTime, ok := db.ttlMap.Get(key) if !ok { return false @@ -304,7 +273,6 @@ func (db *DB) IsExpired(key string) bool { /* --- add version --- */ func (db *DB) addVersion(keys ...string) { - db.stopWorld.Wait() for _, key := range keys { versionCode := db.GetVersion(key) db.versionMap.Put(key, versionCode+1) @@ -313,7 +281,6 @@ func (db *DB) addVersion(keys ...string) { // GetVersion returns version code for given key func (db *DB) GetVersion(key string) uint32 { - db.stopWorld.Wait() entity, ok := db.versionMap.Get(key) if !ok { return 0 @@ -323,7 +290,6 @@ func (db *DB) GetVersion(key string) uint32 { // ForEach traverses all the keys in the database func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { - db.stopWorld.Wait() db.data.ForEach(func(key string, raw interface{}) bool { entity, _ := raw.(*database.DataEntity) var expiration *time.Time