multi database

This commit is contained in:
hdt3213
2021-08-28 22:31:23 +08:00
parent c9f33d08c2
commit e91294bcf4
46 changed files with 1135 additions and 821 deletions

209
aof.go
View File

@@ -1,209 +0,0 @@
package godis
import (
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/datastruct/lock"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/redis/parser"
"github.com/hdt3213/godis/redis/reply"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"time"
)
var pExpireAtBytes = []byte("PEXPIREAT")
func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
args := make([][]byte, 3)
args[0] = pExpireAtBytes
args[1] = []byte(key)
args[2] = []byte(strconv.FormatInt(expireAt.UnixNano()/1e6, 10))
return reply.MakeMultiBulkReply(args)
}
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
params := make([][]byte, len(args)+1)
copy(params[1:], args)
params[0] = []byte(cmd)
return reply.MakeMultiBulkReply(params)
}
// AddAof send command to aof goroutine through channel
func (db *DB) AddAof(args *reply.MultiBulkReply) {
// aofChan == nil when loadAof
if config.Properties.AppendOnly && db.aofChan != nil {
db.aofChan <- args
}
}
// handleAof listen aof channel and write into file
func (db *DB) handleAof() {
for cmd := range db.aofChan {
// todo: use switch and channels instead of mutex
db.pausingAof.RLock() // prevent other goroutines from pausing aof
if db.aofRewriteBuffer != nil {
// replica during rewrite
db.aofRewriteBuffer <- cmd
}
_, err := db.aofFile.Write(cmd.ToBytes())
if err != nil {
logger.Warn(err)
}
db.pausingAof.RUnlock()
}
db.aofFinished <- struct{}{}
}
// loadAof read aof file
func (db *DB) loadAof(maxBytes int) {
// delete aofChan to prevent write again
aofChan := db.aofChan
db.aofChan = nil
defer func(aofChan chan *reply.MultiBulkReply) {
db.aofChan = aofChan
}(aofChan)
file, err := os.Open(db.aofFilename)
if err != nil {
if _, ok := err.(*os.PathError); ok {
return
}
logger.Warn(err)
return
}
defer file.Close()
var reader io.Reader
if maxBytes > 0 {
reader = io.LimitReader(file, int64(maxBytes))
} else {
reader = file
}
ch := parser.ParseStream(reader)
for p := range ch {
if p.Err != nil {
if p.Err == io.EOF {
break
}
logger.Error("parse error: " + p.Err.Error())
continue
}
if p.Data == nil {
logger.Error("empty payload")
continue
}
r, ok := p.Data.(*reply.MultiBulkReply)
if !ok {
logger.Error("require multi bulk reply")
continue
}
cmd := strings.ToLower(string(r.Args[0]))
command, ok := cmdTable[cmd]
if ok {
handler := command.executor
handler(db, r.Args[1:])
}
}
}
/*-- aof rewrite --*/
func (db *DB) aofRewrite() {
file, fileSize, err := db.startRewrite()
if err != nil {
logger.Warn(err)
return
}
// load aof file
tmpDB := &DB{
data: dict.MakeSimple(),
ttlMap: dict.MakeSimple(),
locker: lock.Make(lockerSize),
aofFilename: db.aofFilename,
}
tmpDB.loadAof(int(fileSize))
// rewrite aof file
tmpDB.data.ForEach(func(key string, raw interface{}) bool {
entity, _ := raw.(*DataEntity)
cmd := EntityToCmd(key, entity)
if cmd != nil {
_, _ = file.Write(cmd.ToBytes())
}
return true
})
tmpDB.ttlMap.ForEach(func(key string, raw interface{}) bool {
expireTime, _ := raw.(time.Time)
cmd := makeExpireCmd(key, expireTime)
if cmd != nil {
_, _ = file.Write(cmd.ToBytes())
}
return true
})
db.finishRewrite(file)
}
func (db *DB) startRewrite() (*os.File, int64, error) {
db.pausingAof.Lock() // pausing aof
defer db.pausingAof.Unlock()
err := db.aofFile.Sync()
if err != nil {
logger.Warn("fsync failed")
return nil, 0, err
}
// create rewrite channel
db.aofRewriteBuffer = make(chan *reply.MultiBulkReply, aofQueueSize)
// get current aof file size
fileInfo, _ := os.Stat(db.aofFilename)
filesize := fileInfo.Size()
// create tmp file
file, err := ioutil.TempFile("", "aof")
if err != nil {
logger.Warn("tmp file create failed")
return nil, 0, err
}
return file, filesize, nil
}
func (db *DB) finishRewrite(tmpFile *os.File) {
db.pausingAof.Lock() // pausing aof
defer db.pausingAof.Unlock()
// write commands created during rewriting to tmp file
loop:
for {
// aof is pausing, there won't be any new commands in aofRewriteBuffer
select {
case cmd := <-db.aofRewriteBuffer:
_, err := tmpFile.Write(cmd.ToBytes())
if err != nil {
logger.Warn(err)
}
default:
// channel is empty, break loop
break loop
}
}
close(db.aofRewriteBuffer)
db.aofRewriteBuffer = nil
// replace current aof file by tmp file
_ = db.aofFile.Close()
_ = os.Rename(tmpFile.Name(), db.aofFilename)
// reopen aof file for further write
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
panic(err)
}
db.aofFile = aofFile
}

162
aof/aof.go Normal file
View File

@@ -0,0 +1,162 @@
package aof
import (
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/parser"
"github.com/hdt3213/godis/redis/reply"
"io"
"os"
"strconv"
"sync"
)
type CmdLine = [][]byte
const (
aofQueueSize = 1 << 16
)
type payload struct {
cmdLine CmdLine
dbIndex int
}
type Handler struct {
db database.EmbedDB
tmpDBMaker func() database.EmbedDB
aofChan chan *payload
aofFile *os.File
aofFilename string
// aof goroutine will send msg to main goroutine through this channel when aof tasks finished and ready to shutdown
aofFinished chan struct{}
// buffer commands received during aof rewrite progress
aofRewriteBuffer chan *payload
// pause aof for start/finish aof rewrite progress
pausingAof sync.RWMutex
currentDB int
}
func NewAOFHandler(db database.EmbedDB, tmpDBMaker func() database.EmbedDB) (*Handler, error) {
handler := &Handler{}
handler.aofFilename = config.Properties.AppendFilename
handler.db = db
handler.tmpDBMaker = tmpDBMaker
handler.LoadAof(0)
aofFile, err := os.OpenFile(handler.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, err
}
handler.aofFile = aofFile
handler.aofChan = make(chan *payload, aofQueueSize)
handler.aofFinished = make(chan struct{})
go func() {
handler.handleAof()
}()
return handler, nil
}
// AddAof send command to aof goroutine through channel
func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
if config.Properties.AppendOnly && handler.aofChan != nil {
handler.aofChan <- &payload{
cmdLine: cmdLine,
dbIndex: dbIndex,
}
}
}
// handleAof listen aof channel and write into file
func (handler *Handler) handleAof() {
// serialized execution
handler.currentDB = 0
for p := range handler.aofChan {
handler.pausingAof.RLock() // prevent other goroutines from pausing aof
if handler.aofRewriteBuffer != nil {
// replica during rewrite
handler.aofRewriteBuffer <- p
}
if p.dbIndex != handler.currentDB {
// select db
data := reply.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes()
_, err := handler.aofFile.Write(data)
if err != nil {
logger.Warn(err)
continue // skip this command
}
handler.currentDB = p.dbIndex
}
data := reply.MakeMultiBulkReply(p.cmdLine).ToBytes()
_, err := handler.aofFile.Write(data)
if err != nil {
logger.Warn(err)
}
handler.pausingAof.RUnlock()
}
handler.aofFinished <- struct{}{}
}
// LoadAof read aof file
func (handler *Handler) LoadAof(maxBytes int) {
// delete aofChan to prevent write again
aofChan := handler.aofChan
handler.aofChan = nil
defer func(aofChan chan *payload) {
handler.aofChan = aofChan
}(aofChan)
file, err := os.Open(handler.aofFilename)
if err != nil {
if _, ok := err.(*os.PathError); ok {
return
}
logger.Warn(err)
return
}
defer file.Close()
var reader io.Reader
if maxBytes > 0 {
reader = io.LimitReader(file, int64(maxBytes))
} else {
reader = file
}
ch := parser.ParseStream(reader)
fakeConn := &connection.FakeConn{} // only used for save dbIndex
for p := range ch {
if p.Err != nil {
if p.Err == io.EOF {
break
}
logger.Error("parse error: " + p.Err.Error())
continue
}
if p.Data == nil {
logger.Error("empty payload")
continue
}
r, ok := p.Data.(*reply.MultiBulkReply)
if !ok {
logger.Error("require multi bulk reply")
continue
}
ret := handler.db.Exec(fakeConn, r.Args)
if reply.IsErrorReply(ret) {
logger.Error("exec err", err)
}
}
}
func (handler *Handler) Close() {
if handler.aofFile != nil {
close(handler.aofChan)
<-handler.aofFinished // wait for aof finished
err := handler.aofFile.Close()
if err != nil {
logger.Warn(err)
}
}
}

View File

@@ -1,18 +1,18 @@
package godis package aof
import ( import (
"github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/dict"
List "github.com/hdt3213/godis/datastruct/list" List "github.com/hdt3213/godis/datastruct/list"
"github.com/hdt3213/godis/datastruct/set" "github.com/hdt3213/godis/datastruct/set"
SortedSet "github.com/hdt3213/godis/datastruct/sortedset" SortedSet "github.com/hdt3213/godis/datastruct/sortedset"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv" "strconv"
"time" "time"
) )
// EntityToCmd serialize data entity to redis command // EntityToCmd serialize data entity to redis command
func EntityToCmd(key string, entity *DataEntity) *reply.MultiBulkReply { func EntityToCmd(key string, entity *database.DataEntity) *reply.MultiBulkReply {
if entity == nil { if entity == nil {
return nil return nil
} }
@@ -32,18 +32,6 @@ func EntityToCmd(key string, entity *DataEntity) *reply.MultiBulkReply {
return cmd return cmd
} }
// toTTLCmd serialize ttl config
func toTTLCmd(db *DB, key string) *reply.MultiBulkReply {
raw, exists := db.ttlMap.Get(key)
if !exists {
// 无 TTL
return reply.MakeMultiBulkReply(utils.ToCmdLine("PERSIST", key))
}
expireTime, _ := raw.(time.Time)
timestamp := strconv.FormatInt(expireTime.UnixNano()/1000/1000, 10)
return reply.MakeMultiBulkReply(utils.ToCmdLine("PEXPIREAT", key, timestamp))
}
var setCmd = []byte("SET") var setCmd = []byte("SET")
func stringToCmd(key string, bytes []byte) *reply.MultiBulkReply { func stringToCmd(key string, bytes []byte) *reply.MultiBulkReply {
@@ -116,3 +104,14 @@ func zSetToCmd(key string, zset *SortedSet.SortedSet) *reply.MultiBulkReply {
}) })
return reply.MakeMultiBulkReply(args) return reply.MakeMultiBulkReply(args)
} }
var pExpireAtBytes = []byte("PEXPIREAT")
// MakeExpireCmd generates command line to set expiration for the given key
func MakeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
args := make([][]byte, 3)
args[0] = pExpireAtBytes
args[1] = []byte(key)
args[2] = []byte(strconv.FormatInt(expireAt.UnixNano()/1e6, 10))
return reply.MakeMultiBulkReply(args)
}

138
aof/rewrite.go Normal file
View File

@@ -0,0 +1,138 @@
package aof
import (
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply"
"io/ioutil"
"os"
"strconv"
"time"
)
func (handler *Handler) newRewriteHandler() *Handler {
h := &Handler{}
h.aofFilename = handler.aofFilename
h.db = handler.tmpDBMaker()
return h
}
func (handler *Handler) Rewrite() {
tmpFile, fileSize, err := handler.startRewrite()
if err != nil {
logger.Warn(err)
return
}
// load aof tmpFile
tmpAof := handler.newRewriteHandler()
tmpAof.LoadAof(int(fileSize))
// rewrite aof tmpFile
for i := 0; i < config.Properties.Databases; i++ {
// select db
data := reply.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(i))).ToBytes()
_, err := tmpFile.Write(data)
if err != nil {
logger.Warn(err)
return
}
// dump db
tmpAof.db.ForEach(i, func(key string, entity *database.DataEntity, expiration *time.Time) bool {
cmd := EntityToCmd(key, entity)
if cmd != nil {
_, _ = tmpFile.Write(cmd.ToBytes())
}
if expiration != nil {
cmd := MakeExpireCmd(key, *expiration)
if cmd != nil {
_, _ = tmpFile.Write(cmd.ToBytes())
}
}
return true
})
}
handler.finishRewrite(tmpFile)
}
func (handler *Handler) startRewrite() (*os.File, int64, error) {
handler.pausingAof.Lock() // pausing aof
defer handler.pausingAof.Unlock()
err := handler.aofFile.Sync()
if err != nil {
logger.Warn("fsync failed")
return nil, 0, err
}
// create rewrite channel
handler.aofRewriteBuffer = make(chan *payload, aofQueueSize)
// get current aof file size
fileInfo, _ := os.Stat(handler.aofFilename)
filesize := fileInfo.Size()
// create tmp file
file, err := ioutil.TempFile("", "aof")
if err != nil {
logger.Warn("tmp file create failed")
return nil, 0, err
}
return file, filesize, nil
}
func (handler *Handler) finishRewrite(tmpFile *os.File) {
handler.pausingAof.Lock() // pausing aof
defer handler.pausingAof.Unlock()
// write commands created during rewriting to tmp file
currentDB := -1
loop:
for {
// aof is pausing, there won't be any new commands in aofRewriteBuffer
select {
case p := <-handler.aofRewriteBuffer:
if p.dbIndex != currentDB {
// select db
// always do `select` during first loop 第一次进入循环时必须执行一次 select 确保数据库一致
data := reply.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes()
_, err := tmpFile.Write(data)
if err != nil {
logger.Warn(err)
continue // skip this command
}
currentDB = p.dbIndex
}
data := reply.MakeMultiBulkReply(p.cmdLine).ToBytes()
_, err := tmpFile.Write(data)
if err != nil {
logger.Warn(err)
}
default:
// channel is empty, break loop
break loop
}
}
close(handler.aofRewriteBuffer)
handler.aofRewriteBuffer = nil
// replace current aof file by tmp file
_ = handler.aofFile.Close()
_ = os.Rename(tmpFile.Name(), handler.aofFilename)
// reopen aof file for further write
aofFile, err := os.OpenFile(handler.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
panic(err)
}
handler.aofFile = aofFile
// reset selected db 重新写入一次 select 指令保证 aof 中的数据库与 handler.currentDB 一致
data := reply.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(handler.currentDB))).ToBytes()
_, err = handler.aofFile.Write(data)
if err != nil {
panic(err)
}
}

View File

@@ -2,8 +2,12 @@ package godis
import ( import (
"github.com/hdt3213/godis/config" "github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/hdt3213/godis/redis/reply/asserts"
"io/ioutil" "io/ioutil"
"os" "os"
"path" "path"
@@ -12,6 +16,84 @@ import (
"time" "time"
) )
func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{}
conn.SelectDB(dbIndex)
db.Exec(conn, utils.ToCmdLine("FlushDB"))
cursor := 0
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
db.Exec(conn, utils.ToCmdLine("SET", key, key, "EX", "10000"))
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
db.Exec(conn, utils.ToCmdLine("RPUSH", key, key))
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
db.Exec(conn, utils.ToCmdLine("HSET", key, key, key))
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
db.Exec(conn, utils.ToCmdLine("SADD", key, key))
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
db.Exec(conn, utils.ToCmdLine("ZADD", key, "10", key))
}
}
func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{}
conn.SelectDB(dbIndex)
cursor := 0
var ret redis.Reply
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
ret = db.Exec(conn, utils.ToCmdLine("GET", key))
asserts.AssertBulkReply(t, ret, key)
ret = db.Exec(conn, utils.ToCmdLine("TTL", key))
intResult, ok := ret.(*reply.IntReply)
if !ok {
t.Errorf("expected int reply, actually %s", ret.ToBytes())
return
}
if intResult.Code <= 0 || intResult.Code > 10000 {
t.Error("wrong ttl")
}
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
ret = db.Exec(conn, utils.ToCmdLine("LRANGE", key, "0", "-1"))
asserts.AssertMultiBulkReply(t, ret, []string{key})
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
ret = db.Exec(conn, utils.ToCmdLine("HGET", key, key))
asserts.AssertBulkReply(t, ret, key)
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
ret = db.Exec(conn, utils.ToCmdLine("SIsMember", key, key))
asserts.AssertIntReply(t, ret, 1)
}
for i := 0; i < size; i++ {
key := prefix + strconv.Itoa(cursor)
cursor++
ret = db.Exec(conn, utils.ToCmdLine("ZRANGE", key, "0", "-1"))
asserts.AssertMultiBulkReply(t, ret, []string{key})
}
}
func TestAof(t *testing.T) { func TestAof(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "godis") tmpDir, err := ioutil.TempDir("", "godis")
if err != nil { if err != nil {
@@ -26,69 +108,31 @@ func TestAof(t *testing.T) {
AppendOnly: true, AppendOnly: true,
AppendFilename: aofFilename, AppendFilename: aofFilename,
} }
aofWriteDB := MakeDB() dbNum := 4
size := 10 size := 10
keys := make([]string, 0) var prefixes []string
cursor := 0 aofWriteDB := NewStandaloneServer()
for i := 0; i < size; i++ { for i := 0; i < dbNum; i++ {
key := strconv.Itoa(cursor) prefix := utils.RandString(8)
cursor++ prefixes = append(prefixes, prefix)
execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), "EX", "10000")) makeTestData(aofWriteDB, i, prefix, size)
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := strconv.Itoa(cursor)
cursor++
execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := strconv.Itoa(cursor)
cursor++
execHSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), utils.RandString(8)))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := strconv.Itoa(cursor)
cursor++
execSAdd(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := strconv.Itoa(cursor)
cursor++
execZAdd(aofWriteDB, utils.ToCmdLine(key, "10", utils.RandString(8)))
keys = append(keys, key)
} }
aofWriteDB.Close() // wait for aof finished aofWriteDB.Close() // wait for aof finished
aofReadDB := MakeDB() // start new db and read aof file aofReadDB := NewStandaloneServer() // start new db and read aof file
for _, key := range keys { for i := 0; i < dbNum; i++ {
expect, ok := aofWriteDB.GetEntity(key) prefix := prefixes[i]
if !ok { validateTestData(t, aofReadDB, i, prefix, size)
t.Errorf("key not found in origin: %s", key)
continue
}
actual, ok := aofReadDB.GetEntity(key)
if !ok {
t.Errorf("key not found: %s", key)
continue
}
expectData := EntityToCmd(key, expect).ToBytes()
actualData := EntityToCmd(key, actual).ToBytes()
if !utils.BytesEquals(expectData, actualData) {
t.Errorf("wrong value of key: %s", key)
}
} }
aofReadDB.Close() aofReadDB.Close()
} }
func TestRewriteAOF(t *testing.T) { func TestRewriteAOF(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "godis") tmpFile, err := ioutil.TempFile("", "*.aof")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
aofFilename := path.Join(tmpDir, "a.aof") aofFilename := tmpFile.Name()
defer func() { defer func() {
_ = os.Remove(aofFilename) _ = os.Remove(aofFilename)
}() }()
@@ -96,85 +140,23 @@ func TestRewriteAOF(t *testing.T) {
AppendOnly: true, AppendOnly: true,
AppendFilename: aofFilename, AppendFilename: aofFilename,
} }
aofWriteDB := MakeDB() aofWriteDB := NewStandaloneServer()
size := 1 size := 1
keys := make([]string, 0) dbNum := 4
ttlKeys := make([]string, 0) var prefixes []string
cursor := 0 for i := 0; i < dbNum; i++ {
for i := 0; i < size; i++ { prefix := "" // utils.RandString(8)
key := "str" + strconv.Itoa(cursor) prefixes = append(prefixes, prefix)
cursor++ makeTestData(aofWriteDB, i, prefix, size)
execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
keys = append(keys, key)
} }
// test ttl //time.Sleep(2 * time.Second)
for i := 0; i < size; i++ { aofWriteDB.Exec(nil, utils.ToCmdLine("rewriteaof"))
key := "str" + strconv.Itoa(cursor) time.Sleep(2 * time.Second) // wait for async goroutine finish its job
cursor++
execSet(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8), "EX", "1000"))
ttlKeys = append(ttlKeys, key)
}
for i := 0; i < size; i++ {
key := "list" + strconv.Itoa(cursor)
cursor++
execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
execRPush(aofWriteDB, utils.ToCmdLine(key, utils.RandString(8)))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := "hash" + strconv.Itoa(cursor)
cursor++
field := utils.RandString(8)
execHSet(aofWriteDB, utils.ToCmdLine(key, field, utils.RandString(8)))
execHSet(aofWriteDB, utils.ToCmdLine(key, field, utils.RandString(8)))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := "set" + strconv.Itoa(cursor)
cursor++
member := utils.RandString(8)
execSAdd(aofWriteDB, utils.ToCmdLine(key, member))
execSAdd(aofWriteDB, utils.ToCmdLine(key, member))
keys = append(keys, key)
}
for i := 0; i < size; i++ {
key := "zset" + strconv.Itoa(cursor)
cursor++
execZAdd(aofWriteDB, utils.ToCmdLine(key, "10", utils.RandString(8)))
keys = append(keys, key)
}
time.Sleep(time.Second) // wait for async goroutine finish its job
aofWriteDB.aofRewrite()
aofWriteDB.Close() // wait for aof finished aofWriteDB.Close() // wait for aof finished
aofReadDB := MakeDB() // start new db and read aof file aofReadDB := NewStandaloneServer() // start new db and read aof file
for _, key := range keys { for i := 0; i < dbNum; i++ {
expect, ok := aofWriteDB.GetEntity(key) prefix := prefixes[i]
if !ok { validateTestData(t, aofReadDB, i, prefix, size)
t.Errorf("key not found in origin: %s", key)
continue
}
actual, ok := aofReadDB.GetEntity(key)
if !ok {
t.Errorf("key not found: %s", key)
continue
}
expectData := EntityToCmd(key, expect).ToBytes()
actualData := EntityToCmd(key, actual).ToBytes()
if !utils.BytesEquals(expectData, actualData) {
t.Errorf("wrong value of key: %s", key)
}
}
for _, key := range ttlKeys {
ret := execTTL(aofReadDB, utils.ToCmdLine(key))
intResult, ok := ret.(*reply.IntReply)
if !ok {
t.Errorf("expected int reply, actually %s", ret.ToBytes())
return
}
if intResult.Code <= 0 {
t.Errorf("expect a positive integer, actual: %d", intResult.Code)
}
} }
aofReadDB.Close() aofReadDB.Close()
} }

View File

@@ -7,6 +7,7 @@ import (
"github.com/hdt3213/godis" "github.com/hdt3213/godis"
"github.com/hdt3213/godis/config" "github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/consistenthash" "github.com/hdt3213/godis/lib/consistenthash"
"github.com/hdt3213/godis/lib/idgenerator" "github.com/hdt3213/godis/lib/idgenerator"
@@ -14,6 +15,7 @@ import (
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/jolestar/go-commons-pool/v2" "github.com/jolestar/go-commons-pool/v2"
"runtime/debug" "runtime/debug"
"strconv"
"strings" "strings"
) )
@@ -26,7 +28,7 @@ type Cluster struct {
peerPicker *consistenthash.Map peerPicker *consistenthash.Map
peerConnection map[string]*pool.ObjectPool peerConnection map[string]*pool.ObjectPool
db *godis.DB db database.EmbedDB
transactions *dict.SimpleDict // id -> Transaction transactions *dict.SimpleDict // id -> Transaction
idGenerator *idgenerator.IDGenerator idGenerator *idgenerator.IDGenerator
@@ -45,7 +47,7 @@ func MakeCluster() *Cluster {
cluster := &Cluster{ cluster := &Cluster{
self: config.Properties.Self, self: config.Properties.Self,
db: godis.MakeDB(), db: godis.NewStandaloneServer(),
transactions: dict.MakeSimple(), transactions: dict.MakeSimple(),
peerPicker: consistenthash.New(replicas, nil), peerPicker: consistenthash.New(replicas, nil),
peerConnection: make(map[string]*pool.ObjectPool), peerConnection: make(map[string]*pool.ObjectPool),
@@ -100,7 +102,7 @@ func (cluster *Cluster) Exec(c redis.Connection, cmdLine [][]byte) (result redis
}() }()
cmdName := strings.ToLower(string(cmdLine[0])) cmdName := strings.ToLower(string(cmdLine[0]))
if cmdName == "auth" { if cmdName == "auth" {
return godis.Auth(cluster.db, c, cmdLine[1:]) return godis.Auth(c, cmdLine[1:])
} }
if !isAuthenticated(c) { if !isAuthenticated(c) {
return reply.MakeErrReply("NOAUTH Authentication required") return reply.MakeErrReply("NOAUTH Authentication required")
@@ -110,20 +112,25 @@ func (cluster *Cluster) Exec(c redis.Connection, cmdLine [][]byte) (result redis
if len(cmdLine) != 1 { if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName) return reply.MakeArgNumErrReply(cmdName)
} }
return godis.StartMulti(cluster.db, c) return godis.StartMulti(c)
} else if cmdName == "discard" { } else if cmdName == "discard" {
if len(cmdLine) != 1 { if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName) return reply.MakeArgNumErrReply(cmdName)
} }
return godis.DiscardMulti(cluster.db, c) return godis.DiscardMulti(c)
} else if cmdName == "exec" { } else if cmdName == "exec" {
if len(cmdLine) != 1 { if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName) return reply.MakeArgNumErrReply(cmdName)
} }
return execMulti(cluster, c, nil) return execMulti(cluster, c, nil)
} else if cmdName == "select" {
if len(cmdLine) != 2 {
return reply.MakeArgNumErrReply(cmdName)
}
return execSelect(c, cmdLine)
} }
if c != nil && c.InMultiState() { if c != nil && c.InMultiState() {
return godis.EnqueueCmd(cluster.db, c, cmdLine) return godis.EnqueueCmd(c, cmdLine)
} }
cmdFunc, ok := router[cmdName] cmdFunc, ok := router[cmdName]
if !ok { if !ok {
@@ -138,8 +145,8 @@ func (cluster *Cluster) AfterClientClose(c redis.Connection) {
cluster.db.AfterClientClose(c) cluster.db.AfterClientClose(c)
} }
func ping(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func ping(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Reply {
return godis.Ping(cluster.db, args[1:]) return cluster.db.Exec(c, cmdLine)
} }
/*----- utils -------*/ /*----- utils -------*/
@@ -167,3 +174,15 @@ func (cluster *Cluster) groupBy(keys []string) map[string][]string {
} }
return result return result
} }
func execSelect(c redis.Connection, args [][]byte) redis.Reply {
dbIndex, err := strconv.Atoi(string(args[1]))
if err != nil {
return reply.MakeErrReply("ERR invalid DB index")
}
if dbIndex >= config.Properties.Databases {
return reply.MakeErrReply("ERR DB index is out of range")
}
c.SelectDB(dbIndex)
return reply.MakeOkReply()
}

View File

@@ -4,16 +4,18 @@ import (
"context" "context"
"errors" "errors"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/client" "github.com/hdt3213/godis/redis/client"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv"
) )
func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) { func (cluster *Cluster) getPeerClient(peer string) (*client.Client, error) {
connectionFactory, ok := cluster.peerConnection[peer] factory, ok := cluster.peerConnection[peer]
if !ok { if !ok {
return nil, errors.New("connection factory not found") return nil, errors.New("connection factory not found")
} }
raw, err := connectionFactory.BorrowObject(context.Background()) raw, err := factory.BorrowObject(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -33,6 +35,7 @@ func (cluster *Cluster) returnPeerClient(peer string, peerClient *client.Client)
} }
// relay relays command to peer // relay relays command to peer
// select db by c.GetDBIndex()
// cannot call Prepare, Commit, execRollback of self node // cannot call Prepare, Commit, execRollback of self node
func (cluster *Cluster) relay(peer string, c redis.Connection, args [][]byte) redis.Reply { func (cluster *Cluster) relay(peer string, c redis.Connection, args [][]byte) redis.Reply {
if peer == cluster.self { if peer == cluster.self {
@@ -46,6 +49,7 @@ func (cluster *Cluster) relay(peer string, c redis.Connection, args [][]byte) re
defer func() { defer func() {
_ = cluster.returnPeerClient(peer, peerClient) _ = cluster.returnPeerClient(peer, peerClient)
}() }()
peerClient.Send(utils.ToCmdLine("SELECT", strconv.Itoa(c.GetDBIndex())))
return peerClient.Send(args) return peerClient.Send(args)
} }

View File

@@ -10,11 +10,12 @@ import (
func TestExec(t *testing.T) { func TestExec(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
conn := &connection.FakeConn{}
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
testCluster2.Exec(nil, toArgs("SET", key, value)) testCluster2.Exec(conn, toArgs("SET", key, value))
ret := testCluster2.Exec(nil, toArgs("GET", key)) ret := testCluster2.Exec(conn, toArgs("GET", key))
asserts.AssertBulkReply(t, ret, value) asserts.AssertBulkReply(t, ret, value)
} }
} }
@@ -38,9 +39,10 @@ func TestRelay(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
ret := testCluster2.relay("127.0.0.1:6379", nil, toArgs("SET", key, value)) conn := &connection.FakeConn{}
ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
ret = testCluster2.relay("127.0.0.1:6379", nil, toArgs("GET", key)) ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key))
asserts.AssertBulkReply(t, ret, value) asserts.AssertBulkReply(t, ret, value)
} }
@@ -48,7 +50,7 @@ func TestBroadcast(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
rets := testCluster2.broadcast(nil, toArgs("SET", key, value)) rets := testCluster2.broadcast(&connection.FakeConn{}, toArgs("SET", key, value))
for _, v := range rets { for _, v := range rets {
asserts.AssertNotError(t, v) asserts.AssertNotError(t, v)
} }

View File

@@ -1,15 +1,17 @@
package cluster package cluster
import ( import (
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/reply/asserts" "github.com/hdt3213/godis/redis/reply/asserts"
"testing" "testing"
) )
func TestDel(t *testing.T) { func TestDel(t *testing.T) {
conn := &connection.FakeConn{}
allowFastTransaction = false allowFastTransaction = false
testCluster.Exec(nil, toArgs("SET", "a", "a")) testCluster.Exec(conn, toArgs("SET", "a", "a"))
ret := Del(testCluster, nil, toArgs("DEL", "a", "b", "c")) ret := Del(testCluster, conn, toArgs("DEL", "a", "b", "c"))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
ret = testCluster.Exec(nil, toArgs("GET", "a")) ret = testCluster.Exec(conn, toArgs("GET", "a"))
asserts.AssertNullBulk(t, ret) asserts.AssertNullBulk(t, ret)
} }

View File

@@ -1,21 +1,24 @@
package cluster package cluster
import ( import (
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/reply/asserts" "github.com/hdt3213/godis/redis/reply/asserts"
"testing" "testing"
) )
func TestMSet(t *testing.T) { func TestMSet(t *testing.T) {
conn := &connection.FakeConn{}
allowFastTransaction = false allowFastTransaction = false
ret := MSet(testCluster, nil, toArgs("MSET", "a", "a", "b", "b")) ret := MSet(testCluster, conn, toArgs("MSET", "a", "a", "b", "b"))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
ret = testCluster.Exec(nil, toArgs("MGET", "a", "b")) ret = testCluster.Exec(conn, toArgs("MGET", "a", "b"))
asserts.AssertMultiBulkReply(t, ret, []string{"a", "b"}) asserts.AssertMultiBulkReply(t, ret, []string{"a", "b"})
} }
func TestMSetNx(t *testing.T) { func TestMSetNx(t *testing.T) {
conn := &connection.FakeConn{}
allowFastTransaction = false allowFastTransaction = false
FlushAll(testCluster, nil, toArgs("FLUSHALL")) FlushAll(testCluster, conn, toArgs("FLUSHALL"))
ret := MSetNX(testCluster, nil, toArgs("MSETNX", "a", "a", "b", "b")) ret := MSetNX(testCluster, conn, toArgs("MSETNX", "a", "a", "b", "b"))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
} }

View File

@@ -24,7 +24,7 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R
// analysis related keys // analysis related keys
keys := make([]string, 0) // may contains duplicate keys := make([]string, 0) // may contains duplicate
for _, cl := range cmdLines { for _, cl := range cmdLines {
wKeys, rKeys := cluster.db.GetRelatedKeys(cl) wKeys, rKeys := godis.GetRelatedKeys(cl)
keys = append(keys, wKeys...) keys = append(keys, wKeys...)
keys = append(keys, rKeys...) keys = append(keys, rKeys...)
} }
@@ -36,7 +36,7 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R
keys = append(keys, watchingKeys...) keys = append(keys, watchingKeys...)
if len(keys) == 0 { if len(keys) == 0 {
// empty transaction or only `PING`s // empty transaction or only `PING`s
return godis.ExecMulti(cluster.db, conn, watching, cmdLines) return cluster.db.ExecMulti(conn, watching, cmdLines)
} }
groupMap := cluster.groupBy(keys) groupMap := cluster.groupBy(keys)
if len(groupMap) > 1 { if len(groupMap) > 1 {
@@ -50,7 +50,7 @@ func execMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine) redis.R
// out parser not support reply.MultiRawReply, so we have to encode it // out parser not support reply.MultiRawReply, so we have to encode it
if peer == cluster.self { if peer == cluster.self {
return godis.ExecMulti(cluster.db, conn, watching, cmdLines) return cluster.db.ExecMulti(conn, watching, cmdLines)
} }
return execMultiOnOtherNode(cluster, conn, peer, watching, cmdLines) return execMultiOnOtherNode(cluster, conn, peer, watching, cmdLines)
} }
@@ -74,7 +74,7 @@ func execMultiOnOtherNode(cluster *Cluster, conn redis.Connection, peer string,
var rawRelayResult redis.Reply var rawRelayResult redis.Reply
if peer == cluster.self { if peer == cluster.self {
// this branch just for testing // this branch just for testing
rawRelayResult = execRelayedMulti(cluster, nil, relayCmdLine) rawRelayResult = execRelayedMulti(cluster, conn, relayCmdLine)
} else { } else {
rawRelayResult = cluster.relay(peer, conn, relayCmdLine) rawRelayResult = cluster.relay(peer, conn, relayCmdLine)
} }
@@ -126,7 +126,7 @@ func execRelayedMulti(cluster *Cluster, conn redis.Connection, cmdLine CmdLine)
} }
watching[key] = uint32(ver) watching[key] = uint32(ver)
} }
rawResult := godis.ExecMulti(cluster.db, conn, watching, txCmdLines[1:]) rawResult := cluster.db.ExecMulti(conn, watching, txCmdLines[1:])
_, ok := rawResult.(*reply.EmptyMultiBulkReply) _, ok := rawResult.(*reply.EmptyMultiBulkReply)
if ok { if ok {
return rawResult return rawResult

View File

@@ -9,8 +9,8 @@ import (
) )
func TestMultiExecOnSelf(t *testing.T) { func TestMultiExecOnSelf(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testCluster.db.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testCluster.Exec(conn, toArgs("MULTI")) result := testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key := utils.RandString(10) key := utils.RandString(10)
@@ -27,8 +27,8 @@ func TestMultiExecOnSelf(t *testing.T) {
} }
func TestEmptyMulti(t *testing.T) { func TestEmptyMulti(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testCluster.db.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testCluster.Exec(conn, toArgs("MULTI")) result := testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
result = testCluster.Exec(conn, utils.ToCmdLine("PING")) result = testCluster.Exec(conn, utils.ToCmdLine("PING"))
@@ -40,8 +40,8 @@ func TestEmptyMulti(t *testing.T) {
} }
func TestMultiExecOnOthers(t *testing.T) { func TestMultiExecOnOthers(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testCluster.db.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testCluster.Exec(conn, toArgs("MULTI")) result := testCluster.Exec(conn, toArgs("MULTI"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key := utils.RandString(10) key := utils.RandString(10)
@@ -59,8 +59,8 @@ func TestMultiExecOnOthers(t *testing.T) {
} }
func TestWatch(t *testing.T) { func TestWatch(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testCluster.db.Exec(conn, utils.ToCmdLine("FLUSHALL"))
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("watch", key)) testCluster.Exec(conn, utils.ToCmdLine("watch", key))
@@ -86,8 +86,8 @@ func TestWatch(t *testing.T) {
} }
func TestWatch2(t *testing.T) { func TestWatch2(t *testing.T) {
testCluster.db.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testCluster.db.Exec(conn, utils.ToCmdLine("FLUSHALL"))
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testCluster.Exec(conn, utils.ToCmdLine("watch", key)) testCluster.Exec(conn, utils.ToCmdLine("watch", key))

View File

@@ -3,29 +3,31 @@ package cluster
import ( import (
"fmt" "fmt"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/hdt3213/godis/redis/reply/asserts" "github.com/hdt3213/godis/redis/reply/asserts"
"testing" "testing"
) )
func TestRename(t *testing.T) { func TestRename(t *testing.T) {
conn := new(connection.FakeConn)
testDB := testCluster.db testDB := testCluster.db
testDB.Exec(nil, utils.ToCmdLine("FlushALL")) testDB.Exec(conn, utils.ToCmdLine("FlushALL"))
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
newKey := key + utils.RandString(2) newKey := key + utils.RandString(2)
testDB.Exec(nil, utils.ToCmdLine("SET", key, value, "ex", "1000")) testDB.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000"))
result := Rename(testCluster, nil, utils.ToCmdLine("RENAME", key, newKey)) result := Rename(testCluster, conn, utils.ToCmdLine("RENAME", key, newKey))
if _, ok := result.(*reply.OkReply); !ok { if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok") t.Error("expect ok")
return return
} }
result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", key)) result = testDB.Exec(conn, utils.ToCmdLine("EXISTS", key))
asserts.AssertIntReply(t, result, 0) asserts.AssertIntReply(t, result, 0)
result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", newKey)) result = testDB.Exec(conn, utils.ToCmdLine("EXISTS", newKey))
asserts.AssertIntReply(t, result, 1) asserts.AssertIntReply(t, result, 1)
// check ttl // check ttl
result = testDB.Exec(nil, utils.ToCmdLine("TTL", newKey)) result = testDB.Exec(conn, utils.ToCmdLine("TTL", newKey))
intResult, ok := result.(*reply.IntReply) intResult, ok := result.(*reply.IntReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))
@@ -38,20 +40,21 @@ func TestRename(t *testing.T) {
} }
func TestRenameNx(t *testing.T) { func TestRenameNx(t *testing.T) {
conn := new(connection.FakeConn)
testDB := testCluster.db testDB := testCluster.db
testDB.Exec(nil, utils.ToCmdLine("FlushALL")) testDB.Exec(conn, utils.ToCmdLine("FlushALL"))
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
newKey := key + utils.RandString(2) newKey := key + utils.RandString(2)
testCluster.db.Exec(nil, utils.ToCmdLine("SET", key, value, "ex", "1000")) testCluster.db.Exec(conn, utils.ToCmdLine("SET", key, value, "ex", "1000"))
result := RenameNx(testCluster, nil, utils.ToCmdLine("RENAMENX", key, newKey)) result := RenameNx(testCluster, conn, utils.ToCmdLine("RENAMENX", key, newKey))
asserts.AssertIntReply(t, result, 1) asserts.AssertIntReply(t, result, 1)
result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", key)) result = testDB.Exec(conn, utils.ToCmdLine("EXISTS", key))
asserts.AssertIntReply(t, result, 0) asserts.AssertIntReply(t, result, 0)
result = testDB.Exec(nil, utils.ToCmdLine("EXISTS", newKey)) result = testDB.Exec(conn, utils.ToCmdLine("EXISTS", newKey))
asserts.AssertIntReply(t, result, 1) asserts.AssertIntReply(t, result, 1)
result = testDB.Exec(nil, utils.ToCmdLine("TTL", newKey)) result = testDB.Exec(conn, utils.ToCmdLine("TTL", newKey))
intResult, ok := result.(*reply.IntReply) intResult, ok := result.(*reply.IntReply)
if !ok { if !ok {
t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes())) t.Error(fmt.Sprintf("expected int reply, actually %s", result.ToBytes()))

View File

@@ -2,6 +2,7 @@ package cluster
import ( import (
"fmt" "fmt"
"github.com/hdt3213/godis"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/lib/timewheel"
@@ -17,6 +18,7 @@ type Transaction struct {
cmdLine [][]byte // cmd cmdLine cmdLine [][]byte // cmd cmdLine
cluster *Cluster cluster *Cluster
conn redis.Connection conn redis.Connection
dbIndex int
writeKeys []string writeKeys []string
readKeys []string readKeys []string
@@ -48,6 +50,7 @@ func NewTransaction(cluster *Cluster, c redis.Connection, id string, cmdLine [][
cmdLine: cmdLine, cmdLine: cmdLine,
cluster: cluster, cluster: cluster,
conn: c, conn: c,
dbIndex: c.GetDBIndex(),
status: createdStatus, status: createdStatus,
mu: new(sync.Mutex), mu: new(sync.Mutex),
} }
@@ -57,14 +60,14 @@ func NewTransaction(cluster *Cluster, c redis.Connection, id string, cmdLine [][
// invoker should hold tx.mu // invoker should hold tx.mu
func (tx *Transaction) lockKeys() { func (tx *Transaction) lockKeys() {
if !tx.keysLocked { if !tx.keysLocked {
tx.cluster.db.RWLocks(tx.writeKeys, tx.readKeys) tx.cluster.db.RWLocks(tx.dbIndex, tx.writeKeys, tx.readKeys)
tx.keysLocked = true tx.keysLocked = true
} }
} }
func (tx *Transaction) unLockKeys() { func (tx *Transaction) unLockKeys() {
if tx.keysLocked { if tx.keysLocked {
tx.cluster.db.RWUnLocks(tx.writeKeys, tx.readKeys) tx.cluster.db.RWUnLocks(tx.dbIndex, tx.writeKeys, tx.readKeys)
tx.keysLocked = false tx.keysLocked = false
} }
} }
@@ -74,12 +77,12 @@ func (tx *Transaction) prepare() error {
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
tx.writeKeys, tx.readKeys = tx.cluster.db.GetRelatedKeys(tx.cmdLine) tx.writeKeys, tx.readKeys = godis.GetRelatedKeys(tx.cmdLine)
// lock writeKeys // lock writeKeys
tx.lockKeys() tx.lockKeys()
// build undoLog // build undoLog
tx.undoLog = tx.cluster.db.GetUndoLogs(tx.cmdLine) tx.undoLog = tx.cluster.db.GetUndoLogs(tx.dbIndex, tx.cmdLine)
tx.status = preparedStatus tx.status = preparedStatus
taskKey := genTaskKey(tx.id) taskKey := genTaskKey(tx.id)
timewheel.Delay(maxLockTime, taskKey, func() { timewheel.Delay(maxLockTime, taskKey, func() {
@@ -104,7 +107,7 @@ func (tx *Transaction) rollback() error {
} }
tx.lockKeys() tx.lockKeys()
for _, cmdLine := range tx.undoLog { for _, cmdLine := range tx.undoLog {
tx.cluster.db.ExecWithLock(cmdLine) tx.cluster.db.ExecWithLock(tx.conn, cmdLine)
} }
tx.unLockKeys() tx.unLockKeys()
tx.status = rolledBackStatus tx.status = rolledBackStatus
@@ -163,7 +166,7 @@ func execCommit(cluster *Cluster, c redis.Connection, cmdLine CmdLine) redis.Rep
tx.mu.Lock() tx.mu.Lock()
defer tx.mu.Unlock() defer tx.mu.Unlock()
result := cluster.db.ExecWithLock(tx.cmdLine) result := cluster.db.ExecWithLock(c, tx.cmdLine)
if reply.IsErrorReply(result) { if reply.IsErrorReply(result) {
// failed // failed

View File

@@ -1,6 +1,7 @@
package cluster package cluster
import ( import (
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/reply/asserts" "github.com/hdt3213/godis/redis/reply/asserts"
"math/rand" "math/rand"
"strconv" "strconv"
@@ -9,37 +10,38 @@ import (
func TestRollback(t *testing.T) { func TestRollback(t *testing.T) {
// rollback uncommitted transaction // rollback uncommitted transaction
FlushAll(testCluster, nil, toArgs("FLUSHALL")) conn := new(connection.FakeConn)
FlushAll(testCluster, conn, toArgs("FLUSHALL"))
txID := rand.Int63() txID := rand.Int63()
txIDStr := strconv.FormatInt(txID, 10) txIDStr := strconv.FormatInt(txID, 10)
keys := []string{"a", "b"} keys := []string{"a", "b"}
groupMap := testCluster.groupBy(keys) groupMap := testCluster.groupBy(keys)
args := []string{txIDStr, "DEL"} args := []string{txIDStr, "DEL"}
args = append(args, keys...) args = append(args, keys...)
testCluster.Exec(nil, toArgs("SET", "a", "a")) testCluster.Exec(conn, toArgs("SET", "a", "a"))
ret := execPrepare(testCluster, nil, makeArgs("Prepare", args...)) ret := execPrepare(testCluster, conn, makeArgs("Prepare", args...))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
requestRollback(testCluster, nil, txID, groupMap) requestRollback(testCluster, conn, txID, groupMap)
ret = testCluster.Exec(nil, toArgs("GET", "a")) ret = testCluster.Exec(conn, toArgs("GET", "a"))
asserts.AssertBulkReply(t, ret, "a") asserts.AssertBulkReply(t, ret, "a")
// rollback committed transaction // rollback committed transaction
FlushAll(testCluster, nil, toArgs("FLUSHALL")) FlushAll(testCluster, conn, toArgs("FLUSHALL"))
txID = rand.Int63() txID = rand.Int63()
txIDStr = strconv.FormatInt(txID, 10) txIDStr = strconv.FormatInt(txID, 10)
args = []string{txIDStr, "DEL"} args = []string{txIDStr, "DEL"}
args = append(args, keys...) args = append(args, keys...)
testCluster.Exec(nil, toArgs("SET", "a", "a")) testCluster.Exec(conn, toArgs("SET", "a", "a"))
ret = execPrepare(testCluster, nil, makeArgs("Prepare", args...)) ret = execPrepare(testCluster, conn, makeArgs("Prepare", args...))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
_, err := requestCommit(testCluster, nil, txID, groupMap) _, err := requestCommit(testCluster, conn, txID, groupMap)
if err != nil { if err != nil {
t.Errorf("del failed %v", err) t.Errorf("del failed %v", err)
return return
} }
ret = testCluster.Exec(nil, toArgs("GET", "a")) ret = testCluster.Exec(conn, toArgs("GET", "a"))
asserts.AssertNullBulk(t, ret) asserts.AssertNullBulk(t, ret)
requestRollback(testCluster, nil, txID, groupMap) requestRollback(testCluster, conn, txID, groupMap)
ret = testCluster.Exec(nil, toArgs("GET", "a")) ret = testCluster.Exec(conn, toArgs("GET", "a"))
asserts.AssertBulkReply(t, ret, "a") asserts.AssertBulkReply(t, ret, "a")
} }

View File

@@ -18,6 +18,7 @@ type ServerProperties struct {
AppendFilename string `cfg:"appendFilename"` AppendFilename string `cfg:"appendFilename"`
MaxClients int `cfg:"maxclients"` MaxClients int `cfg:"maxclients"`
RequirePass string `cfg:"requirepass"` RequirePass string `cfg:"requirepass"`
Databases int `cfg:"databases"`
Peers []string `cfg:"peers"` Peers []string `cfg:"peers"`
Self string `cfg:"self"` Self string `cfg:"self"`

View File

@@ -11,6 +11,7 @@ import (
type ConcurrentDict struct { type ConcurrentDict struct {
table []*shard table []*shard
count int32 count int32
shardCount int
} }
type shard struct { type shard struct {
@@ -46,6 +47,7 @@ func MakeConcurrent(shardCount int) *ConcurrentDict {
d := &ConcurrentDict{ d := &ConcurrentDict{
count: 0, count: 0,
table: table, table: table,
shardCount: shardCount,
} }
return d return d
} }
@@ -284,3 +286,7 @@ func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string {
} }
return arr return arr
} }
func (dict *ConcurrentDict) Clear() {
*dict = *MakeConcurrent(dict.shardCount)
}

View File

@@ -15,4 +15,5 @@ type Dict interface {
Keys() []string Keys() []string
RandomKeys(limit int) []string RandomKeys(limit int) []string
RandomDistinctKeys(limit int) []string RandomDistinctKeys(limit int) []string
Clear()
} }

View File

@@ -114,3 +114,7 @@ func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
} }
return result return result
} }
func (dict *SimpleDict) Clear() {
*dict = *MakeSimple()
}

147
db.go
View File

@@ -2,15 +2,14 @@
package godis package godis
import ( import (
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/datastruct/lock" "github.com/hdt3213/godis/datastruct/lock"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/timewheel" "github.com/hdt3213/godis/lib/timewheel"
"github.com/hdt3213/godis/pubsub"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"os" "strings"
"sync" "sync"
"time" "time"
) )
@@ -19,11 +18,11 @@ const (
dataDictSize = 1 << 16 dataDictSize = 1 << 16
ttlDictSize = 1 << 10 ttlDictSize = 1 << 10
lockerSize = 1024 lockerSize = 1024
aofQueueSize = 1 << 16
) )
// DB stores data and execute user's commands // DB stores data and execute user's commands
type DB struct { type DB struct {
index int
// key -> DataEntity // key -> DataEntity
data dict.Dict data dict.Dict
// key -> expireTime (time.Time) // key -> expireTime (time.Time)
@@ -36,24 +35,7 @@ type DB struct {
locker *lock.Locks locker *lock.Locks
// stop all data access for execFlushDB // stop all data access for execFlushDB
stopWorld sync.WaitGroup stopWorld sync.WaitGroup
// handle publish/subscribe addAof func(CmdLine)
hub *pubsub.Hub
// main goroutine send commands to aof goroutine through aofChan
aofChan chan *reply.MultiBulkReply
aofFile *os.File
aofFilename string
// aof goroutine will send msg to main goroutine through this channel when aof tasks finished and ready to shutdown
aofFinished chan struct{}
// buffer commands received during aof rewrite progress
aofRewriteBuffer chan *reply.MultiBulkReply
// pause aof for start/finish aof rewrite progress
pausingAof sync.RWMutex
}
// DataEntity stores data bound to a key, including a string, list, hash, set and so on
type DataEntity struct {
Data interface{}
} }
// ExecFunc is interface for command executor // ExecFunc is interface for command executor
@@ -71,45 +53,80 @@ type CmdLine = [][]byte
// execute from head to tail when undo // execute from head to tail when undo
type UndoFunc func(db *DB, args [][]byte) []CmdLine type UndoFunc func(db *DB, args [][]byte) []CmdLine
// MakeDB create DB instance and start it // makeDB create DB instance
func MakeDB() *DB { func makeDB() *DB {
db := &DB{ db := &DB{
data: dict.MakeConcurrent(dataDictSize), data: dict.MakeConcurrent(dataDictSize),
ttlMap: dict.MakeConcurrent(ttlDictSize), ttlMap: dict.MakeConcurrent(ttlDictSize),
versionMap: dict.MakeConcurrent(dataDictSize), versionMap: dict.MakeConcurrent(dataDictSize),
locker: lock.Make(lockerSize), locker: lock.Make(lockerSize),
hub: pubsub.MakeHub(), addAof: func(line CmdLine) {},
}
// aof
if config.Properties.AppendOnly {
db.aofFilename = config.Properties.AppendFilename
db.loadAof(0)
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
logger.Warn(err)
} else {
db.aofFile = aofFile
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
}
db.aofFinished = make(chan struct{})
go func() {
db.handleAof()
}()
} }
return db return db
} }
// Close graceful shutdown database // makeBasicDB create DB instance only with basic abilities.
func (db *DB) Close() { // It is not concurrent safe
if db.aofFile != nil { func makeBasicDB() *DB {
close(db.aofChan) db := &DB{
<-db.aofFinished // wait for aof finished data: dict.MakeSimple(),
err := db.aofFile.Close() ttlMap: dict.MakeSimple(),
if err != nil { versionMap: dict.MakeSimple(),
logger.Warn(err) locker: lock.Make(1),
addAof: func(line CmdLine) {},
} }
return db
}
// Exec executes command within one database
func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0]))
if cmdName == "multi" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName)
} }
return StartMulti(c)
} else if cmdName == "discard" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName)
}
return DiscardMulti(c)
} else if cmdName == "exec" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName)
}
return execMulti(db, c)
} else if cmdName == "watch" {
if !validateArity(-2, cmdLine) {
return reply.MakeArgNumErrReply(cmdName)
}
return Watch(db, c, cmdLine[1:])
}
if c != nil && c.InMultiState() {
EnqueueCmd(c, cmdLine)
return reply.MakeQueuedReply()
}
return db.execNormalCommand(cmdLine)
}
func (db *DB) execNormalCommand(cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmdName + "'")
}
if !validateArity(cmd.arity, cmdLine) {
return reply.MakeArgNumErrReply(cmdName)
}
prepare := cmd.prepare
write, read := prepare(cmdLine[1:])
db.addVersion(write...)
db.RWLocks(write, read)
defer db.RWUnLocks(write, read)
fun := cmd.executor
return fun(db, cmdLine[1:])
} }
func validateArity(arity int, cmdArgs [][]byte) bool { func validateArity(arity int, cmdArgs [][]byte) bool {
@@ -123,7 +140,7 @@ func validateArity(arity int, cmdArgs [][]byte) bool {
/* ---- Data Access ----- */ /* ---- Data Access ----- */
// GetEntity returns DataEntity bind to given key // GetEntity returns DataEntity bind to given key
func (db *DB) GetEntity(key string) (*DataEntity, bool) { func (db *DB) GetEntity(key string) (*database.DataEntity, bool) {
db.stopWorld.Wait() db.stopWorld.Wait()
raw, ok := db.data.Get(key) raw, ok := db.data.Get(key)
@@ -133,24 +150,24 @@ func (db *DB) GetEntity(key string) (*DataEntity, bool) {
if db.IsExpired(key) { if db.IsExpired(key) {
return nil, false return nil, false
} }
entity, _ := raw.(*DataEntity) entity, _ := raw.(*database.DataEntity)
return entity, true return entity, true
} }
// PutEntity a DataEntity into DB // PutEntity a DataEntity into DB
func (db *DB) PutEntity(key string, entity *DataEntity) int { func (db *DB) PutEntity(key string, entity *database.DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.data.Put(key, entity) return db.data.Put(key, entity)
} }
// PutIfExists edit an existing DataEntity // PutIfExists edit an existing DataEntity
func (db *DB) PutIfExists(key string, entity *DataEntity) int { func (db *DB) PutIfExists(key string, entity *database.DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.data.PutIfExists(key, entity) return db.data.PutIfExists(key, entity)
} }
// PutIfAbsent insert an DataEntity only if the key not exists // PutIfAbsent insert an DataEntity only if the key not exists
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int { func (db *DB) PutIfAbsent(key string, entity *database.DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.data.PutIfAbsent(key, entity) return db.data.PutIfAbsent(key, entity)
} }
@@ -183,8 +200,8 @@ func (db *DB) Flush() {
db.stopWorld.Add(1) db.stopWorld.Add(1)
defer db.stopWorld.Done() defer db.stopWorld.Done()
db.data = dict.MakeConcurrent(dataDictSize) db.data.Clear()
db.ttlMap = dict.MakeConcurrent(ttlDictSize) db.ttlMap.Clear()
db.locker = lock.Make(lockerSize) db.locker = lock.Make(lockerSize)
} }
@@ -269,9 +286,15 @@ func (db *DB) GetVersion(key string) uint32 {
return entity.(uint32) return entity.(uint32)
} }
/* ---- Subscribe Functions ---- */ func (db *DB) ForEach(cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
db.data.ForEach(func(key string, raw interface{}) bool {
// AfterClientClose does some clean after client close connection entity, _ := raw.(*database.DataEntity)
func (db *DB) AfterClientClose(c redis.Connection) { var expiration *time.Time
pubsub.UnsubscribeAll(db.hub, c) rawExpireTime, ok := db.ttlMap.Get(key)
if ok {
expireTime, _ := rawExpireTime.(time.Time)
expiration = &expireTime
}
return cb(key, entity, expiration)
})
} }

81
exec.go
View File

@@ -1,81 +0,0 @@
package godis
import (
"fmt"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/pubsub"
"github.com/hdt3213/godis/redis/reply"
"runtime/debug"
"strings"
)
// Exec executes command
// parameter `cmdLine` contains command and its arguments, for example: "set key value"
func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Reply) {
defer func() {
if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{}
}
}()
cmdName := strings.ToLower(string(cmdLine[0]))
// authenticate
if cmdName == "auth" {
return Auth(db, c, cmdLine[1:])
}
if !isAuthenticated(c) {
return reply.MakeErrReply("NOAUTH Authentication required")
}
// special commands
done := false
result, done = execSpecialCmd(c, cmdLine, cmdName, db)
if done {
return result
}
if c != nil && c.InMultiState() {
return EnqueueCmd(db, c, cmdLine)
}
// normal commands
return execNormalCommand(db, cmdLine)
}
func execSpecialCmd(c redis.Connection, cmdLine [][]byte, cmdName string, db *DB) (redis.Reply, bool) {
if cmdName == "subscribe" {
if len(cmdLine) < 2 {
return reply.MakeArgNumErrReply("subscribe"), true
}
return pubsub.Subscribe(db.hub, c, cmdLine[1:]), true
} else if cmdName == "publish" {
return pubsub.Publish(db.hub, cmdLine[1:]), true
} else if cmdName == "unsubscribe" {
return pubsub.UnSubscribe(db.hub, c, cmdLine[1:]), true
} else if cmdName == "bgrewriteaof" {
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
return BGRewriteAOF(db, cmdLine[1:]), true
} else if cmdName == "multi" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName), true
}
return StartMulti(db, c), true
} else if cmdName == "discard" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName), true
}
return DiscardMulti(db, c), true
} else if cmdName == "exec" {
if len(cmdLine) != 1 {
return reply.MakeArgNumErrReply(cmdName), true
}
return execMulti(db, c), true
} else if cmdName == "watch" {
if !validateArity(-2, cmdLine) {
return reply.MakeArgNumErrReply(cmdName), true
}
return Watch(db, c, cmdLine[1:]), true
}
return nil, false
}

View File

@@ -1,68 +0,0 @@
package godis
import (
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/redis/reply"
"strings"
)
func execNormalCommand(db *DB, cmdArgs [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdArgs[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmdName + "'")
}
if !validateArity(cmd.arity, cmdArgs) {
return reply.MakeArgNumErrReply(cmdName)
}
prepare := cmd.prepare
write, read := prepare(cmdArgs[1:])
db.addVersion(write...)
db.RWLocks(write, read)
defer db.RWUnLocks(write, read)
fun := cmd.executor
return fun(db, cmdArgs[1:])
}
// ExecWithLock executes normal commands, invoker should provide locks
func (db *DB) ExecWithLock(cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmdName + "'")
}
if !validateArity(cmd.arity, cmdLine) {
return reply.MakeArgNumErrReply(cmdName)
}
fun := cmd.executor
return fun(db, cmdLine[1:])
}
// GetRelatedKeys analysis related keys
func (db *DB) GetRelatedKeys(cmdLine [][]byte) ([]string, []string) {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return nil, nil
}
prepare := cmd.prepare
if prepare == nil {
return nil, nil
}
return prepare(cmdLine[1:])
}
// GetUndoLogs return rollback commands
func (db *DB) GetUndoLogs(cmdLine [][]byte) []CmdLine {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return nil
}
undo := cmd.undo
if undo == nil {
return nil
}
return undo(db, cmdLine[1:])
}

7
geo.go
View File

@@ -5,6 +5,7 @@ import (
"github.com/hdt3213/godis/datastruct/sortedset" "github.com/hdt3213/godis/datastruct/sortedset"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/geohash" "github.com/hdt3213/godis/lib/geohash"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv" "strconv"
"strings" "strings"
@@ -51,9 +52,7 @@ func execGeoAdd(db *DB, args [][]byte) redis.Reply {
i++ i++
} }
} }
db.addAof(utils.ToCmdLine3("geoadd", args...))
db.AddAof(makeAofCmd("geoadd", args))
return reply.MakeIntReply(int64(i)) return reply.MakeIntReply(int64(i))
} }
@@ -87,7 +86,7 @@ func execGeoPos(db *DB, args [][]byte) redis.Reply {
member := string(args[i+1]) member := string(args[i+1])
elem, exists := sortedSet.Get(member) elem, exists := sortedSet.Get(member)
if !exists { if !exists {
positions[i] = (&reply.EmptyMultiBulkReply{}) positions[i] = &reply.EmptyMultiBulkReply{}
continue continue
} }
lat, lng := geohash.Decode(uint64(elem.Score)) lat, lng := geohash.Decode(uint64(elem.Score))

18
hash.go
View File

@@ -2,7 +2,9 @@ package godis
import ( import (
Dict "github.com/hdt3213/godis/datastruct/dict" Dict "github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"strconv" "strconv"
@@ -28,7 +30,7 @@ func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply r
inited = false inited = false
if dict == nil { if dict == nil {
dict = Dict.MakeSimple() dict = Dict.MakeSimple()
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: dict, Data: dict,
}) })
inited = true inited = true
@@ -50,7 +52,7 @@ func execHSet(db *DB, args [][]byte) redis.Reply {
} }
result := dict.Put(field, value) result := dict.Put(field, value)
db.AddAof(makeAofCmd("hset", args)) db.addAof(utils.ToCmdLine3("hset", args...))
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
} }
@@ -74,7 +76,7 @@ func execHSetNX(db *DB, args [][]byte) redis.Reply {
result := dict.PutIfAbsent(field, value) result := dict.PutIfAbsent(field, value)
if result > 0 { if result > 0 {
db.AddAof(makeAofCmd("hsetnx", args)) db.addAof(utils.ToCmdLine3("hsetnx", args...))
} }
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
@@ -153,7 +155,7 @@ func execHDel(db *DB, args [][]byte) redis.Reply {
db.Remove(key) db.Remove(key)
} }
if deleted > 0 { if deleted > 0 {
db.AddAof(makeAofCmd("hdel", args)) db.addAof(utils.ToCmdLine3("hdel", args...))
} }
return reply.MakeIntReply(int64(deleted)) return reply.MakeIntReply(int64(deleted))
@@ -210,7 +212,7 @@ func execHMSet(db *DB, args [][]byte) redis.Reply {
value := values[i] value := values[i]
dict.Put(field, value) dict.Put(field, value)
} }
db.AddAof(makeAofCmd("hmset", args)) db.addAof(utils.ToCmdLine3("hmset", args...))
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -344,7 +346,7 @@ func execHIncrBy(db *DB, args [][]byte) redis.Reply {
value, exists := dict.Get(field) value, exists := dict.Get(field)
if !exists { if !exists {
dict.Put(field, args[2]) dict.Put(field, args[2])
db.AddAof(makeAofCmd("hincrby", args)) db.addAof(utils.ToCmdLine3("hincrby", args...))
return reply.MakeBulkReply(args[2]) return reply.MakeBulkReply(args[2])
} }
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64) val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
@@ -354,7 +356,7 @@ func execHIncrBy(db *DB, args [][]byte) redis.Reply {
val += delta val += delta
bytes := []byte(strconv.FormatInt(val, 10)) bytes := []byte(strconv.FormatInt(val, 10))
dict.Put(field, bytes) dict.Put(field, bytes)
db.AddAof(makeAofCmd("hincrby", args)) db.addAof(utils.ToCmdLine3("hincrby", args...))
return reply.MakeBulkReply(bytes) return reply.MakeBulkReply(bytes)
} }
@@ -392,7 +394,7 @@ func execHIncrByFloat(db *DB, args [][]byte) redis.Reply {
result := val.Add(delta) result := val.Add(delta)
resultBytes := []byte(result.String()) resultBytes := []byte(result.String())
dict.Put(field, resultBytes) dict.Put(field, resultBytes)
db.AddAof(makeAofCmd("hincrbyfloat", args)) db.addAof(utils.ToCmdLine3("hincrbyfloat", args...))
return reply.MakeBulkReply(resultBytes) return reply.MakeBulkReply(resultBytes)
} }

31
interface/database/db.go Normal file
View File

@@ -0,0 +1,31 @@
package database
import (
"github.com/hdt3213/godis/interface/redis"
"time"
)
type CmdLine = [][]byte
// DB is the interface for redis style storage engine
type DB interface {
Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection)
Close()
}
// EmbedDB is the embedding storage engine exposing more methods for complex application
type EmbedDB interface {
DB
ExecWithLock(conn redis.Connection, args [][]byte) redis.Reply
ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply
GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine
ForEach(dbIndex int, cb func(key string, data *DataEntity, expiration *time.Time) bool)
RWLocks(dbIndex int, writeKeys []string, readKeys []string)
RWUnLocks(dbIndex int, writeKeys []string, readKeys []string)
}
// DataEntity stores data bound to a key, including a string, list, hash, set and so on
type DataEntity struct {
Data interface{}
}

View File

@@ -1,10 +0,0 @@
package db
import "github.com/hdt3213/godis/interface/redis"
// DB is the interface for redis style storage engine
type DB interface {
Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection)
Close()
}

View File

@@ -19,4 +19,8 @@ type Connection interface {
EnqueueCmd([][]byte) EnqueueCmd([][]byte)
ClearQueuedCmds() ClearQueuedCmds()
GetWatching() map[string]uint32 GetWatching() map[string]uint32
// used for multi database
GetDBIndex() int
SelectDB(int)
} }

57
keys.go
View File

@@ -1,11 +1,13 @@
package godis package godis
import ( import (
"github.com/hdt3213/godis/aof"
"github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/datastruct/list" "github.com/hdt3213/godis/datastruct/list"
"github.com/hdt3213/godis/datastruct/set" "github.com/hdt3213/godis/datastruct/set"
"github.com/hdt3213/godis/datastruct/sortedset" "github.com/hdt3213/godis/datastruct/sortedset"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/lib/wildcard" "github.com/hdt3213/godis/lib/wildcard"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv" "strconv"
@@ -21,7 +23,7 @@ func execDel(db *DB, args [][]byte) redis.Reply {
deleted := db.Removes(keys...) deleted := db.Removes(keys...)
if deleted > 0 { if deleted > 0 {
db.AddAof(makeAofCmd("del", args)) db.addAof(utils.ToCmdLine3("del", args...))
} }
return reply.MakeIntReply(int64(deleted)) return reply.MakeIntReply(int64(deleted))
} }
@@ -50,14 +52,7 @@ func execExists(db *DB, args [][]byte) redis.Reply {
// execFlushDB removes all data in current db // execFlushDB removes all data in current db
func execFlushDB(db *DB, args [][]byte) redis.Reply { func execFlushDB(db *DB, args [][]byte) redis.Reply {
db.Flush() db.Flush()
db.AddAof(makeAofCmd("flushdb", args)) db.addAof(utils.ToCmdLine3("flushdb", args...))
return &reply.OkReply{}
}
// execFlushAll removes all data in all db
func execFlushAll(db *DB, args [][]byte) redis.Reply {
db.Flush()
db.AddAof(makeAofCmd("flushdb", args))
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -110,7 +105,7 @@ func execRename(db *DB, args [][]byte) redis.Reply {
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
db.AddAof(makeAofCmd("rename", args)) db.addAof(utils.ToCmdLine3("rename", args...))
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -143,7 +138,7 @@ func execRenameNx(db *DB, args [][]byte) redis.Reply {
expireTime, _ := rawTTL.(time.Time) expireTime, _ := rawTTL.(time.Time)
db.Expire(dest, expireTime) db.Expire(dest, expireTime)
} }
db.AddAof(makeAofCmd("renamenx", args)) db.addAof(utils.ToCmdLine3("renamenx", args...))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -164,7 +159,7 @@ func execExpire(db *DB, args [][]byte) redis.Reply {
expireAt := time.Now().Add(ttl) expireAt := time.Now().Add(ttl)
db.Expire(key, expireAt) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireAt)) db.addAof(aof.MakeExpireCmd(key, expireAt).Args)
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -176,15 +171,15 @@ func execExpireAt(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
expireTime := time.Unix(raw, 0) expireAt := time.Unix(raw, 0)
_, exists := db.GetEntity(key) _, exists := db.GetEntity(key)
if !exists { if !exists {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
db.Expire(key, expireTime) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireAt).Args)
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -203,9 +198,9 @@ func execPExpire(db *DB, args [][]byte) redis.Reply {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
expireTime := time.Now().Add(ttl) expireAt := time.Now().Add(ttl)
db.Expire(key, expireTime) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireAt).Args)
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -217,16 +212,16 @@ func execPExpireAt(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
expireTime := time.Unix(0, raw*int64(time.Millisecond)) expireAt := time.Unix(0, raw*int64(time.Millisecond))
_, exists := db.GetEntity(key) _, exists := db.GetEntity(key)
if !exists { if !exists {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
db.Expire(key, expireTime) db.Expire(key, expireAt)
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireAt).Args)
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -278,16 +273,10 @@ func execPersist(db *DB, args [][]byte) redis.Reply {
} }
db.Persist(key) db.Persist(key)
db.AddAof(makeAofCmd("persist", args)) db.addAof(utils.ToCmdLine3("persist", args...))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
// BGRewriteAOF asynchronously rewrites Append-Only-File
func BGRewriteAOF(db *DB, args [][]byte) redis.Reply {
go db.aofRewrite()
return reply.MakeStatusReply("Background append only file rewriting started")
}
// execKeys returns all keys matching the given pattern // execKeys returns all keys matching the given pattern
func execKeys(db *DB, args [][]byte) redis.Reply { func execKeys(db *DB, args [][]byte) redis.Reply {
pattern := wildcard.CompilePattern(string(args[0])) pattern := wildcard.CompilePattern(string(args[0]))
@@ -301,6 +290,17 @@ func execKeys(db *DB, args [][]byte) redis.Reply {
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
func toTTLCmd(db *DB, key string) *reply.MultiBulkReply {
raw, exists := db.ttlMap.Get(key)
if !exists {
// 无 TTL
return reply.MakeMultiBulkReply(utils.ToCmdLine("PERSIST", key))
}
expireTime, _ := raw.(time.Time)
timestamp := strconv.FormatInt(expireTime.UnixNano()/1000/1000, 10)
return reply.MakeMultiBulkReply(utils.ToCmdLine("PEXPIREAT", key, timestamp))
}
func undoExpire(db *DB, args [][]byte) []CmdLine { func undoExpire(db *DB, args [][]byte) []CmdLine {
key := string(args[0]) key := string(args[0])
return []CmdLine{ return []CmdLine{
@@ -322,6 +322,5 @@ func init() {
RegisterCommand("Rename", execRename, prepareRename, undoRename, 3) RegisterCommand("Rename", execRename, prepareRename, undoRename, 3)
RegisterCommand("RenameNx", execRenameNx, prepareRename, undoRename, 3) RegisterCommand("RenameNx", execRenameNx, prepareRename, undoRename, 3)
RegisterCommand("FlushDB", execFlushDB, noPrepare, nil, -1) RegisterCommand("FlushDB", execFlushDB, noPrepare, nil, -1)
RegisterCommand("FlushAll", execFlushAll, noPrepare, nil, -1)
RegisterCommand("Keys", execKeys, noPrepare, nil, 2) RegisterCommand("Keys", execKeys, noPrepare, nil, 2)
} }

View File

@@ -18,6 +18,15 @@ func ToCmdLine2(commandName string, args ...string) [][]byte {
return result return result
} }
func ToCmdLine3(commandName string, args ...[]byte) [][]byte {
result := make([][]byte, len(args)+1)
result[0] = []byte(commandName)
for i, s := range args {
result[i+1] = s
}
return result
}
// Equals check whether the given value is equal // Equals check whether the given value is equal
func Equals(a interface{}, b interface{}) bool { func Equals(a interface{}, b interface{}) bool {
sliceA, okA := a.([]byte) sliceA, okA := a.([]byte)

21
list.go
View File

@@ -2,6 +2,7 @@ package godis
import ( import (
List "github.com/hdt3213/godis/datastruct/list" List "github.com/hdt3213/godis/datastruct/list"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
@@ -28,7 +29,7 @@ func (db *DB) getOrInitList(key string) (list *List.LinkedList, isNew bool, errR
isNew = false isNew = false
if list == nil { if list == nil {
list = &List.LinkedList{} list = &List.LinkedList{}
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: list, Data: list,
}) })
isNew = true isNew = true
@@ -103,7 +104,7 @@ func execLPop(db *DB, args [][]byte) redis.Reply {
if list.Len() == 0 { if list.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
db.AddAof(makeAofCmd("lpop", args)) db.addAof(utils.ToCmdLine3("lpop", args...))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
@@ -144,7 +145,7 @@ func execLPush(db *DB, args [][]byte) redis.Reply {
list.Insert(0, value) list.Insert(0, value)
} }
db.AddAof(makeAofCmd("lpush", args)) db.addAof(utils.ToCmdLine3("lpush", args...))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
@@ -176,7 +177,7 @@ func execLPushX(db *DB, args [][]byte) redis.Reply {
for _, value := range values { for _, value := range values {
list.Insert(0, value) list.Insert(0, value)
} }
db.AddAof(makeAofCmd("lpushx", args)) db.addAof(utils.ToCmdLine3("lpushx", args...))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
@@ -269,7 +270,7 @@ func execLRem(db *DB, args [][]byte) redis.Reply {
db.Remove(key) db.Remove(key)
} }
if removed > 0 { if removed > 0 {
db.AddAof(makeAofCmd("lrem", args)) db.addAof(utils.ToCmdLine3("lrem", args...))
} }
return reply.MakeIntReply(int64(removed)) return reply.MakeIntReply(int64(removed))
@@ -305,7 +306,7 @@ func execLSet(db *DB, args [][]byte) redis.Reply {
} }
list.Set(index, value) list.Set(index, value)
db.AddAof(makeAofCmd("lset", args)) db.addAof(utils.ToCmdLine3("lset", args...))
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -360,7 +361,7 @@ func execRPop(db *DB, args [][]byte) redis.Reply {
if list.Len() == 0 { if list.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
db.AddAof(makeAofCmd("rpop", args)) db.addAof(utils.ToCmdLine3("rpop", args...))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
@@ -420,7 +421,7 @@ func execRPopLPush(db *DB, args [][]byte) redis.Reply {
db.Remove(sourceKey) db.Remove(sourceKey)
} }
db.AddAof(makeAofCmd("rpoplpush", args)) db.addAof(utils.ToCmdLine3("rpoplpush", args...))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
@@ -463,7 +464,7 @@ func execRPush(db *DB, args [][]byte) redis.Reply {
for _, value := range values { for _, value := range values {
list.Add(value) list.Add(value)
} }
db.AddAof(makeAofCmd("rpush", args)) db.addAof(utils.ToCmdLine3("rpush", args...))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
@@ -498,7 +499,7 @@ func execRPushX(db *DB, args [][]byte) redis.Reply {
for _, value := range values { for _, value := range values {
list.Add(value) list.Add(value)
} }
db.AddAof(makeAofCmd("rpushx", args)) db.addAof(utils.ToCmdLine3("rpushx", args...))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }

View File

@@ -2,5 +2,5 @@ bind 0.0.0.0
port 6399 port 6399
maxclients 128 maxclients 128
appendonly no appendonly yes
appendfilename appendonly.aof appendfilename appendonly.aof

View File

@@ -28,6 +28,9 @@ type Connection struct {
multiState bool multiState bool
queue [][][]byte queue [][][]byte
watching map[string]uint32 watching map[string]uint32
// selected db
selectedDB int
} }
// RemoteAddr returns the remote network address // RemoteAddr returns the remote network address
@@ -147,6 +150,14 @@ func (c *Connection) GetWatching() map[string]uint32 {
return c.watching return c.watching
} }
func (c *Connection) GetDBIndex() int {
return c.selectedDB
}
func (c *Connection) SelectDB(dbNum int) {
c.selectedDB = dbNum
}
// FakeConn implements redis.Connection for test // FakeConn implements redis.Connection for test
type FakeConn struct { type FakeConn struct {
Connection Connection

View File

@@ -9,7 +9,7 @@ import (
"github.com/hdt3213/godis" "github.com/hdt3213/godis"
"github.com/hdt3213/godis/cluster" "github.com/hdt3213/godis/cluster"
"github.com/hdt3213/godis/config" "github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/db" "github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/atomic" "github.com/hdt3213/godis/lib/sync/atomic"
"github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/connection"
@@ -28,18 +28,18 @@ var (
// Handler implements tcp.Handler and serves as a redis server // Handler implements tcp.Handler and serves as a redis server
type Handler struct { type Handler struct {
activeConn sync.Map // *client -> placeholder activeConn sync.Map // *client -> placeholder
db db.DB db database.DB
closing atomic.Boolean // refusing new client and new request closing atomic.Boolean // refusing new client and new request
} }
// MakeHandler creates a Handler instance // MakeHandler creates a Handler instance
func MakeHandler() *Handler { func MakeHandler() *Handler {
var db db.DB var db database.DB
if config.Properties.Self != "" && if config.Properties.Self != "" &&
len(config.Properties.Peers) > 0 { len(config.Properties.Peers) > 0 {
db = cluster.MakeCluster() db = cluster.MakeCluster()
} else { } else {
db = godis.MakeDB() db = godis.NewStandaloneServer()
} }
return &Handler{ return &Handler{
db: db, db: db,

View File

@@ -11,6 +11,7 @@ type command struct {
prepare PreFunc // return related keys command prepare PreFunc // return related keys command
undo UndoFunc undo UndoFunc
arity int // allow number of args, arity < 0 means len(args) >= -arity arity int // allow number of args, arity < 0 means len(args) >= -arity
flags int
} }
// RegisterCommand registers a new command // RegisterCommand registers a new command

220
server.go
View File

@@ -1,45 +1,217 @@
package godis package godis
import ( import (
"fmt"
"github.com/hdt3213/godis/aof"
"github.com/hdt3213/godis/config" "github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/pubsub"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"runtime/debug"
"strconv"
"strings"
"time"
) )
// Ping the server type MultiDB struct {
func Ping(db *DB, args [][]byte) redis.Reply { dbSet []*DB
if len(args) == 0 {
return &reply.PongReply{} // handle publish/subscribe
} else if len(args) == 1 { hub *pubsub.Hub
return reply.MakeStatusReply(string(args[0])) // handle aof persistence
} else { aofHandler *aof.Handler
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") }
// NewStandaloneServer creates a standalone redis server, with multi database and all other funtions
func NewStandaloneServer() *MultiDB {
mdb := &MultiDB{}
if config.Properties.Databases == 0 {
config.Properties.Databases = 16
}
mdb.dbSet = make([]*DB, config.Properties.Databases)
for i := range mdb.dbSet {
singleDB := makeDB()
singleDB.index = i
mdb.dbSet[i] = singleDB
}
mdb.hub = pubsub.MakeHub()
if config.Properties.AppendOnly {
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB {
return MakeBasicMultiDB()
})
if err != nil {
panic(err)
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
// avoid closure
singleDB := db
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
}
}
return mdb
}
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
func MakeBasicMultiDB() *MultiDB {
mdb := &MultiDB{}
mdb.dbSet = make([]*DB, config.Properties.Databases)
for i := range mdb.dbSet {
mdb.dbSet[i] = makeBasicDB()
}
return mdb
}
// Exec executes command
// parameter `cmdLine` contains command and its arguments, for example: "set key value"
func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Reply) {
defer func() {
if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{}
}
}()
cmdName := strings.ToLower(string(cmdLine[0]))
// authenticate
if cmdName == "auth" {
return Auth(c, cmdLine[1:])
}
if !isAuthenticated(c) {
return reply.MakeErrReply("NOAUTH Authentication required")
}
// special commands
if cmdName == "subscribe" {
if len(cmdLine) < 2 {
return reply.MakeArgNumErrReply("subscribe")
}
return pubsub.Subscribe(mdb.hub, c, cmdLine[1:])
} else if cmdName == "publish" {
return pubsub.Publish(mdb.hub, cmdLine[1:])
} else if cmdName == "unsubscribe" {
return pubsub.UnSubscribe(mdb.hub, c, cmdLine[1:])
} else if cmdName == "bgrewriteaof" {
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
return BGRewriteAOF(mdb, cmdLine[1:])
} else if cmdName == "rewriteaof" {
return RewriteAOF(mdb, cmdLine[1:])
} else if cmdName == "flushall" {
return mdb.flushAll()
} else if cmdName == "select" {
if c != nil && c.InMultiState() {
return reply.MakeErrReply("cannot select database within multi")
}
if len(cmdLine) != 2 {
return reply.MakeArgNumErrReply("select")
}
return execSelect(c, mdb, cmdLine[1:])
}
// todo: support multi database transaction
// normal commands
dbIndex := c.GetDBIndex()
if dbIndex >= len(mdb.dbSet) {
return reply.MakeErrReply("ERR DB index is out of range")
}
selectedDB := mdb.dbSet[dbIndex]
return selectedDB.Exec(c, cmdLine)
}
// AfterClientClose does some clean after client close connection
func (mdb *MultiDB) AfterClientClose(c redis.Connection) {
pubsub.UnsubscribeAll(mdb.hub, c)
}
// Close graceful shutdown database
func (mdb *MultiDB) Close() {
if mdb.aofHandler != nil {
mdb.aofHandler.Close()
} }
} }
// Auth validate client's password func execSelect(c redis.Connection, mdb *MultiDB, args [][]byte) redis.Reply {
func Auth(db *DB, c redis.Connection, args [][]byte) redis.Reply { dbIndex, err := strconv.Atoi(string(args[0]))
if len(args) != 1 { if err != nil {
return reply.MakeErrReply("ERR wrong number of arguments for 'auth' command") return reply.MakeErrReply("ERR invalid DB index")
} }
if config.Properties.RequirePass == "" { if dbIndex >= len(mdb.dbSet) {
return reply.MakeErrReply("ERR Client sent AUTH, but no password is set") return reply.MakeErrReply("ERR DB index is out of range")
} }
passwd := string(args[0]) c.SelectDB(dbIndex)
c.SetPassword(passwd) return reply.MakeOkReply()
if config.Properties.RequirePass != passwd { }
return reply.MakeErrReply("ERR invalid password")
func (mdb *MultiDB) flushAll() redis.Reply {
for _, db := range mdb.dbSet {
db.Flush()
}
if mdb.aofHandler != nil {
mdb.aofHandler.AddAof(0, utils.ToCmdLine("FlushAll"))
} }
return &reply.OkReply{} return &reply.OkReply{}
} }
func isAuthenticated(c redis.Connection) bool { func (mdb *MultiDB) ForEach(dbIndex int, cb func(key string, data *database.DataEntity, expiration *time.Time) bool) {
if config.Properties.RequirePass == "" { if dbIndex >= len(mdb.dbSet) {
return true return
} }
return c.GetPassword() == config.Properties.RequirePass db := mdb.dbSet[dbIndex]
db.ForEach(cb)
} }
func init() { func (mdb *MultiDB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply {
RegisterCommand("ping", Ping, noPrepare, nil, -1) if conn.GetDBIndex() >= len(mdb.dbSet) {
return reply.MakeErrReply("ERR DB index is out of range")
}
db := mdb.dbSet[conn.GetDBIndex()]
return db.ExecMulti(conn, watching, cmdLines)
}
func (mdb *MultiDB) RWLocks(dbIndex int, writeKeys []string, readKeys []string) {
if dbIndex >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
}
db := mdb.dbSet[dbIndex]
db.RWLocks(writeKeys, readKeys)
}
func (mdb *MultiDB) RWUnLocks(dbIndex int, writeKeys []string, readKeys []string) {
if dbIndex >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
}
db := mdb.dbSet[dbIndex]
db.RWUnLocks(writeKeys, readKeys)
}
func (mdb *MultiDB) GetUndoLogs(dbIndex int, cmdLine [][]byte) []CmdLine {
if dbIndex >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
}
db := mdb.dbSet[dbIndex]
return db.GetUndoLogs(cmdLine)
}
func (mdb *MultiDB) ExecWithLock(conn redis.Connection, cmdLine [][]byte) redis.Reply {
if conn.GetDBIndex() >= len(mdb.dbSet) {
panic("ERR DB index is out of range")
}
db := mdb.dbSet[conn.GetDBIndex()]
return db.execWithLock(cmdLine)
}
// BGRewriteAOF asynchronously rewrites Append-Only-File
func BGRewriteAOF(db *MultiDB, args [][]byte) redis.Reply {
go db.aofHandler.Rewrite()
return reply.MakeStatusReply("Background append only file rewriting started")
}
func RewriteAOF(db *MultiDB, args [][]byte) redis.Reply {
db.aofHandler.Rewrite()
return reply.MakeStatusReply("Background append only file rewriting started")
} }

20
set.go
View File

@@ -2,7 +2,9 @@ package godis
import ( import (
HashSet "github.com/hdt3213/godis/datastruct/set" HashSet "github.com/hdt3213/godis/datastruct/set"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv" "strconv"
) )
@@ -27,7 +29,7 @@ func (db *DB) getOrInitSet(key string) (set *HashSet.Set, inited bool, errReply
inited = false inited = false
if set == nil { if set == nil {
set = HashSet.Make() set = HashSet.Make()
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: set, Data: set,
}) })
inited = true inited = true
@@ -49,7 +51,7 @@ func execSAdd(db *DB, args [][]byte) redis.Reply {
for _, member := range members { for _, member := range members {
counter += set.Add(string(member)) counter += set.Add(string(member))
} }
db.AddAof(makeAofCmd("sadd", args)) db.addAof(utils.ToCmdLine3("sadd", args...))
return reply.MakeIntReply(int64(counter)) return reply.MakeIntReply(int64(counter))
} }
@@ -94,7 +96,7 @@ func execSRem(db *DB, args [][]byte) redis.Reply {
db.Remove(key) db.Remove(key)
} }
if counter > 0 { if counter > 0 {
db.AddAof(makeAofCmd("srem", args)) db.addAof(utils.ToCmdLine3("srem", args...))
} }
return reply.MakeIntReply(int64(counter)) return reply.MakeIntReply(int64(counter))
} }
@@ -210,10 +212,10 @@ func execSInterStore(db *DB, args [][]byte) redis.Reply {
} }
set := HashSet.Make(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sinterstore", args)) db.addAof(utils.ToCmdLine3("sinterstore", args...))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
@@ -289,11 +291,11 @@ func execSUnionStore(db *DB, args [][]byte) redis.Reply {
} }
set := HashSet.Make(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sunionstore", args)) db.addAof(utils.ToCmdLine3("sunionstore", args...))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }
@@ -385,11 +387,11 @@ func execSDiffStore(db *DB, args [][]byte) redis.Reply {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
set := HashSet.Make(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &database.DataEntity{
Data: set, Data: set,
}) })
db.AddAof(makeAofCmd("sdiffstore", args)) db.addAof(utils.ToCmdLine3("sdiffstore", args...))
return reply.MakeIntReply(int64(set.Len())) return reply.MakeIntReply(int64(set.Len()))
} }

View File

@@ -2,7 +2,9 @@ package godis
import ( import (
SortedSet "github.com/hdt3213/godis/datastruct/sortedset" SortedSet "github.com/hdt3213/godis/datastruct/sortedset"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"strconv" "strconv"
"strings" "strings"
@@ -28,7 +30,7 @@ func (db *DB) getOrInitSortedSet(key string) (sortedSet *SortedSet.SortedSet, in
inited = false inited = false
if sortedSet == nil { if sortedSet == nil {
sortedSet = SortedSet.Make() sortedSet = SortedSet.Make()
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: sortedSet, Data: sortedSet,
}) })
inited = true inited = true
@@ -70,7 +72,7 @@ func execZAdd(db *DB, args [][]byte) redis.Reply {
} }
} }
db.AddAof(makeAofCmd("zadd", args)) db.addAof(utils.ToCmdLine3("zadd", args...))
return reply.MakeIntReply(int64(i)) return reply.MakeIntReply(int64(i))
} }
@@ -456,7 +458,7 @@ func execZRemRangeByScore(db *DB, args [][]byte) redis.Reply {
removed := sortedSet.RemoveByScore(min, max) removed := sortedSet.RemoveByScore(min, max)
if removed > 0 { if removed > 0 {
db.AddAof(makeAofCmd("zremrangebyscore", args)) db.addAof(utils.ToCmdLine3("zremrangebyscore", args...))
} }
return reply.MakeIntReply(removed) return reply.MakeIntReply(removed)
} }
@@ -507,7 +509,7 @@ func execZRemRangeByRank(db *DB, args [][]byte) redis.Reply {
// assert: start in [0, size - 1], stop in [start, size] // assert: start in [0, size - 1], stop in [start, size]
removed := sortedSet.RemoveByRank(start, stop) removed := sortedSet.RemoveByRank(start, stop)
if removed > 0 { if removed > 0 {
db.AddAof(makeAofCmd("zremrangebyrank", args)) db.addAof(utils.ToCmdLine3("zremrangebyrank", args...))
} }
return reply.MakeIntReply(removed) return reply.MakeIntReply(removed)
} }
@@ -538,7 +540,7 @@ func execZRem(db *DB, args [][]byte) redis.Reply {
} }
} }
if deleted > 0 { if deleted > 0 {
db.AddAof(makeAofCmd("zrem", args)) db.addAof(utils.ToCmdLine3("zrem", args...))
} }
return reply.MakeIntReply(deleted) return reply.MakeIntReply(deleted)
} }
@@ -572,13 +574,13 @@ func execZIncrBy(db *DB, args [][]byte) redis.Reply {
element, exists := sortedSet.Get(field) element, exists := sortedSet.Get(field)
if !exists { if !exists {
sortedSet.Add(field, delta) sortedSet.Add(field, delta)
db.AddAof(makeAofCmd("zincrby", args)) db.addAof(utils.ToCmdLine3("zincrby", args...))
return reply.MakeBulkReply(args[1]) return reply.MakeBulkReply(args[1])
} }
score := element.Score + delta score := element.Score + delta
sortedSet.Add(field, score) sortedSet.Add(field, score)
bytes := []byte(strconv.FormatFloat(score, 'f', -1, 64)) bytes := []byte(strconv.FormatFloat(score, 'f', -1, 64))
db.AddAof(makeAofCmd("zincrby", args)) db.addAof(utils.ToCmdLine3("zincrby", args...))
return reply.MakeBulkReply(bytes) return reply.MakeBulkReply(bytes)
} }

View File

@@ -1,7 +1,10 @@
package godis package godis
import ( import (
"github.com/hdt3213/godis/aof"
"github.com/hdt3213/godis/interface/database"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"strconv" "strconv"
@@ -102,7 +105,7 @@ func execSet(db *DB, args [][]byte) redis.Reply {
} }
} }
entity := &DataEntity{ entity := &database.DataEntity{
Data: value, Data: value,
} }
@@ -124,17 +127,17 @@ func execSet(db *DB, args [][]byte) redis.Reply {
if ttl != unlimitedTTL { if ttl != unlimitedTTL {
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
db.AddAof(reply.MakeMultiBulkReply([][]byte{ db.addAof(CmdLine{
[]byte("SET"), []byte("SET"),
args[0], args[0],
args[1], args[1],
})) })
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
} else if result > 0 { } else if result > 0 {
db.Persist(key) // override ttl db.Persist(key) // override ttl
db.AddAof(makeAofCmd("set", args)) db.addAof(utils.ToCmdLine3("set", args...))
} else { } else {
db.AddAof(makeAofCmd("set", args)) db.addAof(utils.ToCmdLine3("set", args...))
} }
if policy == upsertPolicy || result > 0 { if policy == upsertPolicy || result > 0 {
@@ -147,11 +150,11 @@ func execSet(db *DB, args [][]byte) redis.Reply {
func execSetNX(db *DB, args [][]byte) redis.Reply { func execSetNX(db *DB, args [][]byte) redis.Reply {
key := string(args[0]) key := string(args[0])
value := args[1] value := args[1]
entity := &DataEntity{ entity := &database.DataEntity{
Data: value, Data: value,
} }
result := db.PutIfAbsent(key, entity) result := db.PutIfAbsent(key, entity)
db.AddAof(makeAofCmd("setnx", args)) db.addAof(utils.ToCmdLine3("setnx", args...))
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
} }
@@ -169,15 +172,15 @@ func execSetEX(db *DB, args [][]byte) redis.Reply {
} }
ttl := ttlArg * 1000 ttl := ttlArg * 1000
entity := &DataEntity{ entity := &database.DataEntity{
Data: value, Data: value,
} }
db.PutEntity(key, entity) db.PutEntity(key, entity)
expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
db.AddAof(makeAofCmd("setex", args)) db.addAof(utils.ToCmdLine3("setex", args...))
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -194,15 +197,15 @@ func execPSetEX(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("ERR invalid expire time in setex") return reply.MakeErrReply("ERR invalid expire time in setex")
} }
entity := &DataEntity{ entity := &database.DataEntity{
Data: value, Data: value,
} }
db.PutEntity(key, entity) db.PutEntity(key, entity)
expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond) expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
db.Expire(key, expireTime) db.Expire(key, expireTime)
db.AddAof(makeAofCmd("setex", args)) db.addAof(utils.ToCmdLine3("setex", args...))
db.AddAof(makeExpireCmd(key, expireTime)) db.addAof(aof.MakeExpireCmd(key, expireTime).Args)
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -237,9 +240,9 @@ func execMSet(db *DB, args [][]byte) redis.Reply {
for i, key := range keys { for i, key := range keys {
value := values[i] value := values[i]
db.PutEntity(key, &DataEntity{Data: value}) db.PutEntity(key, &database.DataEntity{Data: value})
} }
db.AddAof(makeAofCmd("mset", args)) db.addAof(utils.ToCmdLine3("mset", args...))
return &reply.OkReply{} return &reply.OkReply{}
} }
@@ -299,9 +302,9 @@ func execMSetNX(db *DB, args [][]byte) redis.Reply {
for i, key := range keys { for i, key := range keys {
value := values[i] value := values[i]
db.PutEntity(key, &DataEntity{Data: value}) db.PutEntity(key, &database.DataEntity{Data: value})
} }
db.AddAof(makeAofCmd("msetnx", args)) db.addAof(utils.ToCmdLine3("msetnx", args...))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -315,9 +318,9 @@ func execGetSet(db *DB, args [][]byte) redis.Reply {
return err return err
} }
db.PutEntity(key, &DataEntity{Data: value}) db.PutEntity(key, &database.DataEntity{Data: value})
db.Persist(key) // override ttl db.Persist(key) // override ttl
db.AddAof(makeAofCmd("getset", args)) db.addAof(utils.ToCmdLine3("getset", args...))
if old == nil { if old == nil {
return new(reply.NullBulkReply) return new(reply.NullBulkReply)
} }
@@ -337,16 +340,16 @@ func execIncr(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte(strconv.FormatInt(val+1, 10)), Data: []byte(strconv.FormatInt(val+1, 10)),
}) })
db.AddAof(makeAofCmd("incr", args)) db.addAof(utils.ToCmdLine3("incr", args...))
return reply.MakeIntReply(val + 1) return reply.MakeIntReply(val + 1)
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte("1"), Data: []byte("1"),
}) })
db.AddAof(makeAofCmd("incr", args)) db.addAof(utils.ToCmdLine3("incr", args...))
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
@@ -369,16 +372,16 @@ func execIncrBy(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte(strconv.FormatInt(val+delta, 10)), Data: []byte(strconv.FormatInt(val+delta, 10)),
}) })
db.AddAof(makeAofCmd("incrby", args)) db.addAof(utils.ToCmdLine3("incrby", args...))
return reply.MakeIntReply(val + delta) return reply.MakeIntReply(val + delta)
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: args[1], Data: args[1],
}) })
db.AddAof(makeAofCmd("incrby", args)) db.addAof(utils.ToCmdLine3("incrby", args...))
return reply.MakeIntReply(delta) return reply.MakeIntReply(delta)
} }
@@ -401,16 +404,16 @@ func execIncrByFloat(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
resultBytes := []byte(val.Add(delta).String()) resultBytes := []byte(val.Add(delta).String())
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: resultBytes, Data: resultBytes,
}) })
db.AddAof(makeAofCmd("incrbyfloat", args)) db.addAof(utils.ToCmdLine3("incrbyfloat", args...))
return reply.MakeBulkReply(resultBytes) return reply.MakeBulkReply(resultBytes)
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: args[1], Data: args[1],
}) })
db.AddAof(makeAofCmd("incrbyfloat", args)) db.addAof(utils.ToCmdLine3("incrbyfloat", args...))
return reply.MakeBulkReply(args[1]) return reply.MakeBulkReply(args[1])
} }
@@ -427,17 +430,17 @@ func execDecr(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte(strconv.FormatInt(val-1, 10)), Data: []byte(strconv.FormatInt(val-1, 10)),
}) })
db.AddAof(makeAofCmd("decr", args)) db.addAof(utils.ToCmdLine3("decr", args...))
return reply.MakeIntReply(val - 1) return reply.MakeIntReply(val - 1)
} }
entity := &DataEntity{ entity := &database.DataEntity{
Data: []byte("-1"), Data: []byte("-1"),
} }
db.PutEntity(key, entity) db.PutEntity(key, entity)
db.AddAof(makeAofCmd("decr", args)) db.addAof(utils.ToCmdLine3("decr", args...))
return reply.MakeIntReply(-1) return reply.MakeIntReply(-1)
} }
@@ -459,17 +462,17 @@ func execDecrBy(db *DB, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte(strconv.FormatInt(val-delta, 10)), Data: []byte(strconv.FormatInt(val-delta, 10)),
}) })
db.AddAof(makeAofCmd("decrby", args)) db.addAof(utils.ToCmdLine3("decrby", args...))
return reply.MakeIntReply(val - delta) return reply.MakeIntReply(val - delta)
} }
valueStr := strconv.FormatInt(-delta, 10) valueStr := strconv.FormatInt(-delta, 10)
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: []byte(valueStr), Data: []byte(valueStr),
}) })
db.AddAof(makeAofCmd("decrby", args)) db.addAof(utils.ToCmdLine3("decrby", args...))
return reply.MakeIntReply(-delta) return reply.MakeIntReply(-delta)
} }
@@ -494,10 +497,10 @@ func execAppend(db *DB, args [][]byte) redis.Reply {
return err return err
} }
bytes = append(bytes, args[1]...) bytes = append(bytes, args[1]...)
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: bytes, Data: bytes,
}) })
db.AddAof(makeAofCmd("append", args)) db.addAof(utils.ToCmdLine3("append", args...))
return reply.MakeIntReply(int64(len(bytes))) return reply.MakeIntReply(int64(len(bytes)))
} }
@@ -529,10 +532,10 @@ func execSetRange(db *DB, args [][]byte) redis.Reply {
bytes[idx] = value[i] bytes[idx] = value[i]
} }
} }
db.PutEntity(key, &DataEntity{ db.PutEntity(key, &database.DataEntity{
Data: bytes, Data: bytes,
}) })
db.AddAof(makeAofCmd("setRange", args)) db.addAof(utils.ToCmdLine3("setRange", args...))
return reply.MakeIntReply(int64(len(bytes))) return reply.MakeIntReply(int64(len(bytes)))
} }

View File

@@ -10,6 +10,7 @@ import (
) )
var testDB = makeTestDB() var testDB = makeTestDB()
var testServer = NewStandaloneServer()
func TestSet2(t *testing.T) { func TestSet2(t *testing.T) {
key := utils.RandString(10) key := utils.RandString(10)

45
sys.go Normal file
View File

@@ -0,0 +1,45 @@
package godis
import (
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/redis/reply"
)
// Ping the server
func Ping(db *DB, args [][]byte) redis.Reply {
if len(args) == 0 {
return &reply.PongReply{}
} else if len(args) == 1 {
return reply.MakeStatusReply(string(args[0]))
} else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
}
}
// Auth validate client's password
func Auth(c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'auth' command")
}
if config.Properties.RequirePass == "" {
return reply.MakeErrReply("ERR Client sent AUTH, but no password is set")
}
passwd := string(args[0])
c.SetPassword(passwd)
if config.Properties.RequirePass != passwd {
return reply.MakeErrReply("ERR invalid password")
}
return &reply.OkReply{}
}
func isAuthenticated(c redis.Connection) bool {
if config.Properties.RequirePass == "" {
return true
}
return c.GetPassword() == config.Properties.RequirePass
}
func init() {
RegisterCommand("ping", Ping, noPrepare, nil, -1)
}

View File

@@ -21,20 +21,20 @@ func TestPing(t *testing.T) {
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
passwd := utils.RandString(10) passwd := utils.RandString(10)
c := &connection.FakeConn{} c := &connection.FakeConn{}
ret := testDB.Exec(c, utils.ToCmdLine("AUTH")) ret := testServer.Exec(c, utils.ToCmdLine("AUTH"))
asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command") asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command")
ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd)) ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd))
asserts.AssertErrReply(t, ret, "ERR Client sent AUTH, but no password is set") asserts.AssertErrReply(t, ret, "ERR Client sent AUTH, but no password is set")
config.Properties.RequirePass = passwd config.Properties.RequirePass = passwd
defer func() { defer func() {
config.Properties.RequirePass = "" config.Properties.RequirePass = ""
}() }()
ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd+"wrong")) ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd+"wrong"))
asserts.AssertErrReply(t, ret, "ERR invalid password") asserts.AssertErrReply(t, ret, "ERR invalid password")
ret = testDB.Exec(c, utils.ToCmdLine("PING")) ret = testServer.Exec(c, utils.ToCmdLine("PING"))
asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") asserts.AssertErrReply(t, ret, "NOAUTH Authentication required")
ret = testDB.Exec(c, utils.ToCmdLine("AUTH", passwd)) ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd))
asserts.AssertStatusReply(t, ret, "OK") asserts.AssertStatusReply(t, ret, "OK")
} }

View File

@@ -44,7 +44,7 @@ func isWatchingChanged(db *DB, watching map[string]uint32) bool {
} }
// StartMulti starts multi-command-transaction // StartMulti starts multi-command-transaction
func StartMulti(db *DB, conn redis.Connection) redis.Reply { func StartMulti(conn redis.Connection) redis.Reply {
if conn.InMultiState() { if conn.InMultiState() {
return reply.MakeErrReply("ERR MULTI calls can not be nested") return reply.MakeErrReply("ERR MULTI calls can not be nested")
} }
@@ -53,7 +53,7 @@ func StartMulti(db *DB, conn redis.Connection) redis.Reply {
} }
// EnqueueCmd puts command line into `multi` pending queue // EnqueueCmd puts command line into `multi` pending queue
func EnqueueCmd(db *DB, conn redis.Connection, cmdLine [][]byte) redis.Reply { func EnqueueCmd(conn redis.Connection, cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0])) cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName] cmd, ok := cmdTable[cmdName]
if !ok { if !ok {
@@ -79,11 +79,11 @@ func execMulti(db *DB, conn redis.Connection) redis.Reply {
} }
defer conn.SetMultiState(false) defer conn.SetMultiState(false)
cmdLines := conn.GetQueuedCmdLine() cmdLines := conn.GetQueuedCmdLine()
return ExecMulti(db, conn, conn.GetWatching(), cmdLines) return db.ExecMulti(conn, conn.GetWatching(), cmdLines)
} }
// ExecMulti executes multi commands transaction Atomically and Isolated // ExecMulti executes multi commands transaction Atomically and Isolated
func ExecMulti(db *DB, conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply { func (db *DB) ExecMulti(conn redis.Connection, watching map[string]uint32, cmdLines []CmdLine) redis.Reply {
// prepare // prepare
writeKeys := make([]string, 0) // may contains duplicate writeKeys := make([]string, 0) // may contains duplicate
readKeys := make([]string, 0) readKeys := make([]string, 0)
@@ -113,7 +113,7 @@ func ExecMulti(db *DB, conn redis.Connection, watching map[string]uint32, cmdLin
undoCmdLines := make([][]CmdLine, 0, len(cmdLines)) undoCmdLines := make([][]CmdLine, 0, len(cmdLines))
for _, cmdLine := range cmdLines { for _, cmdLine := range cmdLines {
undoCmdLines = append(undoCmdLines, db.GetUndoLogs(cmdLine)) undoCmdLines = append(undoCmdLines, db.GetUndoLogs(cmdLine))
result := db.ExecWithLock(cmdLine) result := db.execWithLock(cmdLine)
if reply.IsErrorReply(result) { if reply.IsErrorReply(result) {
aborted = true aborted = true
// don't rollback failed commands // don't rollback failed commands
@@ -134,14 +134,14 @@ func ExecMulti(db *DB, conn redis.Connection, watching map[string]uint32, cmdLin
continue continue
} }
for _, cmdLine := range curCmdLines { for _, cmdLine := range curCmdLines {
db.ExecWithLock(cmdLine) db.execWithLock(cmdLine)
} }
} }
return reply.MakeErrReply("EXECABORT Transaction discarded because of previous errors.") return reply.MakeErrReply("EXECABORT Transaction discarded because of previous errors.")
} }
// DiscardMulti drops MULTI pending commands // DiscardMulti drops MULTI pending commands
func DiscardMulti(db *DB, conn redis.Connection) redis.Reply { func DiscardMulti(conn redis.Connection) redis.Reply {
if !conn.InMultiState() { if !conn.InMultiState() {
return reply.MakeErrReply("ERR DISCARD without MULTI") return reply.MakeErrReply("ERR DISCARD without MULTI")
} }
@@ -149,3 +149,45 @@ func DiscardMulti(db *DB, conn redis.Connection) redis.Reply {
conn.SetMultiState(false) conn.SetMultiState(false)
return reply.MakeQueuedReply() return reply.MakeQueuedReply()
} }
// GetUndoLogs return rollback commands
func (db *DB) GetUndoLogs(cmdLine [][]byte) []CmdLine {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return nil
}
undo := cmd.undo
if undo == nil {
return nil
}
return undo(db, cmdLine[1:])
}
// execWithLock executes normal commands, invoker should provide locks
func (db *DB) execWithLock(cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmdName + "'")
}
if !validateArity(cmd.arity, cmdLine) {
return reply.MakeArgNumErrReply(cmdName)
}
fun := cmd.executor
return fun(db, cmdLine[1:])
}
// GetRelatedKeys analysis related keys
func GetRelatedKeys(cmdLine [][]byte) ([]string, []string) {
cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName]
if !ok {
return nil, nil
}
prepare := cmd.prepare
if prepare == nil {
return nil, nil
}
return prepare(cmdLine[1:])
}

View File

@@ -8,20 +8,20 @@ import (
) )
func TestMulti(t *testing.T) { func TestMulti(t *testing.T) {
testDB.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
result := testDB.Exec(conn, utils.ToCmdLine("multi")) testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testServer.Exec(conn, utils.ToCmdLine("multi"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("set", key, value)) testServer.Exec(conn, utils.ToCmdLine("set", key, value))
key2 := utils.RandString(10) key2 := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("rpush", key2, value)) testServer.Exec(conn, utils.ToCmdLine("rpush", key2, value))
result = testDB.Exec(conn, utils.ToCmdLine("exec")) result = testServer.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
result = testDB.Exec(conn, utils.ToCmdLine("get", key)) result = testServer.Exec(conn, utils.ToCmdLine("get", key))
asserts.AssertBulkReply(t, result, value) asserts.AssertBulkReply(t, result, value)
result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) result = testServer.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1"))
asserts.AssertMultiBulkReply(t, result, []string{value}) asserts.AssertMultiBulkReply(t, result, []string{value})
if len(conn.GetWatching()) > 0 { if len(conn.GetWatching()) > 0 {
t.Error("watching map should be reset") t.Error("watching map should be reset")
@@ -32,17 +32,17 @@ func TestMulti(t *testing.T) {
} }
func TestRollback(t *testing.T) { func TestRollback(t *testing.T) {
testDB.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
result := testDB.Exec(conn, utils.ToCmdLine("multi")) testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testServer.Exec(conn, utils.ToCmdLine("multi"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("set", key, value)) testServer.Exec(conn, utils.ToCmdLine("set", key, value))
testDB.Exec(conn, utils.ToCmdLine("rpush", key, value)) testServer.Exec(conn, utils.ToCmdLine("rpush", key, value))
result = testDB.Exec(conn, utils.ToCmdLine("exec")) result = testServer.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.") asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.")
result = testDB.Exec(conn, utils.ToCmdLine("type", key)) result = testServer.Exec(conn, utils.ToCmdLine("type", key))
asserts.AssertStatusReply(t, result, "none") asserts.AssertStatusReply(t, result, "none")
if len(conn.GetWatching()) > 0 { if len(conn.GetWatching()) > 0 {
t.Error("watching map should be reset") t.Error("watching map should be reset")
@@ -53,20 +53,20 @@ func TestRollback(t *testing.T) {
} }
func TestDiscard(t *testing.T) { func TestDiscard(t *testing.T) {
testDB.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
result := testDB.Exec(conn, utils.ToCmdLine("multi")) testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testServer.Exec(conn, utils.ToCmdLine("multi"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("set", key, value)) testServer.Exec(conn, utils.ToCmdLine("set", key, value))
key2 := utils.RandString(10) key2 := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("rpush", key2, value)) testServer.Exec(conn, utils.ToCmdLine("rpush", key2, value))
result = testDB.Exec(conn, utils.ToCmdLine("discard")) result = testServer.Exec(conn, utils.ToCmdLine("discard"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
result = testDB.Exec(conn, utils.ToCmdLine("get", key)) result = testServer.Exec(conn, utils.ToCmdLine("get", key))
asserts.AssertNullBulk(t, result) asserts.AssertNullBulk(t, result)
result = testDB.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1")) result = testServer.Exec(conn, utils.ToCmdLine("lrange", key2, "0", "-1"))
asserts.AssertMultiBulkReplySize(t, result, 0) asserts.AssertMultiBulkReplySize(t, result, 0)
if len(conn.GetWatching()) > 0 { if len(conn.GetWatching()) > 0 {
t.Error("watching map should be reset") t.Error("watching map should be reset")
@@ -77,21 +77,21 @@ func TestDiscard(t *testing.T) {
} }
func TestWatch(t *testing.T) { func TestWatch(t *testing.T) {
testDB.Flush()
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
key := utils.RandString(10) key := utils.RandString(10)
value := utils.RandString(10) value := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("watch", key)) testServer.Exec(conn, utils.ToCmdLine("watch", key))
testDB.Exec(conn, utils.ToCmdLine("set", key, value)) testServer.Exec(conn, utils.ToCmdLine("set", key, value))
result := testDB.Exec(conn, utils.ToCmdLine("multi")) result := testServer.Exec(conn, utils.ToCmdLine("multi"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
key2 := utils.RandString(10) key2 := utils.RandString(10)
value2 := utils.RandString(10) value2 := utils.RandString(10)
testDB.Exec(conn, utils.ToCmdLine("set", key2, value2)) testServer.Exec(conn, utils.ToCmdLine("set", key2, value2))
result = testDB.Exec(conn, utils.ToCmdLine("exec")) result = testServer.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertNotError(t, result) asserts.AssertNotError(t, result)
result = testDB.Exec(conn, utils.ToCmdLine("get", key2)) result = testServer.Exec(conn, utils.ToCmdLine("get", key2))
asserts.AssertNullBulk(t, result) asserts.AssertNullBulk(t, result)
if len(conn.GetWatching()) > 0 { if len(conn.GetWatching()) > 0 {
t.Error("watching map should be reset") t.Error("watching map should be reset")

View File

@@ -1,6 +1,7 @@
package godis package godis
import ( import (
"github.com/hdt3213/godis/aof"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"strconv" "strconv"
) )
@@ -52,7 +53,7 @@ func rollbackGivenKeys(db *DB, keys ...string) []CmdLine {
} else { } else {
undoCmdLines = append(undoCmdLines, undoCmdLines = append(undoCmdLines,
utils.ToCmdLine("DEL", key), // clean existed first utils.ToCmdLine("DEL", key), // clean existed first
EntityToCmd(key, entity).Args, aof.EntityToCmd(key, entity).Args,
toTTLCmd(db, key).Args, toTTLCmd(db, key).Args,
) )
} }

View File

@@ -11,5 +11,8 @@ func makeTestDB() *DB {
versionMap: dict.MakeConcurrent(dataDictSize), versionMap: dict.MakeConcurrent(dataDictSize),
ttlMap: dict.MakeConcurrent(ttlDictSize), ttlMap: dict.MakeConcurrent(ttlDictSize),
locker: lock.Make(lockerSize), locker: lock.Make(lockerSize),
addAof: func(line CmdLine) {
},
} }
} }