mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 08:46:56 +08:00
replication master side
This commit is contained in:
29
aof/aof.go
29
aof/aof.go
@@ -26,6 +26,10 @@ type payload struct {
|
||||
dbIndex int
|
||||
}
|
||||
|
||||
// Listener is a channel receive a replication of all aof payloads
|
||||
// with a listener we can forward the updates to slave nodes etc.
|
||||
type Listener chan<- CmdLine
|
||||
|
||||
// Handler receive msgs from channel and write to AOF file
|
||||
type Handler struct {
|
||||
db database.EmbedDB
|
||||
@@ -38,6 +42,7 @@ type Handler struct {
|
||||
// pause aof for start/finish aof rewrite progress
|
||||
pausingAof sync.RWMutex
|
||||
currentDB int
|
||||
listeners map[Listener]struct{}
|
||||
}
|
||||
|
||||
// NewAOFHandler creates a new aof.Handler
|
||||
@@ -54,12 +59,20 @@ func NewAOFHandler(db database.EmbedDB, tmpDBMaker func() database.EmbedDB) (*Ha
|
||||
handler.aofFile = aofFile
|
||||
handler.aofChan = make(chan *payload, aofQueueSize)
|
||||
handler.aofFinished = make(chan struct{})
|
||||
handler.listeners = make(map[Listener]struct{})
|
||||
go func() {
|
||||
handler.handleAof()
|
||||
}()
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
// RemoveListener removes a listener from aof handler, so we can close the listener
|
||||
func (handler *Handler) RemoveListener(listener Listener) {
|
||||
handler.pausingAof.Lock()
|
||||
defer handler.pausingAof.Unlock()
|
||||
delete(handler.listeners, listener)
|
||||
}
|
||||
|
||||
// AddAof send command to aof goroutine through channel
|
||||
func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
|
||||
if config.Properties.AppendOnly && handler.aofChan != nil {
|
||||
@@ -73,12 +86,16 @@ func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
|
||||
// handleAof listen aof channel and write into file
|
||||
func (handler *Handler) handleAof() {
|
||||
// serialized execution
|
||||
var cmdLines []CmdLine
|
||||
handler.currentDB = 0
|
||||
for p := range handler.aofChan {
|
||||
cmdLines = cmdLines[:0] // reuse underlying array
|
||||
handler.pausingAof.RLock() // prevent other goroutines from pausing aof
|
||||
if p.dbIndex != handler.currentDB {
|
||||
// select db
|
||||
data := protocol.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes()
|
||||
selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))
|
||||
cmdLines = append(cmdLines, selectCmd)
|
||||
data := protocol.MakeMultiBulkReply(selectCmd).ToBytes()
|
||||
_, err := handler.aofFile.Write(data)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
@@ -88,11 +105,17 @@ func (handler *Handler) handleAof() {
|
||||
handler.currentDB = p.dbIndex
|
||||
}
|
||||
data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes()
|
||||
cmdLines = append(cmdLines, p.cmdLine)
|
||||
_, err := handler.aofFile.Write(data)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
}
|
||||
handler.pausingAof.RUnlock()
|
||||
for listener := range handler.listeners {
|
||||
for _, line := range cmdLines {
|
||||
listener <- line
|
||||
}
|
||||
}
|
||||
}
|
||||
handler.aofFinished <- struct{}{}
|
||||
}
|
||||
@@ -123,7 +146,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
|
||||
reader = file
|
||||
}
|
||||
ch := parser.ParseStream(reader)
|
||||
fakeConn := &connection.FakeConn{} // only used for save dbIndex
|
||||
fakeConn := connection.NewFakeConn() // only used for save dbIndex
|
||||
for p := range ch {
|
||||
if p.Err != nil {
|
||||
if p.Err == io.EOF {
|
||||
@@ -143,7 +166,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
|
||||
}
|
||||
ret := handler.db.Exec(fakeConn, r.Args)
|
||||
if protocol.IsErrorReply(ret) {
|
||||
logger.Error("exec err", ret.ToBytes())
|
||||
logger.Error("exec err", string(ret.ToBytes()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
17
aof/rdb.go
17
aof/rdb.go
@@ -16,8 +16,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func (handler *Handler) Rewrite2RDB() error {
|
||||
ctx, err := handler.startRewrite2RDB()
|
||||
// todo: forbid concurrent rewrite
|
||||
|
||||
// Rewrite2RDB rewrite aof data into rdb
|
||||
// if extraListener is not nil, it will be appended to Handler.listeners, it will receive all updates after rdb
|
||||
func (handler *Handler) Rewrite2RDB(rdbFilename string, extraListener Listener) error {
|
||||
ctx, err := handler.startRewrite2RDB(extraListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -25,10 +29,6 @@ func (handler *Handler) Rewrite2RDB() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rdbFilename := config.Properties.RDBFilename
|
||||
if rdbFilename == "" {
|
||||
rdbFilename = "dump.rdb"
|
||||
}
|
||||
err = ctx.tmpFile.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -40,7 +40,7 @@ func (handler *Handler) Rewrite2RDB() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) {
|
||||
func (handler *Handler) startRewrite2RDB(extraListener Listener) (*RewriteCtx, error) {
|
||||
handler.pausingAof.Lock() // pausing aof
|
||||
defer handler.pausingAof.Unlock()
|
||||
|
||||
@@ -59,6 +59,9 @@ func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) {
|
||||
logger.Warn("tmp file create failed")
|
||||
return nil, err
|
||||
}
|
||||
if extraListener != nil {
|
||||
handler.listeners[extraListener] = struct{}{}
|
||||
}
|
||||
return &RewriteCtx{
|
||||
tmpFile: file,
|
||||
fileSize: filesize,
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
for i := 0; i < 1000; i++ {
|
||||
key := RandString(4)
|
||||
value := RandString(4)
|
||||
@@ -26,7 +26,7 @@ func TestAuth(t *testing.T) {
|
||||
defer func() {
|
||||
config.Properties.RequirePass = ""
|
||||
}()
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
ret := testNodeA.Exec(conn, toArgs("GET", "a"))
|
||||
asserts.AssertErrReply(t, ret, "NOAUTH Authentication required")
|
||||
ret = testNodeA.Exec(conn, toArgs("AUTH", passwd))
|
||||
@@ -39,7 +39,7 @@ func TestRelay(t *testing.T) {
|
||||
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
|
||||
key := RandString(4)
|
||||
value := RandString(4)
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value))
|
||||
asserts.AssertNotError(t, ret)
|
||||
ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key))
|
||||
@@ -50,7 +50,7 @@ func TestBroadcast(t *testing.T) {
|
||||
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
|
||||
key := RandString(4)
|
||||
value := RandString(4)
|
||||
rets := testCluster2.broadcast(&connection.FakeConn{}, toArgs("SET", key, value))
|
||||
rets := testCluster2.broadcast(connection.NewFakeConn(), toArgs("SET", key, value))
|
||||
for _, v := range rets {
|
||||
asserts.AssertNotError(t, v)
|
||||
}
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestDel(t *testing.T) {
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
allowFastTransaction = false
|
||||
testNodeA.Exec(conn, toArgs("SET", "a", "a"))
|
||||
ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c"))
|
||||
|
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
func TestMSet(t *testing.T) {
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
allowFastTransaction = false
|
||||
ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b"))
|
||||
asserts.AssertNotError(t, ret)
|
||||
@@ -16,7 +16,7 @@ func TestMSet(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMSetNx(t *testing.T) {
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
allowFastTransaction = false
|
||||
FlushAll(testNodeA, conn, toArgs("FLUSHALL"))
|
||||
ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b"))
|
||||
|
@@ -11,7 +11,7 @@ import (
|
||||
func TestPublish(t *testing.T) {
|
||||
channel := utils.RandString(5)
|
||||
msg := utils.RandString(5)
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
Subscribe(testNodeA, conn, utils.ToCmdLine("SUBSCRIBE", channel))
|
||||
conn.Clean() // clean subscribe success
|
||||
Publish(testNodeA, conn, utils.ToCmdLine("PUBLISH", channel, msg))
|
||||
|
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
conn.SelectDB(dbIndex)
|
||||
db.Exec(conn, utils.ToCmdLine("FlushDB"))
|
||||
cursor := 0
|
||||
@@ -49,7 +49,7 @@ func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
|
||||
}
|
||||
|
||||
func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) {
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
conn.SelectDB(dbIndex)
|
||||
cursor := 0
|
||||
var ret redis.Reply
|
||||
@@ -146,7 +146,7 @@ func TestRDB(t *testing.T) {
|
||||
dbNum := 4
|
||||
size := 10
|
||||
var prefixes []string
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
writeDB := NewStandaloneServer()
|
||||
for i := 0; i < dbNum; i++ {
|
||||
prefix := utils.RandString(8)
|
||||
@@ -216,7 +216,7 @@ func TestRewriteAOF2(t *testing.T) {
|
||||
}
|
||||
aofWriteDB := NewStandaloneServer()
|
||||
dbNum := 4
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
for i := 0; i < dbNum; i++ {
|
||||
conn.SelectDB(i)
|
||||
key := strconv.Itoa(i)
|
||||
|
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"github.com/hdt3213/godis/lib/utils"
|
||||
"github.com/hdt3213/godis/pubsub"
|
||||
"github.com/hdt3213/godis/redis/connection"
|
||||
"github.com/hdt3213/godis/redis/protocol"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
@@ -27,10 +26,10 @@ type MultiDB struct {
|
||||
// handle aof persistence
|
||||
aofHandler *aof.Handler
|
||||
|
||||
// store master node address
|
||||
slaveOf string
|
||||
// for replication
|
||||
role int32
|
||||
replication *slaveStatus
|
||||
slaveStatus *slaveStatus
|
||||
masterStatus *masterStatus
|
||||
}
|
||||
|
||||
// NewStandaloneServer creates a standalone redis server, with multi database and all other funtions
|
||||
@@ -50,6 +49,21 @@ func NewStandaloneServer() *MultiDB {
|
||||
mdb.hub = pubsub.MakeHub()
|
||||
validAof := false
|
||||
if config.Properties.AppendOnly {
|
||||
mdb.initAof()
|
||||
validAof = true
|
||||
}
|
||||
if config.Properties.RDBFilename != "" && !validAof {
|
||||
// load rdb
|
||||
loadRdbFile(mdb)
|
||||
}
|
||||
mdb.slaveStatus = initReplSlaveStatus()
|
||||
mdb.startAsMaster()
|
||||
mdb.startReplCron()
|
||||
mdb.role = masterRole // The initialization process does not require atomicity
|
||||
return mdb
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) initAof() {
|
||||
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB {
|
||||
return MakeBasicMultiDB()
|
||||
})
|
||||
@@ -63,16 +77,6 @@ func NewStandaloneServer() *MultiDB {
|
||||
mdb.aofHandler.AddAof(singleDB.index, line)
|
||||
}
|
||||
}
|
||||
validAof = true
|
||||
}
|
||||
if config.Properties.RDBFilename != "" && !validAof {
|
||||
// load rdb
|
||||
loadRdbFile(mdb)
|
||||
}
|
||||
mdb.replication = initReplStatus()
|
||||
mdb.startReplCron()
|
||||
mdb.role = masterRole // The initialization process does not require atomicity
|
||||
return mdb
|
||||
}
|
||||
|
||||
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
|
||||
@@ -117,8 +121,7 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
|
||||
|
||||
// read only slave
|
||||
role := atomic.LoadInt32(&mdb.role)
|
||||
if role == slaveRole &&
|
||||
c.GetRole() != connection.ReplicationRecvCli {
|
||||
if role == slaveRole && !c.IsMaster() {
|
||||
// only allow read only command, forbid all special commands except `auth` and `slaveof`
|
||||
if !isReadOnlyCommand(cmdName) {
|
||||
return protocol.MakeErrReply("READONLY You can't write against a read only slave.")
|
||||
@@ -167,6 +170,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
|
||||
return protocol.MakeArgNumErrReply("copy")
|
||||
}
|
||||
return execCopy(mdb, c, cmdLine[1:])
|
||||
} else if cmdName == "replconf" {
|
||||
return mdb.execReplConf(c, cmdLine[1:])
|
||||
} else if cmdName == "psync" {
|
||||
return mdb.execPSync(c, cmdLine[1:])
|
||||
}
|
||||
// todo: support multi database transaction
|
||||
|
||||
@@ -186,8 +193,8 @@ func (mdb *MultiDB) AfterClientClose(c redis.Connection) {
|
||||
|
||||
// Close graceful shutdown database
|
||||
func (mdb *MultiDB) Close() {
|
||||
// stop replication first
|
||||
mdb.replication.close()
|
||||
// stop slaveStatus first
|
||||
mdb.slaveStatus.close()
|
||||
if mdb.aofHandler != nil {
|
||||
mdb.aofHandler.Close()
|
||||
}
|
||||
@@ -308,7 +315,11 @@ func SaveRDB(db *MultiDB, args [][]byte) redis.Reply {
|
||||
if db.aofHandler == nil {
|
||||
return protocol.MakeErrReply("please enable aof before using save")
|
||||
}
|
||||
err := db.aofHandler.Rewrite2RDB()
|
||||
rdbFilename := config.Properties.RDBFilename
|
||||
if rdbFilename == "" {
|
||||
rdbFilename = "dump.rdb"
|
||||
}
|
||||
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
|
||||
if err != nil {
|
||||
return protocol.MakeErrReply(err.Error())
|
||||
}
|
||||
@@ -326,7 +337,11 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply {
|
||||
logger.Error(err)
|
||||
}
|
||||
}()
|
||||
err := db.aofHandler.Rewrite2RDB()
|
||||
rdbFilename := config.Properties.RDBFilename
|
||||
if rdbFilename == "" {
|
||||
rdbFilename = "dump.rdb"
|
||||
}
|
||||
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
@@ -339,3 +354,18 @@ func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) {
|
||||
db := mdb.mustSelectDB(dbIndex)
|
||||
return db.data.Len(), db.ttlMap.Len()
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) startReplCron() {
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("panic", err)
|
||||
}
|
||||
}()
|
||||
ticker := time.Tick(time.Second * 10)
|
||||
for range ticker {
|
||||
mdb.slaveCron()
|
||||
mdb.masterCron()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@@ -17,7 +17,7 @@ func TestLoadRDB(t *testing.T) {
|
||||
AppendOnly: false,
|
||||
RDBFilename: filepath.Join(projectRoot, "test.rdb"), // set working directory to project root
|
||||
}
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
rdbDB := NewStandaloneServer()
|
||||
result := rdbDB.Exec(conn, utils.ToCmdLine("Get", "str"))
|
||||
asserts.AssertBulkReply(t, result, "str")
|
||||
|
362
database/replication_master.go
Normal file
362
database/replication_master.go
Normal file
@@ -0,0 +1,362 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/hdt3213/godis/interface/redis"
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"github.com/hdt3213/godis/lib/utils"
|
||||
"github.com/hdt3213/godis/redis/protocol"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
slaveStateHandShake = uint8(iota)
|
||||
slaveStateWaitSaveEnd
|
||||
slaveStateSendingRDB
|
||||
slaveStateOnline
|
||||
)
|
||||
|
||||
const (
|
||||
bgSaveIdle = uint8(iota)
|
||||
bgSaveRunning
|
||||
bgSaveFinish
|
||||
)
|
||||
|
||||
const (
|
||||
salveCapacityNone = 0
|
||||
salveCapacityEOF = 1 << iota
|
||||
salveCapacityPsync2
|
||||
)
|
||||
|
||||
// slaveClient stores slave status in the view of master
|
||||
type slaveClient struct {
|
||||
conn redis.Connection
|
||||
state uint8
|
||||
offset int64
|
||||
lastAckTime time.Time
|
||||
announceIp string
|
||||
announcePort int
|
||||
capacity uint8
|
||||
}
|
||||
|
||||
type masterStatus struct {
|
||||
ctx context.Context
|
||||
mu sync.RWMutex
|
||||
replId string
|
||||
backlog []byte // backlog can be appended or replaced as a whole, cannot be modified(insert/set/delete)
|
||||
beginOffset int64
|
||||
currentOffset int64
|
||||
slaveMap map[redis.Connection]*slaveClient
|
||||
waitSlaves map[*slaveClient]struct{}
|
||||
onlineSlaves map[*slaveClient]struct{}
|
||||
bgSaveState uint8
|
||||
rdbFilename string
|
||||
}
|
||||
|
||||
func (master *masterStatus) appendBacklog(bin []byte) {
|
||||
master.backlog = append(master.backlog, bin...)
|
||||
master.currentOffset += int64(len(bin))
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) bgSaveForReplication() {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
logger.Errorf("panic: %v", e)
|
||||
}
|
||||
}()
|
||||
if err := mdb.saveForReplication(); err != nil {
|
||||
logger.Errorf("save for replication error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
}
|
||||
|
||||
// saveForReplication does bg-save and send rdb to waiting slaves
|
||||
func (mdb *MultiDB) saveForReplication() error {
|
||||
rdbFile, err := ioutil.TempFile("", "*.rdb")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp rdb failed: %v", err)
|
||||
}
|
||||
rdbFilename := rdbFile.Name()
|
||||
mdb.masterStatus.mu.Lock()
|
||||
mdb.masterStatus.bgSaveState = bgSaveRunning
|
||||
mdb.masterStatus.rdbFilename = rdbFilename // todo: can reuse config.Properties.RDBFilename?
|
||||
mdb.masterStatus.mu.Unlock()
|
||||
|
||||
aofListener := make(chan CmdLine, 1024) // give channel enough capacity to store all updates during rewrite to db
|
||||
err = mdb.aofHandler.Rewrite2RDB(rdbFilename, aofListener)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
mdb.masterListenAof(aofListener)
|
||||
}()
|
||||
|
||||
// change bgSaveState and get waitSlaves for sending
|
||||
waitSlaves := make(map[*slaveClient]struct{})
|
||||
mdb.masterStatus.mu.Lock()
|
||||
mdb.masterStatus.bgSaveState = bgSaveFinish
|
||||
for slave := range mdb.masterStatus.waitSlaves {
|
||||
waitSlaves[slave] = struct{}{}
|
||||
}
|
||||
mdb.masterStatus.waitSlaves = nil
|
||||
mdb.masterStatus.mu.Unlock()
|
||||
|
||||
for slave := range waitSlaves {
|
||||
err = mdb.masterFullReSyncWithSlave(slave)
|
||||
if err != nil {
|
||||
mdb.removeSlave(slave)
|
||||
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// masterFullReSyncWithSlave send replication header, rdb file and all backlogs to slave
|
||||
func (mdb *MultiDB) masterFullReSyncWithSlave(slave *slaveClient) error {
|
||||
// write replication header
|
||||
header := "+FULLRESYNC " + mdb.masterStatus.replId + " " +
|
||||
strconv.FormatInt(mdb.masterStatus.beginOffset, 10) + protocol.CRLF
|
||||
_, err := slave.conn.Write([]byte(header))
|
||||
if err != nil {
|
||||
return fmt.Errorf("write replication header to slave failed: %v", err)
|
||||
}
|
||||
// send rdb
|
||||
rdbFile, err := os.Open(mdb.masterStatus.rdbFilename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open rdb file %s for replication error: %v", mdb.masterStatus.rdbFilename, err)
|
||||
}
|
||||
slave.state = slaveStateSendingRDB
|
||||
rdbInfo, _ := os.Stat(mdb.masterStatus.rdbFilename)
|
||||
rdbSize := rdbInfo.Size()
|
||||
rdbHeader := "$" + strconv.FormatInt(rdbSize, 10) + protocol.CRLF
|
||||
_, err = slave.conn.Write([]byte(rdbHeader))
|
||||
if err != nil {
|
||||
return fmt.Errorf("write rdb header to slave failed: %v", err)
|
||||
}
|
||||
_, err = io.Copy(slave.conn, rdbFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write rdb file to slave failed: %v", err)
|
||||
}
|
||||
|
||||
// send backlog
|
||||
mdb.masterStatus.mu.RLock()
|
||||
currentOffset := mdb.masterStatus.currentOffset
|
||||
backlog := mdb.masterStatus.backlog[:currentOffset-mdb.masterStatus.beginOffset]
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
_, err = slave.conn.Write(backlog)
|
||||
if err != nil {
|
||||
return fmt.Errorf("full resync write backlog to slave failed: %v", err)
|
||||
}
|
||||
|
||||
// set slave as online
|
||||
mdb.setSlaveOnline(slave, currentOffset)
|
||||
return nil
|
||||
}
|
||||
|
||||
var cannotPartialSync = errors.New("cannot do partial sync")
|
||||
|
||||
func (mdb *MultiDB) masterTryPartialSyncWithSlave(slave *slaveClient, replId string, slaveOffset int64) error {
|
||||
mdb.masterStatus.mu.RLock()
|
||||
if replId != mdb.masterStatus.replId {
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
return cannotPartialSync
|
||||
}
|
||||
if slaveOffset < mdb.masterStatus.beginOffset || slaveOffset > mdb.masterStatus.currentOffset {
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
return cannotPartialSync
|
||||
}
|
||||
currentOffset := mdb.masterStatus.currentOffset
|
||||
backlog := mdb.masterStatus.backlog[slaveOffset-mdb.masterStatus.beginOffset : currentOffset-mdb.masterStatus.beginOffset]
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
|
||||
// send replication header
|
||||
header := "+CONTINUE " + mdb.masterStatus.replId + protocol.CRLF
|
||||
_, err := slave.conn.Write([]byte(header))
|
||||
if err != nil {
|
||||
return fmt.Errorf("write replication header to slave failed: %v", err)
|
||||
}
|
||||
// send backlog
|
||||
_, err = slave.conn.Write(backlog)
|
||||
if err != nil {
|
||||
return fmt.Errorf("partial resync write backlog to slave failed: %v", err)
|
||||
}
|
||||
|
||||
// set slave online
|
||||
mdb.setSlaveOnline(slave, currentOffset)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) masterSendUpdatesToSlave() error {
|
||||
onlineSlaves := make(map[*slaveClient]struct{})
|
||||
mdb.masterStatus.mu.RLock()
|
||||
currentOffset := mdb.masterStatus.currentOffset
|
||||
beginOffset := mdb.masterStatus.beginOffset
|
||||
backlog := mdb.masterStatus.backlog[:currentOffset-beginOffset]
|
||||
for slave := range mdb.masterStatus.onlineSlaves {
|
||||
onlineSlaves[slave] = struct{}{}
|
||||
}
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
for slave := range onlineSlaves {
|
||||
slaveBeginOffset := slave.offset - beginOffset
|
||||
_, err := slave.conn.Write(backlog[slaveBeginOffset:])
|
||||
if err != nil {
|
||||
logger.Errorf("send updates write backlog to slave failed: %v", err)
|
||||
mdb.removeSlave(slave)
|
||||
continue
|
||||
}
|
||||
slave.offset = currentOffset
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) execPSync(c redis.Connection, args [][]byte) redis.Reply {
|
||||
replId := string(args[0])
|
||||
replOffset, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
mdb.masterStatus.mu.Lock()
|
||||
defer mdb.masterStatus.mu.Unlock()
|
||||
slave := mdb.masterStatus.slaveMap[c]
|
||||
if slave == nil {
|
||||
slave = &slaveClient{
|
||||
conn: c,
|
||||
}
|
||||
c.SetSlave()
|
||||
mdb.masterStatus.slaveMap[c] = slave
|
||||
}
|
||||
if mdb.masterStatus.bgSaveState == bgSaveIdle {
|
||||
slave.state = slaveStateWaitSaveEnd
|
||||
mdb.masterStatus.waitSlaves[slave] = struct{}{}
|
||||
mdb.bgSaveForReplication()
|
||||
} else if mdb.masterStatus.bgSaveState == bgSaveRunning {
|
||||
slave.state = slaveStateWaitSaveEnd
|
||||
mdb.masterStatus.waitSlaves[slave] = struct{}{}
|
||||
} else if mdb.masterStatus.bgSaveState == bgSaveFinish {
|
||||
go func() {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
logger.Errorf("panic: %v", e)
|
||||
}
|
||||
}()
|
||||
err := mdb.masterTryPartialSyncWithSlave(slave, replId, replOffset)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if err != nil && err != cannotPartialSync {
|
||||
mdb.removeSlave(slave)
|
||||
logger.Errorf("masterTryPartialSyncWithSlave error: %v", err)
|
||||
return
|
||||
}
|
||||
// assert err == cannotPartialSync
|
||||
if err := mdb.masterFullReSyncWithSlave(slave); err != nil {
|
||||
mdb.removeSlave(slave)
|
||||
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
return &protocol.NoReply{}
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) execReplConf(c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args)%2 != 0 {
|
||||
return protocol.MakeSyntaxErrReply()
|
||||
}
|
||||
mdb.masterStatus.mu.RLock()
|
||||
slave := mdb.masterStatus.slaveMap[c]
|
||||
mdb.masterStatus.mu.RUnlock()
|
||||
for i := 0; i < len(args); i += 2 {
|
||||
key := strings.ToLower(string(args[i]))
|
||||
value := string(args[i+1])
|
||||
switch key {
|
||||
case "ack":
|
||||
offset, err := strconv.ParseInt(value, 10, 64)
|
||||
if err != nil {
|
||||
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
slave.offset = offset
|
||||
slave.lastAckTime = time.Now()
|
||||
return &protocol.NoReply{}
|
||||
}
|
||||
}
|
||||
return protocol.MakeOkReply()
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) removeSlave(slave *slaveClient) {
|
||||
mdb.masterStatus.mu.Lock()
|
||||
defer mdb.masterStatus.mu.Unlock()
|
||||
_ = slave.conn.Close()
|
||||
delete(mdb.masterStatus.slaveMap, slave.conn)
|
||||
delete(mdb.masterStatus.waitSlaves, slave)
|
||||
delete(mdb.masterStatus.onlineSlaves, slave)
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) setSlaveOnline(slave *slaveClient, currentOffset int64) {
|
||||
mdb.masterStatus.mu.Lock()
|
||||
defer mdb.masterStatus.mu.Unlock()
|
||||
slave.state = slaveStateOnline
|
||||
slave.offset = currentOffset
|
||||
mdb.masterStatus.onlineSlaves[slave] = struct{}{}
|
||||
}
|
||||
|
||||
var pingBytes = protocol.MakeMultiBulkReply(utils.ToCmdLine("ping")).ToBytes()
|
||||
|
||||
func (mdb *MultiDB) masterCron() {
|
||||
if mdb.role != masterRole {
|
||||
return
|
||||
}
|
||||
mdb.masterStatus.mu.Lock()
|
||||
if mdb.masterStatus.bgSaveState == bgSaveFinish {
|
||||
mdb.masterStatus.appendBacklog(pingBytes)
|
||||
}
|
||||
mdb.masterStatus.mu.Unlock()
|
||||
if err := mdb.masterSendUpdatesToSlave(); err != nil {
|
||||
logger.Errorf("masterSendUpdatesToSlave error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) masterListenAof(listener chan CmdLine) {
|
||||
for {
|
||||
select {
|
||||
case cmdLine := <-listener:
|
||||
mdb.masterStatus.mu.Lock()
|
||||
reply := protocol.MakeMultiBulkReply(cmdLine)
|
||||
mdb.masterStatus.appendBacklog(reply.ToBytes())
|
||||
mdb.masterStatus.mu.Unlock()
|
||||
if err := mdb.masterSendUpdatesToSlave(); err != nil {
|
||||
logger.Errorf("masterSendUpdatesToSlave after receive aof error: %v", err)
|
||||
}
|
||||
// if bgSave is running, updates will be sent after the save finished
|
||||
case <-mdb.masterStatus.ctx.Done():
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) startAsMaster() {
|
||||
mdb.masterStatus = &masterStatus{
|
||||
ctx: context.Background(),
|
||||
mu: sync.RWMutex{},
|
||||
replId: utils.RandHexString(40),
|
||||
backlog: nil,
|
||||
beginOffset: 0,
|
||||
currentOffset: 0,
|
||||
slaveMap: make(map[redis.Connection]*slaveClient),
|
||||
waitSlaves: make(map[*slaveClient]struct{}),
|
||||
onlineSlaves: make(map[*slaveClient]struct{}),
|
||||
bgSaveState: bgSaveIdle,
|
||||
rdbFilename: "",
|
||||
}
|
||||
}
|
194
database/replication_master_test.go
Normal file
194
database/replication_master_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/hdt3213/godis/config"
|
||||
"github.com/hdt3213/godis/lib/utils"
|
||||
"github.com/hdt3213/godis/redis/connection"
|
||||
"github.com/hdt3213/godis/redis/parser"
|
||||
"github.com/hdt3213/godis/redis/protocol"
|
||||
"github.com/hdt3213/godis/redis/protocol/asserts"
|
||||
rdb "github.com/hdt3213/rdb/parser"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func mockServer() *MultiDB {
|
||||
server := &MultiDB{}
|
||||
server.dbSet = make([]*atomic.Value, 16)
|
||||
for i := range server.dbSet {
|
||||
singleDB := makeDB()
|
||||
singleDB.index = i
|
||||
holder := &atomic.Value{}
|
||||
holder.Store(singleDB)
|
||||
server.dbSet[i] = holder
|
||||
}
|
||||
server.slaveStatus = initReplSlaveStatus()
|
||||
return server
|
||||
}
|
||||
|
||||
func TestReplicationMasterSide(t *testing.T) {
|
||||
tmpDir, err := ioutil.TempDir("", "godis")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
aofFilename := path.Join(tmpDir, "a.aof")
|
||||
defer func() {
|
||||
_ = os.Remove(aofFilename)
|
||||
}()
|
||||
config.Properties = &config.ServerProperties{
|
||||
Databases: 16,
|
||||
AppendOnly: true,
|
||||
AppendFilename: aofFilename,
|
||||
}
|
||||
master := mockServer()
|
||||
master.initAof()
|
||||
master.startAsMaster()
|
||||
slave := mockServer()
|
||||
replConn := connection.NewFakeConn()
|
||||
|
||||
// set data to master
|
||||
masterConn := connection.NewFakeConn()
|
||||
resp := master.Exec(masterConn, utils.ToCmdLine("SET", "a", "a"))
|
||||
asserts.AssertNotError(t, resp)
|
||||
time.Sleep(time.Millisecond * 100) // wait write aof
|
||||
|
||||
// full re-sync
|
||||
master.Exec(replConn, utils.ToCmdLine("psync", "?", "-1"))
|
||||
masterChan := parser.ParseStream(replConn)
|
||||
psyncPayload := <-masterChan
|
||||
if psyncPayload.Err != nil {
|
||||
t.Errorf("master bad protocol: %v", psyncPayload.Err)
|
||||
return
|
||||
}
|
||||
psyncHeader, ok := psyncPayload.Data.(*protocol.StatusReply)
|
||||
if !ok {
|
||||
t.Error("psync header is not a status reply")
|
||||
return
|
||||
}
|
||||
headers := strings.Split(psyncHeader.Status, " ")
|
||||
if len(headers) != 3 {
|
||||
t.Errorf("illegal psync header: %s", psyncHeader.Status)
|
||||
return
|
||||
}
|
||||
|
||||
replId := headers[1]
|
||||
replOffset, err := strconv.ParseInt(headers[2], 10, 64)
|
||||
if err != nil {
|
||||
t.Errorf("illegal offset: %s", headers[2])
|
||||
return
|
||||
}
|
||||
t.Logf("repl id: %s, offset: %d", replId, replOffset)
|
||||
|
||||
rdbPayload := <-masterChan
|
||||
if rdbPayload.Err != nil {
|
||||
t.Error("read response failed: " + rdbPayload.Err.Error())
|
||||
return
|
||||
}
|
||||
rdbReply, ok := rdbPayload.Data.(*protocol.BulkReply)
|
||||
if !ok {
|
||||
t.Error("illegal payload header: " + string(rdbPayload.Data.ToBytes()))
|
||||
return
|
||||
}
|
||||
|
||||
rdbDec := rdb.NewDecoder(bytes.NewReader(rdbReply.Arg))
|
||||
err = importRDB(rdbDec, slave)
|
||||
if err != nil {
|
||||
t.Error("import rdb failed: " + err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// get a
|
||||
slaveConn := connection.NewFakeConn()
|
||||
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "a"))
|
||||
asserts.AssertBulkReply(t, resp, "a")
|
||||
|
||||
/*---- test broadcast aof ----*/
|
||||
masterConn = connection.NewFakeConn()
|
||||
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "b", "b"))
|
||||
time.Sleep(time.Millisecond * 100) // wait write aof
|
||||
asserts.AssertNotError(t, resp)
|
||||
master.masterCron()
|
||||
for {
|
||||
payload := <-masterChan
|
||||
if payload.Err != nil {
|
||||
t.Error(payload.Err)
|
||||
return
|
||||
}
|
||||
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
|
||||
if !ok {
|
||||
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
|
||||
return
|
||||
}
|
||||
slave.Exec(replConn, cmdLine.Args)
|
||||
n := len(cmdLine.ToBytes())
|
||||
slave.slaveStatus.replOffset += int64(n)
|
||||
if string(cmdLine.Args[0]) != "ping" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "b"))
|
||||
asserts.AssertBulkReply(t, resp, "b")
|
||||
|
||||
/*---- test partial reconnect ----*/
|
||||
_ = replConn.Close() // mock disconnect
|
||||
|
||||
replConn = connection.NewFakeConn()
|
||||
|
||||
master.Exec(replConn, utils.ToCmdLine("psync", replId,
|
||||
strconv.FormatInt(slave.slaveStatus.replOffset, 10)))
|
||||
masterChan = parser.ParseStream(replConn)
|
||||
psyncPayload = <-masterChan
|
||||
if psyncPayload.Err != nil {
|
||||
t.Errorf("master bad protocol: %v", psyncPayload.Err)
|
||||
return
|
||||
}
|
||||
psyncHeader, ok = psyncPayload.Data.(*protocol.StatusReply)
|
||||
if !ok {
|
||||
t.Error("psync header is not a status reply")
|
||||
return
|
||||
}
|
||||
headers = strings.Split(psyncHeader.Status, " ")
|
||||
if len(headers) != 2 {
|
||||
t.Errorf("illegal psync header: %s", psyncHeader.Status)
|
||||
return
|
||||
}
|
||||
if headers[0] != "CONTINUE" {
|
||||
t.Errorf("expect CONTINUE actual %s", headers[0])
|
||||
return
|
||||
}
|
||||
replId = headers[1]
|
||||
t.Logf("partial resync repl id: %s, offset: %d", replId, slave.slaveStatus.replOffset)
|
||||
|
||||
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "c", "c"))
|
||||
time.Sleep(time.Millisecond * 100) // wait write aof
|
||||
asserts.AssertNotError(t, resp)
|
||||
master.masterCron()
|
||||
for {
|
||||
payload := <-masterChan
|
||||
if payload.Err != nil {
|
||||
t.Error(payload.Err)
|
||||
return
|
||||
}
|
||||
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
|
||||
if !ok {
|
||||
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
|
||||
return
|
||||
}
|
||||
slave.Exec(replConn, cmdLine.Args)
|
||||
if string(cmdLine.Args[0]) != "ping" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "c"))
|
||||
asserts.AssertBulkReply(t, resp, "c")
|
||||
}
|
@@ -31,9 +31,9 @@ type slaveStatus struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// configVersion stands for the version of replication config. Any change of master host/port will cause configVersion increment
|
||||
// If configVersion change has been found during replication current replication procedure will stop.
|
||||
// It is designed to abort a running replication procedure
|
||||
// configVersion stands for the version of slaveStatus config. Any change of master host/port will cause configVersion increment
|
||||
// If configVersion change has been found during slaveStatus current slaveStatus procedure will stop.
|
||||
// It is designed to abort a running slaveStatus procedure
|
||||
configVersion int32
|
||||
|
||||
masterHost string
|
||||
@@ -47,26 +47,10 @@ type slaveStatus struct {
|
||||
running sync.WaitGroup
|
||||
}
|
||||
|
||||
var configChangedErr = errors.New("replication config changed")
|
||||
var configChangedErr = errors.New("slaveStatus config changed")
|
||||
|
||||
func initReplStatus() *slaveStatus {
|
||||
repl := &slaveStatus{}
|
||||
// start cron
|
||||
return repl
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) startReplCron() {
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error("panic", err)
|
||||
}
|
||||
}()
|
||||
ticker := time.Tick(time.Second)
|
||||
for range ticker {
|
||||
mdb.slaveCron()
|
||||
}
|
||||
}()
|
||||
func initReplSlaveStatus() *slaveStatus {
|
||||
return &slaveStatus{}
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
|
||||
@@ -80,29 +64,29 @@ func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
|
||||
if err != nil {
|
||||
return protocol.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
mdb.replication.mutex.Lock()
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
atomic.StoreInt32(&mdb.role, slaveRole)
|
||||
mdb.replication.masterHost = host
|
||||
mdb.replication.masterPort = port
|
||||
mdb.slaveStatus.masterHost = host
|
||||
mdb.slaveStatus.masterPort = port
|
||||
// use buffered channel in case receiver goroutine exited before controller send stop signal
|
||||
atomic.AddInt32(&mdb.replication.configVersion, 1)
|
||||
mdb.replication.mutex.Unlock()
|
||||
atomic.AddInt32(&mdb.slaveStatus.configVersion, 1)
|
||||
mdb.slaveStatus.mutex.Unlock()
|
||||
go mdb.setupMaster()
|
||||
return protocol.MakeOkReply()
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) slaveOfNone() {
|
||||
mdb.replication.mutex.Lock()
|
||||
defer mdb.replication.mutex.Unlock()
|
||||
mdb.replication.masterHost = ""
|
||||
mdb.replication.masterPort = 0
|
||||
mdb.replication.replId = ""
|
||||
mdb.replication.replOffset = -1
|
||||
mdb.replication.stopSlaveWithMutex()
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
defer mdb.slaveStatus.mutex.Unlock()
|
||||
mdb.slaveStatus.masterHost = ""
|
||||
mdb.slaveStatus.masterPort = 0
|
||||
mdb.slaveStatus.replId = ""
|
||||
mdb.slaveStatus.replOffset = -1
|
||||
mdb.slaveStatus.stopSlaveWithMutex()
|
||||
}
|
||||
|
||||
// stopSlaveWithMutex stops in-progress connectWithMaster/fullSync/receiveAOF
|
||||
// invoker should have replication mutex
|
||||
// invoker should have slaveStatus mutex
|
||||
func (repl *slaveStatus) stopSlaveWithMutex() {
|
||||
// update configVersion to stop connectWithMaster and fullSync
|
||||
atomic.AddInt32(&repl.configVersion, 1)
|
||||
@@ -135,11 +119,11 @@ func (mdb *MultiDB) setupMaster() {
|
||||
}()
|
||||
var configVersion int32
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
mdb.replication.mutex.Lock()
|
||||
mdb.replication.ctx = ctx
|
||||
mdb.replication.cancel = cancel
|
||||
configVersion = mdb.replication.configVersion
|
||||
mdb.replication.mutex.Unlock()
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
mdb.slaveStatus.ctx = ctx
|
||||
mdb.slaveStatus.cancel = cancel
|
||||
configVersion = mdb.slaveStatus.configVersion
|
||||
mdb.slaveStatus.mutex.Unlock()
|
||||
isFullReSync, err := mdb.connectWithMaster(configVersion)
|
||||
if err != nil {
|
||||
// connect failed, abort master
|
||||
@@ -167,7 +151,7 @@ func (mdb *MultiDB) setupMaster() {
|
||||
// connectWithMaster finishes handshake with master
|
||||
// returns: isFullReSync, error
|
||||
func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
|
||||
addr := mdb.replication.masterHost + ":" + strconv.Itoa(mdb.replication.masterPort)
|
||||
addr := mdb.slaveStatus.masterHost + ":" + strconv.Itoa(mdb.slaveStatus.masterPort)
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
mdb.slaveOfNone() // abort
|
||||
@@ -256,34 +240,34 @@ func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
|
||||
}
|
||||
|
||||
// update connection
|
||||
mdb.replication.mutex.Lock()
|
||||
defer mdb.replication.mutex.Unlock()
|
||||
if mdb.replication.configVersion != configVersion {
|
||||
// replication conf changed during connecting and waiting mutex
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
defer mdb.slaveStatus.mutex.Unlock()
|
||||
if mdb.slaveStatus.configVersion != configVersion {
|
||||
// slaveStatus conf changed during connecting and waiting mutex
|
||||
return false, configChangedErr
|
||||
}
|
||||
mdb.replication.masterConn = conn
|
||||
mdb.replication.masterChan = masterChan
|
||||
mdb.replication.lastRecvTime = time.Now()
|
||||
mdb.slaveStatus.masterConn = conn
|
||||
mdb.slaveStatus.masterChan = masterChan
|
||||
mdb.slaveStatus.lastRecvTime = time.Now()
|
||||
return mdb.psyncHandshake()
|
||||
}
|
||||
|
||||
// psyncHandshake send `psync` to master and sync repl-id/offset with master
|
||||
// invoker should provide with replication.mutex
|
||||
// invoker should provide with slaveStatus.mutex
|
||||
func (mdb *MultiDB) psyncHandshake() (bool, error) {
|
||||
replId := "?"
|
||||
var replOffset int64 = -1
|
||||
if mdb.replication.replId != "" {
|
||||
replId = mdb.replication.replId
|
||||
replOffset = mdb.replication.replOffset
|
||||
if mdb.slaveStatus.replId != "" {
|
||||
replId = mdb.slaveStatus.replId
|
||||
replOffset = mdb.slaveStatus.replOffset
|
||||
}
|
||||
psyncCmdLine := utils.ToCmdLine("psync", replId, strconv.FormatInt(replOffset, 10))
|
||||
psyncReq := protocol.MakeMultiBulkReply(psyncCmdLine)
|
||||
_, err := mdb.replication.masterConn.Write(psyncReq.ToBytes())
|
||||
_, err := mdb.slaveStatus.masterConn.Write(psyncReq.ToBytes())
|
||||
if err != nil {
|
||||
return false, errors.New("send failed " + err.Error())
|
||||
}
|
||||
psyncPayload := <-mdb.replication.masterChan
|
||||
psyncPayload := <-mdb.slaveStatus.masterChan
|
||||
if psyncPayload.Err != nil {
|
||||
return false, errors.New("read response failed: " + psyncPayload.Err.Error())
|
||||
}
|
||||
@@ -300,12 +284,12 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
|
||||
var isFullReSync bool
|
||||
if headers[0] == "FULLRESYNC" {
|
||||
logger.Info("full re-sync with master")
|
||||
mdb.replication.replId = headers[1]
|
||||
mdb.replication.replOffset, err = strconv.ParseInt(headers[2], 10, 64)
|
||||
mdb.slaveStatus.replId = headers[1]
|
||||
mdb.slaveStatus.replOffset, err = strconv.ParseInt(headers[2], 10, 64)
|
||||
isFullReSync = true
|
||||
} else if headers[0] == "CONTINUE" {
|
||||
logger.Info("continue partial sync")
|
||||
mdb.replication.replId = headers[1]
|
||||
mdb.slaveStatus.replId = headers[1]
|
||||
isFullReSync = false
|
||||
} else {
|
||||
return false, errors.New("illegal psync resp: " + psyncHeader.Status)
|
||||
@@ -314,13 +298,13 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
|
||||
if err != nil {
|
||||
return false, errors.New("get illegal repl offset: " + headers[2])
|
||||
}
|
||||
logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.replication.replId, mdb.replication.replOffset))
|
||||
logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.slaveStatus.replId, mdb.slaveStatus.replOffset))
|
||||
return isFullReSync, nil
|
||||
}
|
||||
|
||||
// loadMasterRDB downloads rdb after handshake has been done
|
||||
func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
|
||||
rdbPayload := <-mdb.replication.masterChan
|
||||
rdbPayload := <-mdb.slaveStatus.masterChan
|
||||
if rdbPayload.Err != nil {
|
||||
return errors.New("read response failed: " + rdbPayload.Err.Error())
|
||||
}
|
||||
@@ -337,10 +321,10 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
|
||||
return errors.New("dump rdb failed: " + err.Error())
|
||||
}
|
||||
|
||||
mdb.replication.mutex.Lock()
|
||||
defer mdb.replication.mutex.Unlock()
|
||||
if mdb.replication.configVersion != configVersion {
|
||||
// replication conf changed during connecting and waiting mutex
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
defer mdb.slaveStatus.mutex.Unlock()
|
||||
if mdb.slaveStatus.configVersion != configVersion {
|
||||
// slaveStatus conf changed during connecting and waiting mutex
|
||||
return configChangedErr
|
||||
}
|
||||
for i, h := range rdbHolder.dbSet {
|
||||
@@ -353,13 +337,13 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
|
||||
conn := connection.NewConn(mdb.replication.masterConn)
|
||||
conn.SetRole(connection.ReplicationRecvCli)
|
||||
mdb.replication.running.Add(1)
|
||||
defer mdb.replication.running.Done()
|
||||
conn := connection.NewConn(mdb.slaveStatus.masterConn)
|
||||
conn.SetMaster()
|
||||
mdb.slaveStatus.running.Add(1)
|
||||
defer mdb.slaveStatus.running.Done()
|
||||
for {
|
||||
select {
|
||||
case payload, open := <-mdb.replication.masterChan:
|
||||
case payload, open := <-mdb.slaveStatus.masterChan:
|
||||
if !open {
|
||||
return errors.New("master channel unexpected close")
|
||||
}
|
||||
@@ -370,31 +354,32 @@ func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
|
||||
if !ok {
|
||||
return errors.New("unexpected payload: " + string(payload.Data.ToBytes()))
|
||||
}
|
||||
mdb.replication.mutex.Lock()
|
||||
if mdb.replication.configVersion != configVersion {
|
||||
// replication conf changed during connecting and waiting mutex
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
if mdb.slaveStatus.configVersion != configVersion {
|
||||
// slaveStatus conf changed during connecting and waiting mutex
|
||||
return configChangedErr
|
||||
}
|
||||
mdb.Exec(conn, cmdLine.Args)
|
||||
n := len(cmdLine.ToBytes()) // todo: directly get size from socket
|
||||
mdb.replication.replOffset += int64(n)
|
||||
mdb.replication.lastRecvTime = time.Now()
|
||||
mdb.slaveStatus.replOffset += int64(n)
|
||||
mdb.slaveStatus.lastRecvTime = time.Now()
|
||||
logger.Info(fmt.Sprintf("receive %d bytes from master, current offset %d, %s",
|
||||
n, mdb.replication.replOffset, strconv.Quote(string(cmdLine.ToBytes()))))
|
||||
mdb.replication.mutex.Unlock()
|
||||
n, mdb.slaveStatus.replOffset, strconv.Quote(string(cmdLine.ToBytes()))))
|
||||
mdb.slaveStatus.mutex.Unlock()
|
||||
case <-ctx.Done():
|
||||
conn.GetConnPool().Put(conn)
|
||||
_ = conn.Close()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mdb *MultiDB) slaveCron() {
|
||||
repl := mdb.replication
|
||||
repl := mdb.slaveStatus
|
||||
if repl.masterConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// check master timeout
|
||||
replTimeout := 60 * time.Second
|
||||
if config.Properties.ReplTimeout != 0 {
|
||||
replTimeout = time.Duration(config.Properties.ReplTimeout) * time.Second
|
||||
@@ -427,9 +412,9 @@ func (repl *slaveStatus) sendAck2Master() error {
|
||||
|
||||
func (mdb *MultiDB) reconnectWithMaster() error {
|
||||
logger.Info("reconnecting with master")
|
||||
mdb.replication.mutex.Lock()
|
||||
defer mdb.replication.mutex.Unlock()
|
||||
mdb.replication.stopSlaveWithMutex()
|
||||
mdb.slaveStatus.mutex.Lock()
|
||||
defer mdb.slaveStatus.mutex.Unlock()
|
||||
mdb.slaveStatus.stopSlaveWithMutex()
|
||||
go mdb.setupMaster()
|
||||
return nil
|
||||
}
|
@@ -8,22 +8,12 @@ import (
|
||||
"github.com/hdt3213/godis/redis/connection"
|
||||
"github.com/hdt3213/godis/redis/protocol"
|
||||
"github.com/hdt3213/godis/redis/protocol/asserts"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReplication(t *testing.T) {
|
||||
mdb := &MultiDB{}
|
||||
mdb.dbSet = make([]*atomic.Value, 16)
|
||||
for i := range mdb.dbSet {
|
||||
singleDB := makeDB()
|
||||
singleDB.index = i
|
||||
holder := &atomic.Value{}
|
||||
holder.Store(singleDB)
|
||||
mdb.dbSet[i] = holder
|
||||
}
|
||||
mdb.replication = initReplStatus()
|
||||
func TestReplicationSlaveSide(t *testing.T) {
|
||||
mdb := mockServer()
|
||||
masterCli, err := client.MakeClient("127.0.0.1:6379")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@@ -33,7 +23,7 @@ func TestReplication(t *testing.T) {
|
||||
// sync with master
|
||||
ret := masterCli.Send(utils.ToCmdLine("set", "1", "1"))
|
||||
asserts.AssertStatusReply(t, ret, "OK")
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
ret = mdb.Exec(conn, utils.ToCmdLine("SLAVEOF", "127.0.0.1", "6379"))
|
||||
asserts.AssertStatusReply(t, ret, "OK")
|
||||
success := false
|
||||
@@ -74,7 +64,7 @@ func TestReplication(t *testing.T) {
|
||||
t.Error("sync failed")
|
||||
return
|
||||
}
|
||||
err = mdb.replication.sendAck2Master()
|
||||
err = mdb.slaveStatus.sendAck2Master()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -83,8 +73,8 @@ func TestReplication(t *testing.T) {
|
||||
|
||||
// test reconnect
|
||||
config.Properties.ReplTimeout = 1
|
||||
_ = mdb.replication.masterConn.Close()
|
||||
mdb.replication.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout
|
||||
_ = mdb.slaveStatus.masterConn.Close()
|
||||
mdb.slaveStatus.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout
|
||||
mdb.slaveCron()
|
||||
time.Sleep(3 * time.Second)
|
||||
ret = masterCli.Send(utils.ToCmdLine("set", "1", "3"))
|
||||
@@ -134,7 +124,7 @@ func TestReplication(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
err = mdb.replication.close()
|
||||
err = mdb.slaveStatus.close()
|
||||
if err != nil {
|
||||
t.Error("cannot close")
|
||||
}
|
@@ -20,7 +20,7 @@ func TestPing(t *testing.T) {
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
passwd := utils.RandString(10)
|
||||
c := &connection.FakeConn{}
|
||||
c := connection.NewFakeConn()
|
||||
ret := testServer.Exec(c, utils.ToCmdLine("AUTH"))
|
||||
asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command")
|
||||
ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd))
|
||||
|
@@ -2,7 +2,9 @@ package redis
|
||||
|
||||
// Connection represents a connection with redis client
|
||||
type Connection interface {
|
||||
Write([]byte) error
|
||||
Write([]byte) (int, error)
|
||||
Close() error
|
||||
|
||||
SetPassword(string)
|
||||
GetPassword() string
|
||||
|
||||
@@ -12,7 +14,6 @@ type Connection interface {
|
||||
SubsCount() int
|
||||
GetChannels() []string
|
||||
|
||||
// used for `Multi` command
|
||||
InMultiState() bool
|
||||
SetMultiState(bool)
|
||||
GetQueuedCmdLine() [][][]byte
|
||||
@@ -22,10 +23,12 @@ type Connection interface {
|
||||
AddTxError(err error)
|
||||
GetTxErrors() []error
|
||||
|
||||
// used for multi database
|
||||
GetDBIndex() int
|
||||
SelectDB(int)
|
||||
// returns role of conn, such as connection with client, connection with master node
|
||||
GetRole() int32
|
||||
SetRole(int32)
|
||||
|
||||
SetSlave()
|
||||
IsSlave() bool
|
||||
|
||||
SetMaster()
|
||||
IsMaster() bool
|
||||
}
|
||||
|
@@ -107,6 +107,13 @@ func Error(v ...interface{}) {
|
||||
logger.Println(v...)
|
||||
}
|
||||
|
||||
func Errorf(format string, v ...interface{}) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
setPrefix(ERROR)
|
||||
logger.Println(fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
// Fatal prints error log then stop the program
|
||||
func Fatal(v ...interface{}) {
|
||||
mu.Lock()
|
||||
|
@@ -16,3 +16,13 @@ func RandString(n int) string {
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
var hexLetters = []rune("0123456789abcdef")
|
||||
|
||||
func RandHexString(n int) string {
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = hexLetters[rand.Intn(len(hexLetters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
@@ -82,7 +82,7 @@ func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||
|
||||
for _, channel := range channels {
|
||||
if subscribe0(hub, channel, c) {
|
||||
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
||||
_, _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &protocol.NoReply{}
|
||||
@@ -117,13 +117,13 @@ func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||
defer db.subsLocker.UnLocks(channels...)
|
||||
|
||||
if len(channels) == 0 {
|
||||
_ = c.Write(unSubscribeNothing)
|
||||
_, _ = c.Write(unSubscribeNothing)
|
||||
return &protocol.NoReply{}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
if unsubscribe0(db, channel, c) {
|
||||
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
||||
_, _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &protocol.NoReply{}
|
||||
@@ -151,7 +151,7 @@ func Publish(hub *Hub, args [][]byte) redis.Reply {
|
||||
replyArgs[0] = messageBytes
|
||||
replyArgs[1] = []byte(channel)
|
||||
replyArgs[2] = message
|
||||
_ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes())
|
||||
_, _ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes())
|
||||
return true
|
||||
})
|
||||
return protocol.MakeIntReply(int64(subscribers.Len()))
|
||||
|
@@ -2,6 +2,6 @@ bind 0.0.0.0
|
||||
port 6399
|
||||
maxclients 128
|
||||
|
||||
#appendonly no
|
||||
#appendfilename appendonly.aof
|
||||
appendonly yes
|
||||
appendfilename appendonly.aof
|
||||
#dbfilename test.rdb
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"github.com/hdt3213/godis/lib/sync/wait"
|
||||
"net"
|
||||
@@ -10,21 +9,24 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// NormalCli is client with user
|
||||
NormalCli = iota
|
||||
// ReplicationRecvCli is fake client with replication master
|
||||
ReplicationRecvCli
|
||||
// flagSlave means this a connection with slave
|
||||
flagSlave = uint64(1 << iota)
|
||||
// flagSlave means this a connection with master
|
||||
flagMaster
|
||||
// flagMulti means this connection is within a transaction
|
||||
flagMulti
|
||||
)
|
||||
|
||||
// Connection represents a connection with a redis-cli
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
|
||||
// waiting until protocol finished
|
||||
waitingReply wait.Wait
|
||||
// waiting until finish sending data, used for graceful shutdown
|
||||
sendingData wait.Wait
|
||||
|
||||
// lock while server sending response
|
||||
mu sync.Mutex
|
||||
flags uint64
|
||||
|
||||
// subscribing channels
|
||||
subs map[string]bool
|
||||
@@ -33,14 +35,12 @@ type Connection struct {
|
||||
password string
|
||||
|
||||
// queued commands for `multi`
|
||||
multiState bool
|
||||
queue [][][]byte
|
||||
watching map[string]uint32
|
||||
txErrors []error
|
||||
|
||||
// selected db
|
||||
selectedDB int
|
||||
role int32
|
||||
}
|
||||
|
||||
var connPool = sync.Pool{
|
||||
@@ -49,11 +49,6 @@ var connPool = sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
// GetConnPool returns the connection pool pointer for putting and getting connection
|
||||
func (c *Connection) GetConnPool() *sync.Pool {
|
||||
return &connPool
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address
|
||||
func (c *Connection) RemoteAddr() net.Addr {
|
||||
return c.conn.RemoteAddr()
|
||||
@@ -61,8 +56,15 @@ func (c *Connection) RemoteAddr() net.Addr {
|
||||
|
||||
// Close disconnect with the client
|
||||
func (c *Connection) Close() error {
|
||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||
c.sendingData.WaitWithTimeout(10 * time.Second)
|
||||
_ = c.conn.Close()
|
||||
c.subs = nil
|
||||
c.password = ""
|
||||
c.queue = nil
|
||||
c.watching = nil
|
||||
c.txErrors = nil
|
||||
c.selectedDB = 0
|
||||
connPool.Put(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -80,17 +82,16 @@ func NewConn(conn net.Conn) *Connection {
|
||||
}
|
||||
|
||||
// Write sends response to client over tcp connection
|
||||
func (c *Connection) Write(b []byte) error {
|
||||
func (c *Connection) Write(b []byte) (int, error) {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
c.waitingReply.Add(1)
|
||||
c.sendingData.Add(1)
|
||||
defer func() {
|
||||
c.waitingReply.Done()
|
||||
c.sendingData.Done()
|
||||
}()
|
||||
|
||||
_, err := c.conn.Write(b)
|
||||
return err
|
||||
return c.conn.Write(b)
|
||||
}
|
||||
|
||||
// Subscribe add current connection into subscribers of the given channel
|
||||
@@ -146,7 +147,7 @@ func (c *Connection) GetPassword() string {
|
||||
|
||||
// InMultiState tells is connection in an uncommitted transaction
|
||||
func (c *Connection) InMultiState() bool {
|
||||
return c.multiState
|
||||
return c.flags&flagMulti > 0
|
||||
}
|
||||
|
||||
// SetMultiState sets transaction flag
|
||||
@@ -154,8 +155,10 @@ func (c *Connection) SetMultiState(state bool) {
|
||||
if !state { // reset data when cancel multi
|
||||
c.watching = nil
|
||||
c.queue = nil
|
||||
c.flags &= ^flagMulti // clean multi flag
|
||||
return
|
||||
}
|
||||
c.multiState = state
|
||||
c.flags |= flagMulti
|
||||
}
|
||||
|
||||
// GetQueuedCmdLine returns queued commands of current transaction
|
||||
@@ -183,18 +186,6 @@ func (c *Connection) ClearQueuedCmds() {
|
||||
c.queue = nil
|
||||
}
|
||||
|
||||
// GetRole returns role of connection, such as connection with master
|
||||
func (c *Connection) GetRole() int32 {
|
||||
if c == nil {
|
||||
return NormalCli
|
||||
}
|
||||
return c.role
|
||||
}
|
||||
|
||||
func (c *Connection) SetRole(r int32) {
|
||||
c.role = r
|
||||
}
|
||||
|
||||
// GetWatching returns watching keys and their version code when started watching
|
||||
func (c *Connection) GetWatching() map[string]uint32 {
|
||||
if c.watching == nil {
|
||||
@@ -213,24 +204,18 @@ func (c *Connection) SelectDB(dbNum int) {
|
||||
c.selectedDB = dbNum
|
||||
}
|
||||
|
||||
// FakeConn implements redis.Connection for test
|
||||
type FakeConn struct {
|
||||
Connection
|
||||
buf bytes.Buffer
|
||||
func (c *Connection) SetSlave() {
|
||||
c.flags |= flagSlave
|
||||
}
|
||||
|
||||
// Write writes data to buffer
|
||||
func (c *FakeConn) Write(b []byte) error {
|
||||
c.buf.Write(b)
|
||||
return nil
|
||||
func (c *Connection) IsSlave() bool {
|
||||
return c.flags&flagSlave > 0
|
||||
}
|
||||
|
||||
// Clean resets the buffer
|
||||
func (c *FakeConn) Clean() {
|
||||
c.buf.Reset()
|
||||
func (c *Connection) SetMaster() {
|
||||
c.flags |= flagMaster
|
||||
}
|
||||
|
||||
// Bytes returns written data
|
||||
func (c *FakeConn) Bytes() []byte {
|
||||
return c.buf.Bytes()
|
||||
func (c *Connection) IsMaster() bool {
|
||||
return c.flags&flagMaster > 0
|
||||
}
|
||||
|
79
redis/connection/fake.go
Normal file
79
redis/connection/fake.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FakeConn implements redis.Connection for test
|
||||
type FakeConn struct {
|
||||
Connection
|
||||
buf bytes.Buffer
|
||||
wait chan struct{}
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewFakeConn() *FakeConn {
|
||||
c := &FakeConn{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Write writes data to buffer
|
||||
func (c *FakeConn) Write(b []byte) (int, error) {
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, _ := c.buf.Write(b)
|
||||
c.notify()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *FakeConn) notify() {
|
||||
if c.wait != nil {
|
||||
c.mu.Lock()
|
||||
if c.wait != nil {
|
||||
close(c.wait)
|
||||
c.wait = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *FakeConn) waiting() {
|
||||
c.mu.Lock()
|
||||
c.wait = make(chan struct{})
|
||||
c.mu.Unlock()
|
||||
<-c.wait
|
||||
}
|
||||
|
||||
// Read reads data from buffer
|
||||
func (c *FakeConn) Read(p []byte) (int, error) {
|
||||
n, err := c.buf.Read(p)
|
||||
if err == io.EOF {
|
||||
if c.closed {
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.waiting()
|
||||
return c.buf.Read(p)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Clean resets the buffer
|
||||
func (c *FakeConn) Clean() {
|
||||
c.wait = make(chan struct{})
|
||||
c.buf.Reset()
|
||||
}
|
||||
|
||||
// Bytes returns written data
|
||||
func (c *FakeConn) Bytes() []byte {
|
||||
return c.buf.Bytes()
|
||||
}
|
||||
|
||||
func (c *FakeConn) Close() error {
|
||||
c.closed = true
|
||||
c.notify()
|
||||
return nil
|
||||
}
|
@@ -13,7 +13,7 @@ func TestPublish(t *testing.T) {
|
||||
hub := pubsub.MakeHub()
|
||||
channel := utils.RandString(5)
|
||||
msg := utils.RandString(5)
|
||||
conn := &connection.FakeConn{}
|
||||
conn := connection.NewFakeConn()
|
||||
pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel))
|
||||
conn.Clean() // clean subscribe success
|
||||
pubsub.Publish(hub, utils.ToCmdLine(channel, msg))
|
||||
|
@@ -50,8 +50,6 @@ func (h *Handler) closeClient(client *connection.Connection) {
|
||||
_ = client.Close()
|
||||
h.db.AfterClientClose(client)
|
||||
h.activeConn.Delete(client)
|
||||
|
||||
client.GetConnPool().Put(client)
|
||||
}
|
||||
|
||||
// Handle receives and executes redis commands
|
||||
@@ -78,7 +76,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
// protocol err
|
||||
errReply := protocol.MakeErrReply(payload.Err.Error())
|
||||
err := client.Write(errReply.ToBytes())
|
||||
_, err := client.Write(errReply.ToBytes())
|
||||
if err != nil {
|
||||
h.closeClient(client)
|
||||
logger.Info("connection closed: " + client.RemoteAddr().String())
|
||||
@@ -97,9 +95,9 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
result := h.db.Exec(client, r.Args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
_, _ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(unknownErrReplyBytes)
|
||||
_, _ = client.Write(unknownErrReplyBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user