// Package database is a memory database with redis compatible interface package database import ( "github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/lock" "github.com/hdt3213/godis/interface/database" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/redis/protocol" "strings" "sync" "time" ) const ( dataDictSize = 1 << 16 ttlDictSize = 1 << 10 lockerSize = 1024 ) // DB stores data and execute user's commands type DB struct { index int // key -> DataEntity data dict.Dict // key -> expireTime (time.Time) ttlMap dict.Dict // key -> version(uint32) versionMap dict.Dict // dict.Dict will ensure concurrent-safety of its method // use this mutex for complicated command only, eg. rpush, incr ... locker *lock.Locks // stop all data access for execFlushDB stopWorld sync.WaitGroup addAof func(CmdLine) } // ExecFunc is interface for command executor // args don't include cmd line type ExecFunc func(db *DB, args [][]byte) redis.Reply // PreFunc analyses command line when queued command to `multi` // returns related write keys and read keys type PreFunc func(args [][]byte) ([]string, []string) // CmdLine is alias for [][]byte, represents a command line type CmdLine = [][]byte // UndoFunc returns undo logs for the given command line // execute from head to tail when undo type UndoFunc func(db *DB, args [][]byte) []CmdLine // makeDB create DB instance func makeDB() *DB { db := &DB{ data: dict.MakeConcurrent(dataDictSize), ttlMap: dict.MakeConcurrent(ttlDictSize), versionMap: dict.MakeConcurrent(dataDictSize), locker: lock.Make(lockerSize), addAof: func(line CmdLine) {}, } return db } // makeBasicDB create DB instance only with basic abilities. // It is not concurrent safe func makeBasicDB() *DB { db := &DB{ data: dict.MakeSimple(), ttlMap: dict.MakeSimple(), versionMap: dict.MakeSimple(), locker: lock.Make(1), addAof: func(line CmdLine) {}, } return db } // Exec executes command within one database func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply { // transaction control commands and other commands which cannot execute within transaction cmdName := strings.ToLower(string(cmdLine[0])) if cmdName == "multi" { if len(cmdLine) != 1 { return protocol.MakeArgNumErrReply(cmdName) } return StartMulti(c) } else if cmdName == "discard" { if len(cmdLine) != 1 { return protocol.MakeArgNumErrReply(cmdName) } return DiscardMulti(c) } else if cmdName == "exec" { if len(cmdLine) != 1 { return protocol.MakeArgNumErrReply(cmdName) } return execMulti(db, c) } else if cmdName == "watch" { if !validateArity(-2, cmdLine) { return protocol.MakeArgNumErrReply(cmdName) } return Watch(db, c, cmdLine[1:]) } else if cmdName == "flushdb" { if !validateArity(1, cmdLine) { return protocol.MakeArgNumErrReply(cmdName) } if c.InMultiState() { return protocol.MakeErrReply("ERR command 'FlushDB' cannot be used in MULTI") } return execFlushDB(db, cmdLine[1:]) } if c != nil && c.InMultiState() { EnqueueCmd(c, cmdLine) return protocol.MakeQueuedReply() } return db.execNormalCommand(cmdLine) } func (db *DB) execNormalCommand(cmdLine [][]byte) redis.Reply { cmdName := strings.ToLower(string(cmdLine[0])) cmd, ok := cmdTable[cmdName] if !ok { return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") } if !validateArity(cmd.arity, cmdLine) { return protocol.MakeArgNumErrReply(cmdName) } prepare := cmd.prepare write, read := prepare(cmdLine[1:]) db.addVersion(write...) db.RWLocks(write, read) defer db.RWUnLocks(write, read) fun := cmd.executor return fun(db, cmdLine[1:]) } // execWithLock executes normal commands, invoker should provide locks func (db *DB) execWithLock(cmdLine [][]byte) redis.Reply { cmdName := strings.ToLower(string(cmdLine[0])) cmd, ok := cmdTable[cmdName] if !ok { return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") } if !validateArity(cmd.arity, cmdLine) { return protocol.MakeArgNumErrReply(cmdName) } fun := cmd.executor return fun(db, cmdLine[1:]) } func validateArity(arity int, cmdArgs [][]byte) bool { argNum := len(cmdArgs) if arity >= 0 { return argNum == arity } return argNum >= -arity } /* ---- Data Access ----- */ // GetEntity returns DataEntity bind to given key func (db *DB) GetEntity(key string) (*database.DataEntity, bool) { db.stopWorld.Wait() raw, ok := db.data.Get(key) if !ok { return nil, false } if db.IsExpired(key) { return nil, false } entity, _ := raw.(*database.DataEntity) return entity, true } // PutEntity a DataEntity into DB func (db *DB) PutEntity(key string, entity *database.DataEntity) int { db.stopWorld.Wait() return db.data.Put(key, entity) } // PutIfExists edit an existing DataEntity func (db *DB) PutIfExists(key string, entity *database.DataEntity) int { db.stopWorld.Wait() return db.data.PutIfExists(key, entity) } // PutIfAbsent insert an DataEntity only if the key not exists func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int { db.stopWorld.Wait() return db.data.PutIfAbsent(key, entity) } // Remove the given key from db func (db *DB) Remove(key string) { db.stopWorld.Wait() db.data.Remove(key) db.ttlMap.Remove(key) taskKey := genExpireTask(key) timewheel.Cancel(taskKey) } // Removes the given keys from db func (db *DB) Removes(keys ...string) (deleted int) { db.stopWorld.Wait() deleted = 0 for _, key := range keys { _, exists := db.data.Get(key) if exists { db.Remove(key) deleted++ } } return deleted } // Flush clean database func (db *DB) Flush() { db.stopWorld.Add(1) defer db.stopWorld.Done() db.data.Clear() db.ttlMap.Clear() db.locker = lock.Make(lockerSize) } func (db *DB) Load(db2 *DB) { db.stopWorld.Add(1) defer db.stopWorld.Done() db.data = db2.data db.ttlMap = db2.ttlMap db.locker = lock.Make(lockerSize) } /* ---- Lock Function ----- */ // RWLocks lock keys for writing and reading func (db *DB) RWLocks(writeKeys []string, readKeys []string) { db.locker.RWLocks(writeKeys, readKeys) } // RWUnLocks unlock keys for writing and reading func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) { db.locker.RWUnLocks(writeKeys, readKeys) } /* ---- TTL Functions ---- */ func genExpireTask(key string) string { return "expire:" + key } // Expire sets ttlCmd of key func (db *DB) Expire(key string, expireTime time.Time) { db.stopWorld.Wait() db.ttlMap.Put(key, expireTime) taskKey := genExpireTask(key) timewheel.At(expireTime, taskKey, func() { keys := []string{key} db.RWLocks(keys, nil) defer db.RWUnLocks(keys, nil) // check-lock-check, ttl may be updated during waiting lock logger.Info("expire " + key) rawExpireTime, ok := db.ttlMap.Get(key) if !ok { return } expireTime, _ := rawExpireTime.(time.Time) expired := time.Now().After(expireTime) if expired { db.Remove(key) } }) } // Persist cancel ttlCmd of key func (db *DB) Persist(key string) { db.stopWorld.Wait() db.ttlMap.Remove(key) taskKey := genExpireTask(key) timewheel.Cancel(taskKey) } // IsExpired check whether a key is expired func (db *DB) IsExpired(key string) bool { db.stopWorld.Wait() rawExpireTime, ok := db.ttlMap.Get(key) if !ok { return false } expireTime, _ := rawExpireTime.(time.Time) expired := time.Now().After(expireTime) if expired { db.Remove(key) } return expired } /* --- add version --- */ func (db *DB) addVersion(keys ...string) { db.stopWorld.Wait() for _, key := range keys { versionCode := db.GetVersion(key) db.versionMap.Put(key, versionCode+1) } } // GetVersion returns version code for given key func (db *DB) GetVersion(key string) uint32 { db.stopWorld.Wait() entity, ok := db.versionMap.Get(key) if !ok { return 0 } return entity.(uint32) } // ForEach traverses all the keys in the database func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) { db.stopWorld.Wait() db.data.ForEach(func(key string, raw interface{}) bool { entity, _ := raw.(*database.DataEntity) var expiration *time.Time rawExpireTime, ok := db.ttlMap.Get(key) if ok { expireTime, _ := rawExpireTime.(time.Time) expiration = &expireTime } return cb(key, entity, expiration) }) }