mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 00:42:43 +08:00

# Conflicts: # cluster/cluster.go # cluster/router.go # config/config.go # database/database.go # database/server.go
320 lines
8.0 KiB
Go
320 lines
8.0 KiB
Go
// Package database is a memory database with redis compatible interface
|
|
package database
|
|
|
|
import (
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hdt3213/godis/datastruct/dict"
|
|
"github.com/hdt3213/godis/interface/database"
|
|
"github.com/hdt3213/godis/interface/redis"
|
|
"github.com/hdt3213/godis/lib/logger"
|
|
"github.com/hdt3213/godis/lib/timewheel"
|
|
"github.com/hdt3213/godis/redis/protocol"
|
|
)
|
|
|
|
const (
|
|
dataDictSize = 1 << 16
|
|
ttlDictSize = 1 << 10
|
|
)
|
|
|
|
// DB stores data and execute user's commands
|
|
type DB struct {
|
|
index int
|
|
// key -> DataEntity
|
|
data *dict.ConcurrentDict
|
|
// key -> expireTime (time.Time)
|
|
ttlMap *dict.ConcurrentDict
|
|
// key -> version(uint32)
|
|
versionMap *dict.ConcurrentDict
|
|
|
|
// addaof is used to add command to aof
|
|
addAof func(CmdLine)
|
|
|
|
// callbacks
|
|
insertCallback database.KeyEventCallback
|
|
deleteCallback database.KeyEventCallback
|
|
}
|
|
|
|
// ExecFunc is interface for command executor
|
|
// args don't include cmd line
|
|
type ExecFunc func(db *DB, args [][]byte) redis.Reply
|
|
|
|
// PreFunc analyses command line when queued command to `multi`
|
|
// returns related write keys and read keys
|
|
type PreFunc func(args [][]byte) ([]string, []string)
|
|
|
|
// CmdLine is alias for [][]byte, represents a command line
|
|
type CmdLine = [][]byte
|
|
|
|
// UndoFunc returns undo logs for the given command line
|
|
// execute from head to tail when undo
|
|
type UndoFunc func(db *DB, args [][]byte) []CmdLine
|
|
|
|
// makeDB create DB instance
|
|
func makeDB() *DB {
|
|
db := &DB{
|
|
data: dict.MakeConcurrent(dataDictSize),
|
|
ttlMap: dict.MakeConcurrent(ttlDictSize),
|
|
versionMap: dict.MakeConcurrent(dataDictSize),
|
|
addAof: func(line CmdLine) {},
|
|
}
|
|
return db
|
|
}
|
|
|
|
// makeBasicDB create DB instance only with basic abilities.
|
|
func makeBasicDB() *DB {
|
|
db := &DB{
|
|
data: dict.MakeConcurrent(dataDictSize),
|
|
ttlMap: dict.MakeConcurrent(ttlDictSize),
|
|
versionMap: dict.MakeConcurrent(dataDictSize),
|
|
addAof: func(line CmdLine) {},
|
|
}
|
|
return db
|
|
}
|
|
|
|
// Exec executes command within one database
|
|
func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
|
|
// transaction control commands and other commands which cannot execute within transaction
|
|
cmdName := strings.ToLower(string(cmdLine[0]))
|
|
if cmdName == "multi" {
|
|
if len(cmdLine) != 1 {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
return StartMulti(c)
|
|
} else if cmdName == "discard" {
|
|
if len(cmdLine) != 1 {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
return DiscardMulti(c)
|
|
} else if cmdName == "exec" {
|
|
if len(cmdLine) != 1 {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
return execMulti(db, c)
|
|
} else if cmdName == "watch" {
|
|
if !validateArity(-2, cmdLine) {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
return Watch(db, c, cmdLine[1:])
|
|
}
|
|
if c != nil && c.InMultiState() {
|
|
return EnqueueCmd(c, cmdLine)
|
|
}
|
|
|
|
return db.execNormalCommand(cmdLine)
|
|
}
|
|
|
|
func (db *DB) execNormalCommand(cmdLine [][]byte) redis.Reply {
|
|
cmdName := strings.ToLower(string(cmdLine[0]))
|
|
cmd, ok := cmdTable[cmdName]
|
|
if !ok {
|
|
return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
|
|
}
|
|
if !validateArity(cmd.arity, cmdLine) {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
|
|
prepare := cmd.prepare
|
|
write, read := prepare(cmdLine[1:])
|
|
db.addVersion(write...)
|
|
db.RWLocks(write, read)
|
|
defer db.RWUnLocks(write, read)
|
|
fun := cmd.executor
|
|
return fun(db, cmdLine[1:])
|
|
}
|
|
|
|
// execWithLock executes normal commands, invoker should provide locks
|
|
func (db *DB) execWithLock(cmdLine [][]byte) redis.Reply {
|
|
cmdName := strings.ToLower(string(cmdLine[0]))
|
|
cmd, ok := cmdTable[cmdName]
|
|
if !ok {
|
|
return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
|
|
}
|
|
if !validateArity(cmd.arity, cmdLine) {
|
|
return protocol.MakeArgNumErrReply(cmdName)
|
|
}
|
|
fun := cmd.executor
|
|
return fun(db, cmdLine[1:])
|
|
}
|
|
|
|
func validateArity(arity int, cmdArgs [][]byte) bool {
|
|
argNum := len(cmdArgs)
|
|
if arity >= 0 {
|
|
return argNum == arity
|
|
}
|
|
return argNum >= -arity
|
|
}
|
|
|
|
/* ---- Data Access ----- */
|
|
|
|
// GetEntity returns DataEntity bind to given key
|
|
func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
|
|
raw, ok := db.data.GetWithLock(key)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
if db.IsExpired(key) {
|
|
return nil, false
|
|
}
|
|
entity, _ := raw.(*database.DataEntity)
|
|
return entity, true
|
|
}
|
|
|
|
// PutEntity a DataEntity into DB
|
|
func (db *DB) PutEntity(key string, entity *database.DataEntity) int {
|
|
ret := db.data.PutWithLock(key, entity)
|
|
// db.insertCallback may be set as nil, during `if` and actually callback
|
|
// so introduce a local variable `cb`
|
|
if cb := db.insertCallback; ret > 0 && cb != nil {
|
|
cb(db.index, key, entity)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// PutIfExists edit an existing DataEntity
|
|
func (db *DB) PutIfExists(key string, entity *database.DataEntity) int {
|
|
return db.data.PutIfExistsWithLock(key, entity)
|
|
}
|
|
|
|
// PutIfAbsent insert an DataEntity only if the key not exists
|
|
func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int {
|
|
ret := db.data.PutIfAbsentWithLock(key, entity)
|
|
// db.insertCallback may be set as nil, during `if` and actually callback
|
|
// so introduce a local variable `cb`
|
|
if cb := db.insertCallback; ret > 0 && cb != nil {
|
|
cb(db.index, key, entity)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
// Remove the given key from db
|
|
func (db *DB) Remove(key string) {
|
|
raw, deleted := db.data.RemoveWithLock(key)
|
|
db.ttlMap.Remove(key)
|
|
taskKey := genExpireTask(key)
|
|
timewheel.Cancel(taskKey)
|
|
if cb := db.deleteCallback; cb != nil {
|
|
var entity *database.DataEntity
|
|
if deleted > 0 {
|
|
entity = raw.(*database.DataEntity)
|
|
}
|
|
cb(db.index, key, entity)
|
|
}
|
|
}
|
|
|
|
// Removes the given keys from db
|
|
func (db *DB) Removes(keys ...string) (deleted int) {
|
|
deleted = 0
|
|
for _, key := range keys {
|
|
_, exists := db.data.GetWithLock(key)
|
|
if exists {
|
|
db.Remove(key)
|
|
deleted++
|
|
}
|
|
}
|
|
return deleted
|
|
}
|
|
|
|
// Flush clean database
|
|
// deprecated
|
|
// for test only
|
|
func (db *DB) Flush() {
|
|
db.data.Clear()
|
|
db.ttlMap.Clear()
|
|
}
|
|
|
|
/* ---- Lock Function ----- */
|
|
|
|
// RWLocks lock keys for writing and reading
|
|
func (db *DB) RWLocks(writeKeys []string, readKeys []string) {
|
|
db.data.RWLocks(writeKeys, readKeys)
|
|
}
|
|
|
|
// RWUnLocks unlock keys for writing and reading
|
|
func (db *DB) RWUnLocks(writeKeys []string, readKeys []string) {
|
|
db.data.RWUnLocks(writeKeys, readKeys)
|
|
}
|
|
|
|
/* ---- TTL Functions ---- */
|
|
|
|
func genExpireTask(key string) string {
|
|
return "expire:" + key
|
|
}
|
|
|
|
// Expire sets ttlCmd of key
|
|
func (db *DB) Expire(key string, expireTime time.Time) {
|
|
db.ttlMap.Put(key, expireTime)
|
|
taskKey := genExpireTask(key)
|
|
timewheel.At(expireTime, taskKey, func() {
|
|
keys := []string{key}
|
|
db.RWLocks(keys, nil)
|
|
defer db.RWUnLocks(keys, nil)
|
|
// check-lock-check, ttl may be updated during waiting lock
|
|
logger.Info("expire " + key)
|
|
rawExpireTime, ok := db.ttlMap.Get(key)
|
|
if !ok {
|
|
return
|
|
}
|
|
expireTime, _ := rawExpireTime.(time.Time)
|
|
expired := time.Now().After(expireTime)
|
|
if expired {
|
|
db.Remove(key)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Persist cancel ttlCmd of key
|
|
func (db *DB) Persist(key string) {
|
|
db.ttlMap.Remove(key)
|
|
taskKey := genExpireTask(key)
|
|
timewheel.Cancel(taskKey)
|
|
}
|
|
|
|
// IsExpired check whether a key is expired
|
|
func (db *DB) IsExpired(key string) bool {
|
|
rawExpireTime, ok := db.ttlMap.Get(key)
|
|
if !ok {
|
|
return false
|
|
}
|
|
expireTime, _ := rawExpireTime.(time.Time)
|
|
expired := time.Now().After(expireTime)
|
|
if expired {
|
|
db.Remove(key)
|
|
}
|
|
return expired
|
|
}
|
|
|
|
/* --- add version --- */
|
|
|
|
func (db *DB) addVersion(keys ...string) {
|
|
for _, key := range keys {
|
|
versionCode := db.GetVersion(key)
|
|
db.versionMap.Put(key, versionCode+1)
|
|
}
|
|
}
|
|
|
|
// GetVersion returns version code for given key
|
|
func (db *DB) GetVersion(key string) uint32 {
|
|
entity, ok := db.versionMap.Get(key)
|
|
if !ok {
|
|
return 0
|
|
}
|
|
return entity.(uint32)
|
|
}
|
|
|
|
// ForEach traverses all the keys in the database
|
|
func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
|
|
db.data.ForEach(func(key string, raw interface{}) bool {
|
|
entity, _ := raw.(*database.DataEntity)
|
|
var expiration *time.Time
|
|
rawExpireTime, ok := db.ttlMap.Get(key)
|
|
if ok {
|
|
expireTime, _ := rawExpireTime.(time.Time)
|
|
expiration = &expireTime
|
|
}
|
|
|
|
return cb(key, entity, expiration)
|
|
})
|
|
}
|