replication master side

This commit is contained in:
finley
2022-11-21 23:36:35 +08:00
parent a7a3da6e49
commit ba7ea942cb
23 changed files with 886 additions and 217 deletions

View File

@@ -26,6 +26,10 @@ type payload struct {
dbIndex int dbIndex int
} }
// Listener is a channel receive a replication of all aof payloads
// with a listener we can forward the updates to slave nodes etc.
type Listener chan<- CmdLine
// Handler receive msgs from channel and write to AOF file // Handler receive msgs from channel and write to AOF file
type Handler struct { type Handler struct {
db database.EmbedDB db database.EmbedDB
@@ -38,6 +42,7 @@ type Handler struct {
// pause aof for start/finish aof rewrite progress // pause aof for start/finish aof rewrite progress
pausingAof sync.RWMutex pausingAof sync.RWMutex
currentDB int currentDB int
listeners map[Listener]struct{}
} }
// NewAOFHandler creates a new aof.Handler // NewAOFHandler creates a new aof.Handler
@@ -54,12 +59,20 @@ func NewAOFHandler(db database.EmbedDB, tmpDBMaker func() database.EmbedDB) (*Ha
handler.aofFile = aofFile handler.aofFile = aofFile
handler.aofChan = make(chan *payload, aofQueueSize) handler.aofChan = make(chan *payload, aofQueueSize)
handler.aofFinished = make(chan struct{}) handler.aofFinished = make(chan struct{})
handler.listeners = make(map[Listener]struct{})
go func() { go func() {
handler.handleAof() handler.handleAof()
}() }()
return handler, nil return handler, nil
} }
// RemoveListener removes a listener from aof handler, so we can close the listener
func (handler *Handler) RemoveListener(listener Listener) {
handler.pausingAof.Lock()
defer handler.pausingAof.Unlock()
delete(handler.listeners, listener)
}
// AddAof send command to aof goroutine through channel // AddAof send command to aof goroutine through channel
func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) { func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
if config.Properties.AppendOnly && handler.aofChan != nil { if config.Properties.AppendOnly && handler.aofChan != nil {
@@ -73,12 +86,16 @@ func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
// handleAof listen aof channel and write into file // handleAof listen aof channel and write into file
func (handler *Handler) handleAof() { func (handler *Handler) handleAof() {
// serialized execution // serialized execution
var cmdLines []CmdLine
handler.currentDB = 0 handler.currentDB = 0
for p := range handler.aofChan { for p := range handler.aofChan {
cmdLines = cmdLines[:0] // reuse underlying array
handler.pausingAof.RLock() // prevent other goroutines from pausing aof handler.pausingAof.RLock() // prevent other goroutines from pausing aof
if p.dbIndex != handler.currentDB { if p.dbIndex != handler.currentDB {
// select db // select db
data := protocol.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes() selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))
cmdLines = append(cmdLines, selectCmd)
data := protocol.MakeMultiBulkReply(selectCmd).ToBytes()
_, err := handler.aofFile.Write(data) _, err := handler.aofFile.Write(data)
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
@@ -88,11 +105,17 @@ func (handler *Handler) handleAof() {
handler.currentDB = p.dbIndex handler.currentDB = p.dbIndex
} }
data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes() data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes()
cmdLines = append(cmdLines, p.cmdLine)
_, err := handler.aofFile.Write(data) _, err := handler.aofFile.Write(data)
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} }
handler.pausingAof.RUnlock() handler.pausingAof.RUnlock()
for listener := range handler.listeners {
for _, line := range cmdLines {
listener <- line
}
}
} }
handler.aofFinished <- struct{}{} handler.aofFinished <- struct{}{}
} }
@@ -123,7 +146,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
reader = file reader = file
} }
ch := parser.ParseStream(reader) ch := parser.ParseStream(reader)
fakeConn := &connection.FakeConn{} // only used for save dbIndex fakeConn := connection.NewFakeConn() // only used for save dbIndex
for p := range ch { for p := range ch {
if p.Err != nil { if p.Err != nil {
if p.Err == io.EOF { if p.Err == io.EOF {
@@ -143,7 +166,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
} }
ret := handler.db.Exec(fakeConn, r.Args) ret := handler.db.Exec(fakeConn, r.Args)
if protocol.IsErrorReply(ret) { if protocol.IsErrorReply(ret) {
logger.Error("exec err", ret.ToBytes()) logger.Error("exec err", string(ret.ToBytes()))
} }
} }
} }

View File

@@ -16,8 +16,12 @@ import (
"time" "time"
) )
func (handler *Handler) Rewrite2RDB() error { // todo: forbid concurrent rewrite
ctx, err := handler.startRewrite2RDB()
// Rewrite2RDB rewrite aof data into rdb
// if extraListener is not nil, it will be appended to Handler.listeners, it will receive all updates after rdb
func (handler *Handler) Rewrite2RDB(rdbFilename string, extraListener Listener) error {
ctx, err := handler.startRewrite2RDB(extraListener)
if err != nil { if err != nil {
return err return err
} }
@@ -25,10 +29,6 @@ func (handler *Handler) Rewrite2RDB() error {
if err != nil { if err != nil {
return err return err
} }
rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err = ctx.tmpFile.Close() err = ctx.tmpFile.Close()
if err != nil { if err != nil {
return err return err
@@ -40,7 +40,7 @@ func (handler *Handler) Rewrite2RDB() error {
return nil return nil
} }
func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) { func (handler *Handler) startRewrite2RDB(extraListener Listener) (*RewriteCtx, error) {
handler.pausingAof.Lock() // pausing aof handler.pausingAof.Lock() // pausing aof
defer handler.pausingAof.Unlock() defer handler.pausingAof.Unlock()
@@ -59,6 +59,9 @@ func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) {
logger.Warn("tmp file create failed") logger.Warn("tmp file create failed")
return nil, err return nil, err
} }
if extraListener != nil {
handler.listeners[extraListener] = struct{}{}
}
return &RewriteCtx{ return &RewriteCtx{
tmpFile: file, tmpFile: file,
fileSize: filesize, fileSize: filesize,

View File

@@ -10,7 +10,7 @@ import (
func TestExec(t *testing.T) { func TestExec(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
@@ -26,7 +26,7 @@ func TestAuth(t *testing.T) {
defer func() { defer func() {
config.Properties.RequirePass = "" config.Properties.RequirePass = ""
}() }()
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
ret := testNodeA.Exec(conn, toArgs("GET", "a")) ret := testNodeA.Exec(conn, toArgs("GET", "a"))
asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") asserts.AssertErrReply(t, ret, "NOAUTH Authentication required")
ret = testNodeA.Exec(conn, toArgs("AUTH", passwd)) ret = testNodeA.Exec(conn, toArgs("AUTH", passwd))
@@ -39,7 +39,7 @@ func TestRelay(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value)) ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key)) ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key))
@@ -50,7 +50,7 @@ func TestBroadcast(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4) key := RandString(4)
value := RandString(4) value := RandString(4)
rets := testCluster2.broadcast(&connection.FakeConn{}, toArgs("SET", key, value)) rets := testCluster2.broadcast(connection.NewFakeConn(), toArgs("SET", key, value))
for _, v := range rets { for _, v := range rets {
asserts.AssertNotError(t, v) asserts.AssertNotError(t, v)
} }

View File

@@ -7,7 +7,7 @@ import (
) )
func TestDel(t *testing.T) { func TestDel(t *testing.T) {
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
allowFastTransaction = false allowFastTransaction = false
testNodeA.Exec(conn, toArgs("SET", "a", "a")) testNodeA.Exec(conn, toArgs("SET", "a", "a"))
ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c")) ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c"))

View File

@@ -7,7 +7,7 @@ import (
) )
func TestMSet(t *testing.T) { func TestMSet(t *testing.T) {
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
allowFastTransaction = false allowFastTransaction = false
ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b")) ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b"))
asserts.AssertNotError(t, ret) asserts.AssertNotError(t, ret)
@@ -16,7 +16,7 @@ func TestMSet(t *testing.T) {
} }
func TestMSetNx(t *testing.T) { func TestMSetNx(t *testing.T) {
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
allowFastTransaction = false allowFastTransaction = false
FlushAll(testNodeA, conn, toArgs("FLUSHALL")) FlushAll(testNodeA, conn, toArgs("FLUSHALL"))
ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b")) ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b"))

View File

@@ -11,7 +11,7 @@ import (
func TestPublish(t *testing.T) { func TestPublish(t *testing.T) {
channel := utils.RandString(5) channel := utils.RandString(5)
msg := utils.RandString(5) msg := utils.RandString(5)
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
Subscribe(testNodeA, conn, utils.ToCmdLine("SUBSCRIBE", channel)) Subscribe(testNodeA, conn, utils.ToCmdLine("SUBSCRIBE", channel))
conn.Clean() // clean subscribe success conn.Clean() // clean subscribe success
Publish(testNodeA, conn, utils.ToCmdLine("PUBLISH", channel, msg)) Publish(testNodeA, conn, utils.ToCmdLine("PUBLISH", channel, msg))

View File

@@ -17,7 +17,7 @@ import (
) )
func makeTestData(db database.DB, dbIndex int, prefix string, size int) { func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
conn.SelectDB(dbIndex) conn.SelectDB(dbIndex)
db.Exec(conn, utils.ToCmdLine("FlushDB")) db.Exec(conn, utils.ToCmdLine("FlushDB"))
cursor := 0 cursor := 0
@@ -49,7 +49,7 @@ func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
} }
func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) { func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
conn.SelectDB(dbIndex) conn.SelectDB(dbIndex)
cursor := 0 cursor := 0
var ret redis.Reply var ret redis.Reply
@@ -146,7 +146,7 @@ func TestRDB(t *testing.T) {
dbNum := 4 dbNum := 4
size := 10 size := 10
var prefixes []string var prefixes []string
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
writeDB := NewStandaloneServer() writeDB := NewStandaloneServer()
for i := 0; i < dbNum; i++ { for i := 0; i < dbNum; i++ {
prefix := utils.RandString(8) prefix := utils.RandString(8)
@@ -216,7 +216,7 @@ func TestRewriteAOF2(t *testing.T) {
} }
aofWriteDB := NewStandaloneServer() aofWriteDB := NewStandaloneServer()
dbNum := 4 dbNum := 4
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
for i := 0; i < dbNum; i++ { for i := 0; i < dbNum; i++ {
conn.SelectDB(i) conn.SelectDB(i)
key := strconv.Itoa(i) key := strconv.Itoa(i)

View File

@@ -9,7 +9,6 @@ import (
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/pubsub" "github.com/hdt3213/godis/pubsub"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol"
"runtime/debug" "runtime/debug"
"strconv" "strconv"
@@ -27,10 +26,10 @@ type MultiDB struct {
// handle aof persistence // handle aof persistence
aofHandler *aof.Handler aofHandler *aof.Handler
// store master node address // for replication
slaveOf string role int32
role int32 slaveStatus *slaveStatus
replication *slaveStatus masterStatus *masterStatus
} }
// NewStandaloneServer creates a standalone redis server, with multi database and all other funtions // NewStandaloneServer creates a standalone redis server, with multi database and all other funtions
@@ -50,31 +49,36 @@ func NewStandaloneServer() *MultiDB {
mdb.hub = pubsub.MakeHub() mdb.hub = pubsub.MakeHub()
validAof := false validAof := false
if config.Properties.AppendOnly { if config.Properties.AppendOnly {
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB { mdb.initAof()
return MakeBasicMultiDB()
})
if err != nil {
panic(err)
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
singleDB := db.Load().(*DB)
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
}
validAof = true validAof = true
} }
if config.Properties.RDBFilename != "" && !validAof { if config.Properties.RDBFilename != "" && !validAof {
// load rdb // load rdb
loadRdbFile(mdb) loadRdbFile(mdb)
} }
mdb.replication = initReplStatus() mdb.slaveStatus = initReplSlaveStatus()
mdb.startAsMaster()
mdb.startReplCron() mdb.startReplCron()
mdb.role = masterRole // The initialization process does not require atomicity mdb.role = masterRole // The initialization process does not require atomicity
return mdb return mdb
} }
func (mdb *MultiDB) initAof() {
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB {
return MakeBasicMultiDB()
})
if err != nil {
panic(err)
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
singleDB := db.Load().(*DB)
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
}
}
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages // MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
func MakeBasicMultiDB() *MultiDB { func MakeBasicMultiDB() *MultiDB {
mdb := &MultiDB{} mdb := &MultiDB{}
@@ -117,8 +121,7 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
// read only slave // read only slave
role := atomic.LoadInt32(&mdb.role) role := atomic.LoadInt32(&mdb.role)
if role == slaveRole && if role == slaveRole && !c.IsMaster() {
c.GetRole() != connection.ReplicationRecvCli {
// only allow read only command, forbid all special commands except `auth` and `slaveof` // only allow read only command, forbid all special commands except `auth` and `slaveof`
if !isReadOnlyCommand(cmdName) { if !isReadOnlyCommand(cmdName) {
return protocol.MakeErrReply("READONLY You can't write against a read only slave.") return protocol.MakeErrReply("READONLY You can't write against a read only slave.")
@@ -167,6 +170,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
return protocol.MakeArgNumErrReply("copy") return protocol.MakeArgNumErrReply("copy")
} }
return execCopy(mdb, c, cmdLine[1:]) return execCopy(mdb, c, cmdLine[1:])
} else if cmdName == "replconf" {
return mdb.execReplConf(c, cmdLine[1:])
} else if cmdName == "psync" {
return mdb.execPSync(c, cmdLine[1:])
} }
// todo: support multi database transaction // todo: support multi database transaction
@@ -186,8 +193,8 @@ func (mdb *MultiDB) AfterClientClose(c redis.Connection) {
// Close graceful shutdown database // Close graceful shutdown database
func (mdb *MultiDB) Close() { func (mdb *MultiDB) Close() {
// stop replication first // stop slaveStatus first
mdb.replication.close() mdb.slaveStatus.close()
if mdb.aofHandler != nil { if mdb.aofHandler != nil {
mdb.aofHandler.Close() mdb.aofHandler.Close()
} }
@@ -308,7 +315,11 @@ func SaveRDB(db *MultiDB, args [][]byte) redis.Reply {
if db.aofHandler == nil { if db.aofHandler == nil {
return protocol.MakeErrReply("please enable aof before using save") return protocol.MakeErrReply("please enable aof before using save")
} }
err := db.aofHandler.Rewrite2RDB() rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
if err != nil { if err != nil {
return protocol.MakeErrReply(err.Error()) return protocol.MakeErrReply(err.Error())
} }
@@ -326,7 +337,11 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply {
logger.Error(err) logger.Error(err)
} }
}() }()
err := db.aofHandler.Rewrite2RDB() rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
} }
@@ -339,3 +354,18 @@ func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) {
db := mdb.mustSelectDB(dbIndex) db := mdb.mustSelectDB(dbIndex)
return db.data.Len(), db.ttlMap.Len() return db.data.Len(), db.ttlMap.Len()
} }
func (mdb *MultiDB) startReplCron() {
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error("panic", err)
}
}()
ticker := time.Tick(time.Second * 10)
for range ticker {
mdb.slaveCron()
mdb.masterCron()
}
}()
}

View File

@@ -17,7 +17,7 @@ func TestLoadRDB(t *testing.T) {
AppendOnly: false, AppendOnly: false,
RDBFilename: filepath.Join(projectRoot, "test.rdb"), // set working directory to project root RDBFilename: filepath.Join(projectRoot, "test.rdb"), // set working directory to project root
} }
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
rdbDB := NewStandaloneServer() rdbDB := NewStandaloneServer()
result := rdbDB.Exec(conn, utils.ToCmdLine("Get", "str")) result := rdbDB.Exec(conn, utils.ToCmdLine("Get", "str"))
asserts.AssertBulkReply(t, result, "str") asserts.AssertBulkReply(t, result, "str")

View File

@@ -0,0 +1,362 @@
package database
import (
"context"
"errors"
"fmt"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"time"
)
const (
slaveStateHandShake = uint8(iota)
slaveStateWaitSaveEnd
slaveStateSendingRDB
slaveStateOnline
)
const (
bgSaveIdle = uint8(iota)
bgSaveRunning
bgSaveFinish
)
const (
salveCapacityNone = 0
salveCapacityEOF = 1 << iota
salveCapacityPsync2
)
// slaveClient stores slave status in the view of master
type slaveClient struct {
conn redis.Connection
state uint8
offset int64
lastAckTime time.Time
announceIp string
announcePort int
capacity uint8
}
type masterStatus struct {
ctx context.Context
mu sync.RWMutex
replId string
backlog []byte // backlog can be appended or replaced as a whole, cannot be modified(insert/set/delete)
beginOffset int64
currentOffset int64
slaveMap map[redis.Connection]*slaveClient
waitSlaves map[*slaveClient]struct{}
onlineSlaves map[*slaveClient]struct{}
bgSaveState uint8
rdbFilename string
}
func (master *masterStatus) appendBacklog(bin []byte) {
master.backlog = append(master.backlog, bin...)
master.currentOffset += int64(len(bin))
}
func (mdb *MultiDB) bgSaveForReplication() {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Errorf("panic: %v", e)
}
}()
if err := mdb.saveForReplication(); err != nil {
logger.Errorf("save for replication error: %v", err)
}
}()
}
// saveForReplication does bg-save and send rdb to waiting slaves
func (mdb *MultiDB) saveForReplication() error {
rdbFile, err := ioutil.TempFile("", "*.rdb")
if err != nil {
return fmt.Errorf("create temp rdb failed: %v", err)
}
rdbFilename := rdbFile.Name()
mdb.masterStatus.mu.Lock()
mdb.masterStatus.bgSaveState = bgSaveRunning
mdb.masterStatus.rdbFilename = rdbFilename // todo: can reuse config.Properties.RDBFilename?
mdb.masterStatus.mu.Unlock()
aofListener := make(chan CmdLine, 1024) // give channel enough capacity to store all updates during rewrite to db
err = mdb.aofHandler.Rewrite2RDB(rdbFilename, aofListener)
if err != nil {
return err
}
go func() {
mdb.masterListenAof(aofListener)
}()
// change bgSaveState and get waitSlaves for sending
waitSlaves := make(map[*slaveClient]struct{})
mdb.masterStatus.mu.Lock()
mdb.masterStatus.bgSaveState = bgSaveFinish
for slave := range mdb.masterStatus.waitSlaves {
waitSlaves[slave] = struct{}{}
}
mdb.masterStatus.waitSlaves = nil
mdb.masterStatus.mu.Unlock()
for slave := range waitSlaves {
err = mdb.masterFullReSyncWithSlave(slave)
if err != nil {
mdb.removeSlave(slave)
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
continue
}
}
return nil
}
// masterFullReSyncWithSlave send replication header, rdb file and all backlogs to slave
func (mdb *MultiDB) masterFullReSyncWithSlave(slave *slaveClient) error {
// write replication header
header := "+FULLRESYNC " + mdb.masterStatus.replId + " " +
strconv.FormatInt(mdb.masterStatus.beginOffset, 10) + protocol.CRLF
_, err := slave.conn.Write([]byte(header))
if err != nil {
return fmt.Errorf("write replication header to slave failed: %v", err)
}
// send rdb
rdbFile, err := os.Open(mdb.masterStatus.rdbFilename)
if err != nil {
return fmt.Errorf("open rdb file %s for replication error: %v", mdb.masterStatus.rdbFilename, err)
}
slave.state = slaveStateSendingRDB
rdbInfo, _ := os.Stat(mdb.masterStatus.rdbFilename)
rdbSize := rdbInfo.Size()
rdbHeader := "$" + strconv.FormatInt(rdbSize, 10) + protocol.CRLF
_, err = slave.conn.Write([]byte(rdbHeader))
if err != nil {
return fmt.Errorf("write rdb header to slave failed: %v", err)
}
_, err = io.Copy(slave.conn, rdbFile)
if err != nil {
return fmt.Errorf("write rdb file to slave failed: %v", err)
}
// send backlog
mdb.masterStatus.mu.RLock()
currentOffset := mdb.masterStatus.currentOffset
backlog := mdb.masterStatus.backlog[:currentOffset-mdb.masterStatus.beginOffset]
mdb.masterStatus.mu.RUnlock()
_, err = slave.conn.Write(backlog)
if err != nil {
return fmt.Errorf("full resync write backlog to slave failed: %v", err)
}
// set slave as online
mdb.setSlaveOnline(slave, currentOffset)
return nil
}
var cannotPartialSync = errors.New("cannot do partial sync")
func (mdb *MultiDB) masterTryPartialSyncWithSlave(slave *slaveClient, replId string, slaveOffset int64) error {
mdb.masterStatus.mu.RLock()
if replId != mdb.masterStatus.replId {
mdb.masterStatus.mu.RUnlock()
return cannotPartialSync
}
if slaveOffset < mdb.masterStatus.beginOffset || slaveOffset > mdb.masterStatus.currentOffset {
mdb.masterStatus.mu.RUnlock()
return cannotPartialSync
}
currentOffset := mdb.masterStatus.currentOffset
backlog := mdb.masterStatus.backlog[slaveOffset-mdb.masterStatus.beginOffset : currentOffset-mdb.masterStatus.beginOffset]
mdb.masterStatus.mu.RUnlock()
// send replication header
header := "+CONTINUE " + mdb.masterStatus.replId + protocol.CRLF
_, err := slave.conn.Write([]byte(header))
if err != nil {
return fmt.Errorf("write replication header to slave failed: %v", err)
}
// send backlog
_, err = slave.conn.Write(backlog)
if err != nil {
return fmt.Errorf("partial resync write backlog to slave failed: %v", err)
}
// set slave online
mdb.setSlaveOnline(slave, currentOffset)
return nil
}
func (mdb *MultiDB) masterSendUpdatesToSlave() error {
onlineSlaves := make(map[*slaveClient]struct{})
mdb.masterStatus.mu.RLock()
currentOffset := mdb.masterStatus.currentOffset
beginOffset := mdb.masterStatus.beginOffset
backlog := mdb.masterStatus.backlog[:currentOffset-beginOffset]
for slave := range mdb.masterStatus.onlineSlaves {
onlineSlaves[slave] = struct{}{}
}
mdb.masterStatus.mu.RUnlock()
for slave := range onlineSlaves {
slaveBeginOffset := slave.offset - beginOffset
_, err := slave.conn.Write(backlog[slaveBeginOffset:])
if err != nil {
logger.Errorf("send updates write backlog to slave failed: %v", err)
mdb.removeSlave(slave)
continue
}
slave.offset = currentOffset
}
return nil
}
func (mdb *MultiDB) execPSync(c redis.Connection, args [][]byte) redis.Reply {
replId := string(args[0])
replOffset, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range")
}
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
slave := mdb.masterStatus.slaveMap[c]
if slave == nil {
slave = &slaveClient{
conn: c,
}
c.SetSlave()
mdb.masterStatus.slaveMap[c] = slave
}
if mdb.masterStatus.bgSaveState == bgSaveIdle {
slave.state = slaveStateWaitSaveEnd
mdb.masterStatus.waitSlaves[slave] = struct{}{}
mdb.bgSaveForReplication()
} else if mdb.masterStatus.bgSaveState == bgSaveRunning {
slave.state = slaveStateWaitSaveEnd
mdb.masterStatus.waitSlaves[slave] = struct{}{}
} else if mdb.masterStatus.bgSaveState == bgSaveFinish {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Errorf("panic: %v", e)
}
}()
err := mdb.masterTryPartialSyncWithSlave(slave, replId, replOffset)
if err == nil {
return
}
if err != nil && err != cannotPartialSync {
mdb.removeSlave(slave)
logger.Errorf("masterTryPartialSyncWithSlave error: %v", err)
return
}
// assert err == cannotPartialSync
if err := mdb.masterFullReSyncWithSlave(slave); err != nil {
mdb.removeSlave(slave)
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
return
}
}()
}
return &protocol.NoReply{}
}
func (mdb *MultiDB) execReplConf(c redis.Connection, args [][]byte) redis.Reply {
if len(args)%2 != 0 {
return protocol.MakeSyntaxErrReply()
}
mdb.masterStatus.mu.RLock()
slave := mdb.masterStatus.slaveMap[c]
mdb.masterStatus.mu.RUnlock()
for i := 0; i < len(args); i += 2 {
key := strings.ToLower(string(args[i]))
value := string(args[i+1])
switch key {
case "ack":
offset, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range")
}
slave.offset = offset
slave.lastAckTime = time.Now()
return &protocol.NoReply{}
}
}
return protocol.MakeOkReply()
}
func (mdb *MultiDB) removeSlave(slave *slaveClient) {
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
_ = slave.conn.Close()
delete(mdb.masterStatus.slaveMap, slave.conn)
delete(mdb.masterStatus.waitSlaves, slave)
delete(mdb.masterStatus.onlineSlaves, slave)
}
func (mdb *MultiDB) setSlaveOnline(slave *slaveClient, currentOffset int64) {
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
slave.state = slaveStateOnline
slave.offset = currentOffset
mdb.masterStatus.onlineSlaves[slave] = struct{}{}
}
var pingBytes = protocol.MakeMultiBulkReply(utils.ToCmdLine("ping")).ToBytes()
func (mdb *MultiDB) masterCron() {
if mdb.role != masterRole {
return
}
mdb.masterStatus.mu.Lock()
if mdb.masterStatus.bgSaveState == bgSaveFinish {
mdb.masterStatus.appendBacklog(pingBytes)
}
mdb.masterStatus.mu.Unlock()
if err := mdb.masterSendUpdatesToSlave(); err != nil {
logger.Errorf("masterSendUpdatesToSlave error: %v", err)
}
}
func (mdb *MultiDB) masterListenAof(listener chan CmdLine) {
for {
select {
case cmdLine := <-listener:
mdb.masterStatus.mu.Lock()
reply := protocol.MakeMultiBulkReply(cmdLine)
mdb.masterStatus.appendBacklog(reply.ToBytes())
mdb.masterStatus.mu.Unlock()
if err := mdb.masterSendUpdatesToSlave(); err != nil {
logger.Errorf("masterSendUpdatesToSlave after receive aof error: %v", err)
}
// if bgSave is running, updates will be sent after the save finished
case <-mdb.masterStatus.ctx.Done():
break
}
}
}
func (mdb *MultiDB) startAsMaster() {
mdb.masterStatus = &masterStatus{
ctx: context.Background(),
mu: sync.RWMutex{},
replId: utils.RandHexString(40),
backlog: nil,
beginOffset: 0,
currentOffset: 0,
slaveMap: make(map[redis.Connection]*slaveClient),
waitSlaves: make(map[*slaveClient]struct{}),
onlineSlaves: make(map[*slaveClient]struct{}),
bgSaveState: bgSaveIdle,
rdbFilename: "",
}
}

View File

@@ -0,0 +1,194 @@
package database
import (
"bytes"
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/parser"
"github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts"
rdb "github.com/hdt3213/rdb/parser"
"io/ioutil"
"os"
"path"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
)
func mockServer() *MultiDB {
server := &MultiDB{}
server.dbSet = make([]*atomic.Value, 16)
for i := range server.dbSet {
singleDB := makeDB()
singleDB.index = i
holder := &atomic.Value{}
holder.Store(singleDB)
server.dbSet[i] = holder
}
server.slaveStatus = initReplSlaveStatus()
return server
}
func TestReplicationMasterSide(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "godis")
if err != nil {
t.Error(err)
return
}
aofFilename := path.Join(tmpDir, "a.aof")
defer func() {
_ = os.Remove(aofFilename)
}()
config.Properties = &config.ServerProperties{
Databases: 16,
AppendOnly: true,
AppendFilename: aofFilename,
}
master := mockServer()
master.initAof()
master.startAsMaster()
slave := mockServer()
replConn := connection.NewFakeConn()
// set data to master
masterConn := connection.NewFakeConn()
resp := master.Exec(masterConn, utils.ToCmdLine("SET", "a", "a"))
asserts.AssertNotError(t, resp)
time.Sleep(time.Millisecond * 100) // wait write aof
// full re-sync
master.Exec(replConn, utils.ToCmdLine("psync", "?", "-1"))
masterChan := parser.ParseStream(replConn)
psyncPayload := <-masterChan
if psyncPayload.Err != nil {
t.Errorf("master bad protocol: %v", psyncPayload.Err)
return
}
psyncHeader, ok := psyncPayload.Data.(*protocol.StatusReply)
if !ok {
t.Error("psync header is not a status reply")
return
}
headers := strings.Split(psyncHeader.Status, " ")
if len(headers) != 3 {
t.Errorf("illegal psync header: %s", psyncHeader.Status)
return
}
replId := headers[1]
replOffset, err := strconv.ParseInt(headers[2], 10, 64)
if err != nil {
t.Errorf("illegal offset: %s", headers[2])
return
}
t.Logf("repl id: %s, offset: %d", replId, replOffset)
rdbPayload := <-masterChan
if rdbPayload.Err != nil {
t.Error("read response failed: " + rdbPayload.Err.Error())
return
}
rdbReply, ok := rdbPayload.Data.(*protocol.BulkReply)
if !ok {
t.Error("illegal payload header: " + string(rdbPayload.Data.ToBytes()))
return
}
rdbDec := rdb.NewDecoder(bytes.NewReader(rdbReply.Arg))
err = importRDB(rdbDec, slave)
if err != nil {
t.Error("import rdb failed: " + err.Error())
return
}
// get a
slaveConn := connection.NewFakeConn()
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "a"))
asserts.AssertBulkReply(t, resp, "a")
/*---- test broadcast aof ----*/
masterConn = connection.NewFakeConn()
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "b", "b"))
time.Sleep(time.Millisecond * 100) // wait write aof
asserts.AssertNotError(t, resp)
master.masterCron()
for {
payload := <-masterChan
if payload.Err != nil {
t.Error(payload.Err)
return
}
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
if !ok {
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
return
}
slave.Exec(replConn, cmdLine.Args)
n := len(cmdLine.ToBytes())
slave.slaveStatus.replOffset += int64(n)
if string(cmdLine.Args[0]) != "ping" {
break
}
}
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "b"))
asserts.AssertBulkReply(t, resp, "b")
/*---- test partial reconnect ----*/
_ = replConn.Close() // mock disconnect
replConn = connection.NewFakeConn()
master.Exec(replConn, utils.ToCmdLine("psync", replId,
strconv.FormatInt(slave.slaveStatus.replOffset, 10)))
masterChan = parser.ParseStream(replConn)
psyncPayload = <-masterChan
if psyncPayload.Err != nil {
t.Errorf("master bad protocol: %v", psyncPayload.Err)
return
}
psyncHeader, ok = psyncPayload.Data.(*protocol.StatusReply)
if !ok {
t.Error("psync header is not a status reply")
return
}
headers = strings.Split(psyncHeader.Status, " ")
if len(headers) != 2 {
t.Errorf("illegal psync header: %s", psyncHeader.Status)
return
}
if headers[0] != "CONTINUE" {
t.Errorf("expect CONTINUE actual %s", headers[0])
return
}
replId = headers[1]
t.Logf("partial resync repl id: %s, offset: %d", replId, slave.slaveStatus.replOffset)
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "c", "c"))
time.Sleep(time.Millisecond * 100) // wait write aof
asserts.AssertNotError(t, resp)
master.masterCron()
for {
payload := <-masterChan
if payload.Err != nil {
t.Error(payload.Err)
return
}
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
if !ok {
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
return
}
slave.Exec(replConn, cmdLine.Args)
if string(cmdLine.Args[0]) != "ping" {
break
}
}
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "c"))
asserts.AssertBulkReply(t, resp, "c")
}

View File

@@ -31,9 +31,9 @@ type slaveStatus struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// configVersion stands for the version of replication config. Any change of master host/port will cause configVersion increment // configVersion stands for the version of slaveStatus config. Any change of master host/port will cause configVersion increment
// If configVersion change has been found during replication current replication procedure will stop. // If configVersion change has been found during slaveStatus current slaveStatus procedure will stop.
// It is designed to abort a running replication procedure // It is designed to abort a running slaveStatus procedure
configVersion int32 configVersion int32
masterHost string masterHost string
@@ -47,26 +47,10 @@ type slaveStatus struct {
running sync.WaitGroup running sync.WaitGroup
} }
var configChangedErr = errors.New("replication config changed") var configChangedErr = errors.New("slaveStatus config changed")
func initReplStatus() *slaveStatus { func initReplSlaveStatus() *slaveStatus {
repl := &slaveStatus{} return &slaveStatus{}
// start cron
return repl
}
func (mdb *MultiDB) startReplCron() {
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error("panic", err)
}
}()
ticker := time.Tick(time.Second)
for range ticker {
mdb.slaveCron()
}
}()
} }
func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply { func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
@@ -80,29 +64,29 @@ func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
if err != nil { if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range") return protocol.MakeErrReply("ERR value is not an integer or out of range")
} }
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
atomic.StoreInt32(&mdb.role, slaveRole) atomic.StoreInt32(&mdb.role, slaveRole)
mdb.replication.masterHost = host mdb.slaveStatus.masterHost = host
mdb.replication.masterPort = port mdb.slaveStatus.masterPort = port
// use buffered channel in case receiver goroutine exited before controller send stop signal // use buffered channel in case receiver goroutine exited before controller send stop signal
atomic.AddInt32(&mdb.replication.configVersion, 1) atomic.AddInt32(&mdb.slaveStatus.configVersion, 1)
mdb.replication.mutex.Unlock() mdb.slaveStatus.mutex.Unlock()
go mdb.setupMaster() go mdb.setupMaster()
return protocol.MakeOkReply() return protocol.MakeOkReply()
} }
func (mdb *MultiDB) slaveOfNone() { func (mdb *MultiDB) slaveOfNone() {
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
defer mdb.replication.mutex.Unlock() defer mdb.slaveStatus.mutex.Unlock()
mdb.replication.masterHost = "" mdb.slaveStatus.masterHost = ""
mdb.replication.masterPort = 0 mdb.slaveStatus.masterPort = 0
mdb.replication.replId = "" mdb.slaveStatus.replId = ""
mdb.replication.replOffset = -1 mdb.slaveStatus.replOffset = -1
mdb.replication.stopSlaveWithMutex() mdb.slaveStatus.stopSlaveWithMutex()
} }
// stopSlaveWithMutex stops in-progress connectWithMaster/fullSync/receiveAOF // stopSlaveWithMutex stops in-progress connectWithMaster/fullSync/receiveAOF
// invoker should have replication mutex // invoker should have slaveStatus mutex
func (repl *slaveStatus) stopSlaveWithMutex() { func (repl *slaveStatus) stopSlaveWithMutex() {
// update configVersion to stop connectWithMaster and fullSync // update configVersion to stop connectWithMaster and fullSync
atomic.AddInt32(&repl.configVersion, 1) atomic.AddInt32(&repl.configVersion, 1)
@@ -135,11 +119,11 @@ func (mdb *MultiDB) setupMaster() {
}() }()
var configVersion int32 var configVersion int32
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
mdb.replication.ctx = ctx mdb.slaveStatus.ctx = ctx
mdb.replication.cancel = cancel mdb.slaveStatus.cancel = cancel
configVersion = mdb.replication.configVersion configVersion = mdb.slaveStatus.configVersion
mdb.replication.mutex.Unlock() mdb.slaveStatus.mutex.Unlock()
isFullReSync, err := mdb.connectWithMaster(configVersion) isFullReSync, err := mdb.connectWithMaster(configVersion)
if err != nil { if err != nil {
// connect failed, abort master // connect failed, abort master
@@ -167,7 +151,7 @@ func (mdb *MultiDB) setupMaster() {
// connectWithMaster finishes handshake with master // connectWithMaster finishes handshake with master
// returns: isFullReSync, error // returns: isFullReSync, error
func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) { func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
addr := mdb.replication.masterHost + ":" + strconv.Itoa(mdb.replication.masterPort) addr := mdb.slaveStatus.masterHost + ":" + strconv.Itoa(mdb.slaveStatus.masterPort)
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
if err != nil { if err != nil {
mdb.slaveOfNone() // abort mdb.slaveOfNone() // abort
@@ -256,34 +240,34 @@ func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
} }
// update connection // update connection
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
defer mdb.replication.mutex.Unlock() defer mdb.slaveStatus.mutex.Unlock()
if mdb.replication.configVersion != configVersion { if mdb.slaveStatus.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex // slaveStatus conf changed during connecting and waiting mutex
return false, configChangedErr return false, configChangedErr
} }
mdb.replication.masterConn = conn mdb.slaveStatus.masterConn = conn
mdb.replication.masterChan = masterChan mdb.slaveStatus.masterChan = masterChan
mdb.replication.lastRecvTime = time.Now() mdb.slaveStatus.lastRecvTime = time.Now()
return mdb.psyncHandshake() return mdb.psyncHandshake()
} }
// psyncHandshake send `psync` to master and sync repl-id/offset with master // psyncHandshake send `psync` to master and sync repl-id/offset with master
// invoker should provide with replication.mutex // invoker should provide with slaveStatus.mutex
func (mdb *MultiDB) psyncHandshake() (bool, error) { func (mdb *MultiDB) psyncHandshake() (bool, error) {
replId := "?" replId := "?"
var replOffset int64 = -1 var replOffset int64 = -1
if mdb.replication.replId != "" { if mdb.slaveStatus.replId != "" {
replId = mdb.replication.replId replId = mdb.slaveStatus.replId
replOffset = mdb.replication.replOffset replOffset = mdb.slaveStatus.replOffset
} }
psyncCmdLine := utils.ToCmdLine("psync", replId, strconv.FormatInt(replOffset, 10)) psyncCmdLine := utils.ToCmdLine("psync", replId, strconv.FormatInt(replOffset, 10))
psyncReq := protocol.MakeMultiBulkReply(psyncCmdLine) psyncReq := protocol.MakeMultiBulkReply(psyncCmdLine)
_, err := mdb.replication.masterConn.Write(psyncReq.ToBytes()) _, err := mdb.slaveStatus.masterConn.Write(psyncReq.ToBytes())
if err != nil { if err != nil {
return false, errors.New("send failed " + err.Error()) return false, errors.New("send failed " + err.Error())
} }
psyncPayload := <-mdb.replication.masterChan psyncPayload := <-mdb.slaveStatus.masterChan
if psyncPayload.Err != nil { if psyncPayload.Err != nil {
return false, errors.New("read response failed: " + psyncPayload.Err.Error()) return false, errors.New("read response failed: " + psyncPayload.Err.Error())
} }
@@ -300,12 +284,12 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
var isFullReSync bool var isFullReSync bool
if headers[0] == "FULLRESYNC" { if headers[0] == "FULLRESYNC" {
logger.Info("full re-sync with master") logger.Info("full re-sync with master")
mdb.replication.replId = headers[1] mdb.slaveStatus.replId = headers[1]
mdb.replication.replOffset, err = strconv.ParseInt(headers[2], 10, 64) mdb.slaveStatus.replOffset, err = strconv.ParseInt(headers[2], 10, 64)
isFullReSync = true isFullReSync = true
} else if headers[0] == "CONTINUE" { } else if headers[0] == "CONTINUE" {
logger.Info("continue partial sync") logger.Info("continue partial sync")
mdb.replication.replId = headers[1] mdb.slaveStatus.replId = headers[1]
isFullReSync = false isFullReSync = false
} else { } else {
return false, errors.New("illegal psync resp: " + psyncHeader.Status) return false, errors.New("illegal psync resp: " + psyncHeader.Status)
@@ -314,13 +298,13 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
if err != nil { if err != nil {
return false, errors.New("get illegal repl offset: " + headers[2]) return false, errors.New("get illegal repl offset: " + headers[2])
} }
logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.replication.replId, mdb.replication.replOffset)) logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.slaveStatus.replId, mdb.slaveStatus.replOffset))
return isFullReSync, nil return isFullReSync, nil
} }
// loadMasterRDB downloads rdb after handshake has been done // loadMasterRDB downloads rdb after handshake has been done
func (mdb *MultiDB) loadMasterRDB(configVersion int32) error { func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
rdbPayload := <-mdb.replication.masterChan rdbPayload := <-mdb.slaveStatus.masterChan
if rdbPayload.Err != nil { if rdbPayload.Err != nil {
return errors.New("read response failed: " + rdbPayload.Err.Error()) return errors.New("read response failed: " + rdbPayload.Err.Error())
} }
@@ -337,10 +321,10 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
return errors.New("dump rdb failed: " + err.Error()) return errors.New("dump rdb failed: " + err.Error())
} }
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
defer mdb.replication.mutex.Unlock() defer mdb.slaveStatus.mutex.Unlock()
if mdb.replication.configVersion != configVersion { if mdb.slaveStatus.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex // slaveStatus conf changed during connecting and waiting mutex
return configChangedErr return configChangedErr
} }
for i, h := range rdbHolder.dbSet { for i, h := range rdbHolder.dbSet {
@@ -353,13 +337,13 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
} }
func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error { func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
conn := connection.NewConn(mdb.replication.masterConn) conn := connection.NewConn(mdb.slaveStatus.masterConn)
conn.SetRole(connection.ReplicationRecvCli) conn.SetMaster()
mdb.replication.running.Add(1) mdb.slaveStatus.running.Add(1)
defer mdb.replication.running.Done() defer mdb.slaveStatus.running.Done()
for { for {
select { select {
case payload, open := <-mdb.replication.masterChan: case payload, open := <-mdb.slaveStatus.masterChan:
if !open { if !open {
return errors.New("master channel unexpected close") return errors.New("master channel unexpected close")
} }
@@ -370,31 +354,32 @@ func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
if !ok { if !ok {
return errors.New("unexpected payload: " + string(payload.Data.ToBytes())) return errors.New("unexpected payload: " + string(payload.Data.ToBytes()))
} }
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
if mdb.replication.configVersion != configVersion { if mdb.slaveStatus.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex // slaveStatus conf changed during connecting and waiting mutex
return configChangedErr return configChangedErr
} }
mdb.Exec(conn, cmdLine.Args) mdb.Exec(conn, cmdLine.Args)
n := len(cmdLine.ToBytes()) // todo: directly get size from socket n := len(cmdLine.ToBytes()) // todo: directly get size from socket
mdb.replication.replOffset += int64(n) mdb.slaveStatus.replOffset += int64(n)
mdb.replication.lastRecvTime = time.Now() mdb.slaveStatus.lastRecvTime = time.Now()
logger.Info(fmt.Sprintf("receive %d bytes from master, current offset %d, %s", logger.Info(fmt.Sprintf("receive %d bytes from master, current offset %d, %s",
n, mdb.replication.replOffset, strconv.Quote(string(cmdLine.ToBytes())))) n, mdb.slaveStatus.replOffset, strconv.Quote(string(cmdLine.ToBytes()))))
mdb.replication.mutex.Unlock() mdb.slaveStatus.mutex.Unlock()
case <-ctx.Done(): case <-ctx.Done():
conn.GetConnPool().Put(conn) _ = conn.Close()
return nil return nil
} }
} }
} }
func (mdb *MultiDB) slaveCron() { func (mdb *MultiDB) slaveCron() {
repl := mdb.replication repl := mdb.slaveStatus
if repl.masterConn == nil { if repl.masterConn == nil {
return return
} }
// check master timeout
replTimeout := 60 * time.Second replTimeout := 60 * time.Second
if config.Properties.ReplTimeout != 0 { if config.Properties.ReplTimeout != 0 {
replTimeout = time.Duration(config.Properties.ReplTimeout) * time.Second replTimeout = time.Duration(config.Properties.ReplTimeout) * time.Second
@@ -427,9 +412,9 @@ func (repl *slaveStatus) sendAck2Master() error {
func (mdb *MultiDB) reconnectWithMaster() error { func (mdb *MultiDB) reconnectWithMaster() error {
logger.Info("reconnecting with master") logger.Info("reconnecting with master")
mdb.replication.mutex.Lock() mdb.slaveStatus.mutex.Lock()
defer mdb.replication.mutex.Unlock() defer mdb.slaveStatus.mutex.Unlock()
mdb.replication.stopSlaveWithMutex() mdb.slaveStatus.stopSlaveWithMutex()
go mdb.setupMaster() go mdb.setupMaster()
return nil return nil
} }

View File

@@ -8,22 +8,12 @@ import (
"github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts" "github.com/hdt3213/godis/redis/protocol/asserts"
"sync/atomic"
"testing" "testing"
"time" "time"
) )
func TestReplication(t *testing.T) { func TestReplicationSlaveSide(t *testing.T) {
mdb := &MultiDB{} mdb := mockServer()
mdb.dbSet = make([]*atomic.Value, 16)
for i := range mdb.dbSet {
singleDB := makeDB()
singleDB.index = i
holder := &atomic.Value{}
holder.Store(singleDB)
mdb.dbSet[i] = holder
}
mdb.replication = initReplStatus()
masterCli, err := client.MakeClient("127.0.0.1:6379") masterCli, err := client.MakeClient("127.0.0.1:6379")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
@@ -33,7 +23,7 @@ func TestReplication(t *testing.T) {
// sync with master // sync with master
ret := masterCli.Send(utils.ToCmdLine("set", "1", "1")) ret := masterCli.Send(utils.ToCmdLine("set", "1", "1"))
asserts.AssertStatusReply(t, ret, "OK") asserts.AssertStatusReply(t, ret, "OK")
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
ret = mdb.Exec(conn, utils.ToCmdLine("SLAVEOF", "127.0.0.1", "6379")) ret = mdb.Exec(conn, utils.ToCmdLine("SLAVEOF", "127.0.0.1", "6379"))
asserts.AssertStatusReply(t, ret, "OK") asserts.AssertStatusReply(t, ret, "OK")
success := false success := false
@@ -74,7 +64,7 @@ func TestReplication(t *testing.T) {
t.Error("sync failed") t.Error("sync failed")
return return
} }
err = mdb.replication.sendAck2Master() err = mdb.slaveStatus.sendAck2Master()
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@@ -83,8 +73,8 @@ func TestReplication(t *testing.T) {
// test reconnect // test reconnect
config.Properties.ReplTimeout = 1 config.Properties.ReplTimeout = 1
_ = mdb.replication.masterConn.Close() _ = mdb.slaveStatus.masterConn.Close()
mdb.replication.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout mdb.slaveStatus.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout
mdb.slaveCron() mdb.slaveCron()
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
ret = masterCli.Send(utils.ToCmdLine("set", "1", "3")) ret = masterCli.Send(utils.ToCmdLine("set", "1", "3"))
@@ -134,7 +124,7 @@ func TestReplication(t *testing.T) {
return return
} }
err = mdb.replication.close() err = mdb.slaveStatus.close()
if err != nil { if err != nil {
t.Error("cannot close") t.Error("cannot close")
} }

View File

@@ -20,7 +20,7 @@ func TestPing(t *testing.T) {
func TestAuth(t *testing.T) { func TestAuth(t *testing.T) {
passwd := utils.RandString(10) passwd := utils.RandString(10)
c := &connection.FakeConn{} c := connection.NewFakeConn()
ret := testServer.Exec(c, utils.ToCmdLine("AUTH")) ret := testServer.Exec(c, utils.ToCmdLine("AUTH"))
asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command") asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command")
ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd)) ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd))

View File

@@ -2,7 +2,9 @@ package redis
// Connection represents a connection with redis client // Connection represents a connection with redis client
type Connection interface { type Connection interface {
Write([]byte) error Write([]byte) (int, error)
Close() error
SetPassword(string) SetPassword(string)
GetPassword() string GetPassword() string
@@ -12,7 +14,6 @@ type Connection interface {
SubsCount() int SubsCount() int
GetChannels() []string GetChannels() []string
// used for `Multi` command
InMultiState() bool InMultiState() bool
SetMultiState(bool) SetMultiState(bool)
GetQueuedCmdLine() [][][]byte GetQueuedCmdLine() [][][]byte
@@ -22,10 +23,12 @@ type Connection interface {
AddTxError(err error) AddTxError(err error)
GetTxErrors() []error GetTxErrors() []error
// used for multi database
GetDBIndex() int GetDBIndex() int
SelectDB(int) SelectDB(int)
// returns role of conn, such as connection with client, connection with master node
GetRole() int32 SetSlave()
SetRole(int32) IsSlave() bool
SetMaster()
IsMaster() bool
} }

View File

@@ -107,6 +107,13 @@ func Error(v ...interface{}) {
logger.Println(v...) logger.Println(v...)
} }
func Errorf(format string, v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(ERROR)
logger.Println(fmt.Sprintf(format, v...))
}
// Fatal prints error log then stop the program // Fatal prints error log then stop the program
func Fatal(v ...interface{}) { func Fatal(v ...interface{}) {
mu.Lock() mu.Lock()

View File

@@ -16,3 +16,13 @@ func RandString(n int) string {
} }
return string(b) return string(b)
} }
var hexLetters = []rune("0123456789abcdef")
func RandHexString(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = hexLetters[rand.Intn(len(hexLetters))]
}
return string(b)
}

View File

@@ -82,7 +82,7 @@ func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
for _, channel := range channels { for _, channel := range channels {
if subscribe0(hub, channel, c) { if subscribe0(hub, channel, c) {
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) _, _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
} }
} }
return &protocol.NoReply{} return &protocol.NoReply{}
@@ -117,13 +117,13 @@ func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
defer db.subsLocker.UnLocks(channels...) defer db.subsLocker.UnLocks(channels...)
if len(channels) == 0 { if len(channels) == 0 {
_ = c.Write(unSubscribeNothing) _, _ = c.Write(unSubscribeNothing)
return &protocol.NoReply{} return &protocol.NoReply{}
} }
for _, channel := range channels { for _, channel := range channels {
if unsubscribe0(db, channel, c) { if unsubscribe0(db, channel, c) {
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) _, _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
} }
} }
return &protocol.NoReply{} return &protocol.NoReply{}
@@ -151,7 +151,7 @@ func Publish(hub *Hub, args [][]byte) redis.Reply {
replyArgs[0] = messageBytes replyArgs[0] = messageBytes
replyArgs[1] = []byte(channel) replyArgs[1] = []byte(channel)
replyArgs[2] = message replyArgs[2] = message
_ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes()) _, _ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes())
return true return true
}) })
return protocol.MakeIntReply(int64(subscribers.Len())) return protocol.MakeIntReply(int64(subscribers.Len()))

View File

@@ -2,6 +2,6 @@ bind 0.0.0.0
port 6399 port 6399
maxclients 128 maxclients 128
#appendonly no appendonly yes
#appendfilename appendonly.aof appendfilename appendonly.aof
#dbfilename test.rdb #dbfilename test.rdb

View File

@@ -1,7 +1,6 @@
package connection package connection
import ( import (
"bytes"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/wait" "github.com/hdt3213/godis/lib/sync/wait"
"net" "net"
@@ -10,21 +9,24 @@ import (
) )
const ( const (
// NormalCli is client with user // flagSlave means this a connection with slave
NormalCli = iota flagSlave = uint64(1 << iota)
// ReplicationRecvCli is fake client with replication master // flagSlave means this a connection with master
ReplicationRecvCli flagMaster
// flagMulti means this connection is within a transaction
flagMulti
) )
// Connection represents a connection with a redis-cli // Connection represents a connection with a redis-cli
type Connection struct { type Connection struct {
conn net.Conn conn net.Conn
// waiting until protocol finished // waiting until finish sending data, used for graceful shutdown
waitingReply wait.Wait sendingData wait.Wait
// lock while server sending response // lock while server sending response
mu sync.Mutex mu sync.Mutex
flags uint64
// subscribing channels // subscribing channels
subs map[string]bool subs map[string]bool
@@ -33,14 +35,12 @@ type Connection struct {
password string password string
// queued commands for `multi` // queued commands for `multi`
multiState bool queue [][][]byte
queue [][][]byte watching map[string]uint32
watching map[string]uint32 txErrors []error
txErrors []error
// selected db // selected db
selectedDB int selectedDB int
role int32
} }
var connPool = sync.Pool{ var connPool = sync.Pool{
@@ -49,11 +49,6 @@ var connPool = sync.Pool{
}, },
} }
// GetConnPool returns the connection pool pointer for putting and getting connection
func (c *Connection) GetConnPool() *sync.Pool {
return &connPool
}
// RemoteAddr returns the remote network address // RemoteAddr returns the remote network address
func (c *Connection) RemoteAddr() net.Addr { func (c *Connection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
@@ -61,8 +56,15 @@ func (c *Connection) RemoteAddr() net.Addr {
// Close disconnect with the client // Close disconnect with the client
func (c *Connection) Close() error { func (c *Connection) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second) c.sendingData.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close() _ = c.conn.Close()
c.subs = nil
c.password = ""
c.queue = nil
c.watching = nil
c.txErrors = nil
c.selectedDB = 0
connPool.Put(c)
return nil return nil
} }
@@ -80,17 +82,16 @@ func NewConn(conn net.Conn) *Connection {
} }
// Write sends response to client over tcp connection // Write sends response to client over tcp connection
func (c *Connection) Write(b []byte) error { func (c *Connection) Write(b []byte) (int, error) {
if len(b) == 0 { if len(b) == 0 {
return nil return 0, nil
} }
c.waitingReply.Add(1) c.sendingData.Add(1)
defer func() { defer func() {
c.waitingReply.Done() c.sendingData.Done()
}() }()
_, err := c.conn.Write(b) return c.conn.Write(b)
return err
} }
// Subscribe add current connection into subscribers of the given channel // Subscribe add current connection into subscribers of the given channel
@@ -146,7 +147,7 @@ func (c *Connection) GetPassword() string {
// InMultiState tells is connection in an uncommitted transaction // InMultiState tells is connection in an uncommitted transaction
func (c *Connection) InMultiState() bool { func (c *Connection) InMultiState() bool {
return c.multiState return c.flags&flagMulti > 0
} }
// SetMultiState sets transaction flag // SetMultiState sets transaction flag
@@ -154,8 +155,10 @@ func (c *Connection) SetMultiState(state bool) {
if !state { // reset data when cancel multi if !state { // reset data when cancel multi
c.watching = nil c.watching = nil
c.queue = nil c.queue = nil
c.flags &= ^flagMulti // clean multi flag
return
} }
c.multiState = state c.flags |= flagMulti
} }
// GetQueuedCmdLine returns queued commands of current transaction // GetQueuedCmdLine returns queued commands of current transaction
@@ -183,18 +186,6 @@ func (c *Connection) ClearQueuedCmds() {
c.queue = nil c.queue = nil
} }
// GetRole returns role of connection, such as connection with master
func (c *Connection) GetRole() int32 {
if c == nil {
return NormalCli
}
return c.role
}
func (c *Connection) SetRole(r int32) {
c.role = r
}
// GetWatching returns watching keys and their version code when started watching // GetWatching returns watching keys and their version code when started watching
func (c *Connection) GetWatching() map[string]uint32 { func (c *Connection) GetWatching() map[string]uint32 {
if c.watching == nil { if c.watching == nil {
@@ -213,24 +204,18 @@ func (c *Connection) SelectDB(dbNum int) {
c.selectedDB = dbNum c.selectedDB = dbNum
} }
// FakeConn implements redis.Connection for test func (c *Connection) SetSlave() {
type FakeConn struct { c.flags |= flagSlave
Connection
buf bytes.Buffer
} }
// Write writes data to buffer func (c *Connection) IsSlave() bool {
func (c *FakeConn) Write(b []byte) error { return c.flags&flagSlave > 0
c.buf.Write(b)
return nil
} }
// Clean resets the buffer func (c *Connection) SetMaster() {
func (c *FakeConn) Clean() { c.flags |= flagMaster
c.buf.Reset()
} }
// Bytes returns written data func (c *Connection) IsMaster() bool {
func (c *FakeConn) Bytes() []byte { return c.flags&flagMaster > 0
return c.buf.Bytes()
} }

79
redis/connection/fake.go Normal file
View File

@@ -0,0 +1,79 @@
package connection
import (
"bytes"
"io"
"sync"
)
// FakeConn implements redis.Connection for test
type FakeConn struct {
Connection
buf bytes.Buffer
wait chan struct{}
closed bool
mu sync.Mutex
}
func NewFakeConn() *FakeConn {
c := &FakeConn{}
return c
}
// Write writes data to buffer
func (c *FakeConn) Write(b []byte) (int, error) {
if c.closed {
return 0, io.EOF
}
n, _ := c.buf.Write(b)
c.notify()
return n, nil
}
func (c *FakeConn) notify() {
if c.wait != nil {
c.mu.Lock()
if c.wait != nil {
close(c.wait)
c.wait = nil
}
c.mu.Unlock()
}
}
func (c *FakeConn) waiting() {
c.mu.Lock()
c.wait = make(chan struct{})
c.mu.Unlock()
<-c.wait
}
// Read reads data from buffer
func (c *FakeConn) Read(p []byte) (int, error) {
n, err := c.buf.Read(p)
if err == io.EOF {
if c.closed {
return 0, io.EOF
}
c.waiting()
return c.buf.Read(p)
}
return n, err
}
// Clean resets the buffer
func (c *FakeConn) Clean() {
c.wait = make(chan struct{})
c.buf.Reset()
}
// Bytes returns written data
func (c *FakeConn) Bytes() []byte {
return c.buf.Bytes()
}
func (c *FakeConn) Close() error {
c.closed = true
c.notify()
return nil
}

View File

@@ -13,7 +13,7 @@ func TestPublish(t *testing.T) {
hub := pubsub.MakeHub() hub := pubsub.MakeHub()
channel := utils.RandString(5) channel := utils.RandString(5)
msg := utils.RandString(5) msg := utils.RandString(5)
conn := &connection.FakeConn{} conn := connection.NewFakeConn()
pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel)) pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel))
conn.Clean() // clean subscribe success conn.Clean() // clean subscribe success
pubsub.Publish(hub, utils.ToCmdLine(channel, msg)) pubsub.Publish(hub, utils.ToCmdLine(channel, msg))

View File

@@ -50,8 +50,6 @@ func (h *Handler) closeClient(client *connection.Connection) {
_ = client.Close() _ = client.Close()
h.db.AfterClientClose(client) h.db.AfterClientClose(client)
h.activeConn.Delete(client) h.activeConn.Delete(client)
client.GetConnPool().Put(client)
} }
// Handle receives and executes redis commands // Handle receives and executes redis commands
@@ -78,7 +76,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
} }
// protocol err // protocol err
errReply := protocol.MakeErrReply(payload.Err.Error()) errReply := protocol.MakeErrReply(payload.Err.Error())
err := client.Write(errReply.ToBytes()) _, err := client.Write(errReply.ToBytes())
if err != nil { if err != nil {
h.closeClient(client) h.closeClient(client)
logger.Info("connection closed: " + client.RemoteAddr().String()) logger.Info("connection closed: " + client.RemoteAddr().String())
@@ -97,9 +95,9 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
} }
result := h.db.Exec(client, r.Args) result := h.db.Exec(client, r.Args)
if result != nil { if result != nil {
_ = client.Write(result.ToBytes()) _, _ = client.Write(result.ToBytes())
} else { } else {
_ = client.Write(unknownErrReplyBytes) _, _ = client.Write(unknownErrReplyBytes)
} }
} }
} }