replication master side

This commit is contained in:
finley
2022-11-21 23:36:35 +08:00
parent a7a3da6e49
commit ba7ea942cb
23 changed files with 886 additions and 217 deletions

View File

@@ -26,6 +26,10 @@ type payload struct {
dbIndex int
}
// Listener is a channel receive a replication of all aof payloads
// with a listener we can forward the updates to slave nodes etc.
type Listener chan<- CmdLine
// Handler receive msgs from channel and write to AOF file
type Handler struct {
db database.EmbedDB
@@ -38,6 +42,7 @@ type Handler struct {
// pause aof for start/finish aof rewrite progress
pausingAof sync.RWMutex
currentDB int
listeners map[Listener]struct{}
}
// NewAOFHandler creates a new aof.Handler
@@ -54,12 +59,20 @@ func NewAOFHandler(db database.EmbedDB, tmpDBMaker func() database.EmbedDB) (*Ha
handler.aofFile = aofFile
handler.aofChan = make(chan *payload, aofQueueSize)
handler.aofFinished = make(chan struct{})
handler.listeners = make(map[Listener]struct{})
go func() {
handler.handleAof()
}()
return handler, nil
}
// RemoveListener removes a listener from aof handler, so we can close the listener
func (handler *Handler) RemoveListener(listener Listener) {
handler.pausingAof.Lock()
defer handler.pausingAof.Unlock()
delete(handler.listeners, listener)
}
// AddAof send command to aof goroutine through channel
func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
if config.Properties.AppendOnly && handler.aofChan != nil {
@@ -73,12 +86,16 @@ func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) {
// handleAof listen aof channel and write into file
func (handler *Handler) handleAof() {
// serialized execution
var cmdLines []CmdLine
handler.currentDB = 0
for p := range handler.aofChan {
cmdLines = cmdLines[:0] // reuse underlying array
handler.pausingAof.RLock() // prevent other goroutines from pausing aof
if p.dbIndex != handler.currentDB {
// select db
data := protocol.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes()
selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))
cmdLines = append(cmdLines, selectCmd)
data := protocol.MakeMultiBulkReply(selectCmd).ToBytes()
_, err := handler.aofFile.Write(data)
if err != nil {
logger.Warn(err)
@@ -88,11 +105,17 @@ func (handler *Handler) handleAof() {
handler.currentDB = p.dbIndex
}
data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes()
cmdLines = append(cmdLines, p.cmdLine)
_, err := handler.aofFile.Write(data)
if err != nil {
logger.Warn(err)
}
handler.pausingAof.RUnlock()
for listener := range handler.listeners {
for _, line := range cmdLines {
listener <- line
}
}
}
handler.aofFinished <- struct{}{}
}
@@ -123,7 +146,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
reader = file
}
ch := parser.ParseStream(reader)
fakeConn := &connection.FakeConn{} // only used for save dbIndex
fakeConn := connection.NewFakeConn() // only used for save dbIndex
for p := range ch {
if p.Err != nil {
if p.Err == io.EOF {
@@ -143,7 +166,7 @@ func (handler *Handler) LoadAof(maxBytes int) {
}
ret := handler.db.Exec(fakeConn, r.Args)
if protocol.IsErrorReply(ret) {
logger.Error("exec err", ret.ToBytes())
logger.Error("exec err", string(ret.ToBytes()))
}
}
}

View File

@@ -16,8 +16,12 @@ import (
"time"
)
func (handler *Handler) Rewrite2RDB() error {
ctx, err := handler.startRewrite2RDB()
// todo: forbid concurrent rewrite
// Rewrite2RDB rewrite aof data into rdb
// if extraListener is not nil, it will be appended to Handler.listeners, it will receive all updates after rdb
func (handler *Handler) Rewrite2RDB(rdbFilename string, extraListener Listener) error {
ctx, err := handler.startRewrite2RDB(extraListener)
if err != nil {
return err
}
@@ -25,10 +29,6 @@ func (handler *Handler) Rewrite2RDB() error {
if err != nil {
return err
}
rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err = ctx.tmpFile.Close()
if err != nil {
return err
@@ -40,7 +40,7 @@ func (handler *Handler) Rewrite2RDB() error {
return nil
}
func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) {
func (handler *Handler) startRewrite2RDB(extraListener Listener) (*RewriteCtx, error) {
handler.pausingAof.Lock() // pausing aof
defer handler.pausingAof.Unlock()
@@ -59,6 +59,9 @@ func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) {
logger.Warn("tmp file create failed")
return nil, err
}
if extraListener != nil {
handler.listeners[extraListener] = struct{}{}
}
return &RewriteCtx{
tmpFile: file,
fileSize: filesize,

View File

@@ -10,7 +10,7 @@ import (
func TestExec(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
for i := 0; i < 1000; i++ {
key := RandString(4)
value := RandString(4)
@@ -26,7 +26,7 @@ func TestAuth(t *testing.T) {
defer func() {
config.Properties.RequirePass = ""
}()
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
ret := testNodeA.Exec(conn, toArgs("GET", "a"))
asserts.AssertErrReply(t, ret, "NOAUTH Authentication required")
ret = testNodeA.Exec(conn, toArgs("AUTH", passwd))
@@ -39,7 +39,7 @@ func TestRelay(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4)
value := RandString(4)
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value))
asserts.AssertNotError(t, ret)
ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key))
@@ -50,7 +50,7 @@ func TestBroadcast(t *testing.T) {
testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"})
key := RandString(4)
value := RandString(4)
rets := testCluster2.broadcast(&connection.FakeConn{}, toArgs("SET", key, value))
rets := testCluster2.broadcast(connection.NewFakeConn(), toArgs("SET", key, value))
for _, v := range rets {
asserts.AssertNotError(t, v)
}

View File

@@ -7,7 +7,7 @@ import (
)
func TestDel(t *testing.T) {
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
allowFastTransaction = false
testNodeA.Exec(conn, toArgs("SET", "a", "a"))
ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c"))

View File

@@ -7,7 +7,7 @@ import (
)
func TestMSet(t *testing.T) {
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
allowFastTransaction = false
ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b"))
asserts.AssertNotError(t, ret)
@@ -16,7 +16,7 @@ func TestMSet(t *testing.T) {
}
func TestMSetNx(t *testing.T) {
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
allowFastTransaction = false
FlushAll(testNodeA, conn, toArgs("FLUSHALL"))
ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b"))

View File

@@ -11,7 +11,7 @@ import (
func TestPublish(t *testing.T) {
channel := utils.RandString(5)
msg := utils.RandString(5)
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
Subscribe(testNodeA, conn, utils.ToCmdLine("SUBSCRIBE", channel))
conn.Clean() // clean subscribe success
Publish(testNodeA, conn, utils.ToCmdLine("PUBLISH", channel, msg))

View File

@@ -17,7 +17,7 @@ import (
)
func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
conn.SelectDB(dbIndex)
db.Exec(conn, utils.ToCmdLine("FlushDB"))
cursor := 0
@@ -49,7 +49,7 @@ func makeTestData(db database.DB, dbIndex int, prefix string, size int) {
}
func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) {
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
conn.SelectDB(dbIndex)
cursor := 0
var ret redis.Reply
@@ -146,7 +146,7 @@ func TestRDB(t *testing.T) {
dbNum := 4
size := 10
var prefixes []string
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
writeDB := NewStandaloneServer()
for i := 0; i < dbNum; i++ {
prefix := utils.RandString(8)
@@ -216,7 +216,7 @@ func TestRewriteAOF2(t *testing.T) {
}
aofWriteDB := NewStandaloneServer()
dbNum := 4
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
for i := 0; i < dbNum; i++ {
conn.SelectDB(i)
key := strconv.Itoa(i)

View File

@@ -9,7 +9,6 @@ import (
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/pubsub"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/protocol"
"runtime/debug"
"strconv"
@@ -27,10 +26,10 @@ type MultiDB struct {
// handle aof persistence
aofHandler *aof.Handler
// store master node address
slaveOf string
role int32
replication *slaveStatus
// for replication
role int32
slaveStatus *slaveStatus
masterStatus *masterStatus
}
// NewStandaloneServer creates a standalone redis server, with multi database and all other funtions
@@ -50,31 +49,36 @@ func NewStandaloneServer() *MultiDB {
mdb.hub = pubsub.MakeHub()
validAof := false
if config.Properties.AppendOnly {
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB {
return MakeBasicMultiDB()
})
if err != nil {
panic(err)
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
singleDB := db.Load().(*DB)
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
}
mdb.initAof()
validAof = true
}
if config.Properties.RDBFilename != "" && !validAof {
// load rdb
loadRdbFile(mdb)
}
mdb.replication = initReplStatus()
mdb.slaveStatus = initReplSlaveStatus()
mdb.startAsMaster()
mdb.startReplCron()
mdb.role = masterRole // The initialization process does not require atomicity
return mdb
}
func (mdb *MultiDB) initAof() {
aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB {
return MakeBasicMultiDB()
})
if err != nil {
panic(err)
}
mdb.aofHandler = aofHandler
for _, db := range mdb.dbSet {
singleDB := db.Load().(*DB)
singleDB.addAof = func(line CmdLine) {
mdb.aofHandler.AddAof(singleDB.index, line)
}
}
}
// MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages
func MakeBasicMultiDB() *MultiDB {
mdb := &MultiDB{}
@@ -117,8 +121,7 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
// read only slave
role := atomic.LoadInt32(&mdb.role)
if role == slaveRole &&
c.GetRole() != connection.ReplicationRecvCli {
if role == slaveRole && !c.IsMaster() {
// only allow read only command, forbid all special commands except `auth` and `slaveof`
if !isReadOnlyCommand(cmdName) {
return protocol.MakeErrReply("READONLY You can't write against a read only slave.")
@@ -167,6 +170,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep
return protocol.MakeArgNumErrReply("copy")
}
return execCopy(mdb, c, cmdLine[1:])
} else if cmdName == "replconf" {
return mdb.execReplConf(c, cmdLine[1:])
} else if cmdName == "psync" {
return mdb.execPSync(c, cmdLine[1:])
}
// todo: support multi database transaction
@@ -186,8 +193,8 @@ func (mdb *MultiDB) AfterClientClose(c redis.Connection) {
// Close graceful shutdown database
func (mdb *MultiDB) Close() {
// stop replication first
mdb.replication.close()
// stop slaveStatus first
mdb.slaveStatus.close()
if mdb.aofHandler != nil {
mdb.aofHandler.Close()
}
@@ -308,7 +315,11 @@ func SaveRDB(db *MultiDB, args [][]byte) redis.Reply {
if db.aofHandler == nil {
return protocol.MakeErrReply("please enable aof before using save")
}
err := db.aofHandler.Rewrite2RDB()
rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
if err != nil {
return protocol.MakeErrReply(err.Error())
}
@@ -326,7 +337,11 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply {
logger.Error(err)
}
}()
err := db.aofHandler.Rewrite2RDB()
rdbFilename := config.Properties.RDBFilename
if rdbFilename == "" {
rdbFilename = "dump.rdb"
}
err := db.aofHandler.Rewrite2RDB(rdbFilename, nil)
if err != nil {
logger.Error(err)
}
@@ -339,3 +354,18 @@ func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) {
db := mdb.mustSelectDB(dbIndex)
return db.data.Len(), db.ttlMap.Len()
}
func (mdb *MultiDB) startReplCron() {
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error("panic", err)
}
}()
ticker := time.Tick(time.Second * 10)
for range ticker {
mdb.slaveCron()
mdb.masterCron()
}
}()
}

View File

@@ -17,7 +17,7 @@ func TestLoadRDB(t *testing.T) {
AppendOnly: false,
RDBFilename: filepath.Join(projectRoot, "test.rdb"), // set working directory to project root
}
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
rdbDB := NewStandaloneServer()
result := rdbDB.Exec(conn, utils.ToCmdLine("Get", "str"))
asserts.AssertBulkReply(t, result, "str")

View File

@@ -0,0 +1,362 @@
package database
import (
"context"
"errors"
"fmt"
"github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/protocol"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"sync"
"time"
)
const (
slaveStateHandShake = uint8(iota)
slaveStateWaitSaveEnd
slaveStateSendingRDB
slaveStateOnline
)
const (
bgSaveIdle = uint8(iota)
bgSaveRunning
bgSaveFinish
)
const (
salveCapacityNone = 0
salveCapacityEOF = 1 << iota
salveCapacityPsync2
)
// slaveClient stores slave status in the view of master
type slaveClient struct {
conn redis.Connection
state uint8
offset int64
lastAckTime time.Time
announceIp string
announcePort int
capacity uint8
}
type masterStatus struct {
ctx context.Context
mu sync.RWMutex
replId string
backlog []byte // backlog can be appended or replaced as a whole, cannot be modified(insert/set/delete)
beginOffset int64
currentOffset int64
slaveMap map[redis.Connection]*slaveClient
waitSlaves map[*slaveClient]struct{}
onlineSlaves map[*slaveClient]struct{}
bgSaveState uint8
rdbFilename string
}
func (master *masterStatus) appendBacklog(bin []byte) {
master.backlog = append(master.backlog, bin...)
master.currentOffset += int64(len(bin))
}
func (mdb *MultiDB) bgSaveForReplication() {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Errorf("panic: %v", e)
}
}()
if err := mdb.saveForReplication(); err != nil {
logger.Errorf("save for replication error: %v", err)
}
}()
}
// saveForReplication does bg-save and send rdb to waiting slaves
func (mdb *MultiDB) saveForReplication() error {
rdbFile, err := ioutil.TempFile("", "*.rdb")
if err != nil {
return fmt.Errorf("create temp rdb failed: %v", err)
}
rdbFilename := rdbFile.Name()
mdb.masterStatus.mu.Lock()
mdb.masterStatus.bgSaveState = bgSaveRunning
mdb.masterStatus.rdbFilename = rdbFilename // todo: can reuse config.Properties.RDBFilename?
mdb.masterStatus.mu.Unlock()
aofListener := make(chan CmdLine, 1024) // give channel enough capacity to store all updates during rewrite to db
err = mdb.aofHandler.Rewrite2RDB(rdbFilename, aofListener)
if err != nil {
return err
}
go func() {
mdb.masterListenAof(aofListener)
}()
// change bgSaveState and get waitSlaves for sending
waitSlaves := make(map[*slaveClient]struct{})
mdb.masterStatus.mu.Lock()
mdb.masterStatus.bgSaveState = bgSaveFinish
for slave := range mdb.masterStatus.waitSlaves {
waitSlaves[slave] = struct{}{}
}
mdb.masterStatus.waitSlaves = nil
mdb.masterStatus.mu.Unlock()
for slave := range waitSlaves {
err = mdb.masterFullReSyncWithSlave(slave)
if err != nil {
mdb.removeSlave(slave)
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
continue
}
}
return nil
}
// masterFullReSyncWithSlave send replication header, rdb file and all backlogs to slave
func (mdb *MultiDB) masterFullReSyncWithSlave(slave *slaveClient) error {
// write replication header
header := "+FULLRESYNC " + mdb.masterStatus.replId + " " +
strconv.FormatInt(mdb.masterStatus.beginOffset, 10) + protocol.CRLF
_, err := slave.conn.Write([]byte(header))
if err != nil {
return fmt.Errorf("write replication header to slave failed: %v", err)
}
// send rdb
rdbFile, err := os.Open(mdb.masterStatus.rdbFilename)
if err != nil {
return fmt.Errorf("open rdb file %s for replication error: %v", mdb.masterStatus.rdbFilename, err)
}
slave.state = slaveStateSendingRDB
rdbInfo, _ := os.Stat(mdb.masterStatus.rdbFilename)
rdbSize := rdbInfo.Size()
rdbHeader := "$" + strconv.FormatInt(rdbSize, 10) + protocol.CRLF
_, err = slave.conn.Write([]byte(rdbHeader))
if err != nil {
return fmt.Errorf("write rdb header to slave failed: %v", err)
}
_, err = io.Copy(slave.conn, rdbFile)
if err != nil {
return fmt.Errorf("write rdb file to slave failed: %v", err)
}
// send backlog
mdb.masterStatus.mu.RLock()
currentOffset := mdb.masterStatus.currentOffset
backlog := mdb.masterStatus.backlog[:currentOffset-mdb.masterStatus.beginOffset]
mdb.masterStatus.mu.RUnlock()
_, err = slave.conn.Write(backlog)
if err != nil {
return fmt.Errorf("full resync write backlog to slave failed: %v", err)
}
// set slave as online
mdb.setSlaveOnline(slave, currentOffset)
return nil
}
var cannotPartialSync = errors.New("cannot do partial sync")
func (mdb *MultiDB) masterTryPartialSyncWithSlave(slave *slaveClient, replId string, slaveOffset int64) error {
mdb.masterStatus.mu.RLock()
if replId != mdb.masterStatus.replId {
mdb.masterStatus.mu.RUnlock()
return cannotPartialSync
}
if slaveOffset < mdb.masterStatus.beginOffset || slaveOffset > mdb.masterStatus.currentOffset {
mdb.masterStatus.mu.RUnlock()
return cannotPartialSync
}
currentOffset := mdb.masterStatus.currentOffset
backlog := mdb.masterStatus.backlog[slaveOffset-mdb.masterStatus.beginOffset : currentOffset-mdb.masterStatus.beginOffset]
mdb.masterStatus.mu.RUnlock()
// send replication header
header := "+CONTINUE " + mdb.masterStatus.replId + protocol.CRLF
_, err := slave.conn.Write([]byte(header))
if err != nil {
return fmt.Errorf("write replication header to slave failed: %v", err)
}
// send backlog
_, err = slave.conn.Write(backlog)
if err != nil {
return fmt.Errorf("partial resync write backlog to slave failed: %v", err)
}
// set slave online
mdb.setSlaveOnline(slave, currentOffset)
return nil
}
func (mdb *MultiDB) masterSendUpdatesToSlave() error {
onlineSlaves := make(map[*slaveClient]struct{})
mdb.masterStatus.mu.RLock()
currentOffset := mdb.masterStatus.currentOffset
beginOffset := mdb.masterStatus.beginOffset
backlog := mdb.masterStatus.backlog[:currentOffset-beginOffset]
for slave := range mdb.masterStatus.onlineSlaves {
onlineSlaves[slave] = struct{}{}
}
mdb.masterStatus.mu.RUnlock()
for slave := range onlineSlaves {
slaveBeginOffset := slave.offset - beginOffset
_, err := slave.conn.Write(backlog[slaveBeginOffset:])
if err != nil {
logger.Errorf("send updates write backlog to slave failed: %v", err)
mdb.removeSlave(slave)
continue
}
slave.offset = currentOffset
}
return nil
}
func (mdb *MultiDB) execPSync(c redis.Connection, args [][]byte) redis.Reply {
replId := string(args[0])
replOffset, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range")
}
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
slave := mdb.masterStatus.slaveMap[c]
if slave == nil {
slave = &slaveClient{
conn: c,
}
c.SetSlave()
mdb.masterStatus.slaveMap[c] = slave
}
if mdb.masterStatus.bgSaveState == bgSaveIdle {
slave.state = slaveStateWaitSaveEnd
mdb.masterStatus.waitSlaves[slave] = struct{}{}
mdb.bgSaveForReplication()
} else if mdb.masterStatus.bgSaveState == bgSaveRunning {
slave.state = slaveStateWaitSaveEnd
mdb.masterStatus.waitSlaves[slave] = struct{}{}
} else if mdb.masterStatus.bgSaveState == bgSaveFinish {
go func() {
defer func() {
if e := recover(); e != nil {
logger.Errorf("panic: %v", e)
}
}()
err := mdb.masterTryPartialSyncWithSlave(slave, replId, replOffset)
if err == nil {
return
}
if err != nil && err != cannotPartialSync {
mdb.removeSlave(slave)
logger.Errorf("masterTryPartialSyncWithSlave error: %v", err)
return
}
// assert err == cannotPartialSync
if err := mdb.masterFullReSyncWithSlave(slave); err != nil {
mdb.removeSlave(slave)
logger.Errorf("masterFullReSyncWithSlave error: %v", err)
return
}
}()
}
return &protocol.NoReply{}
}
func (mdb *MultiDB) execReplConf(c redis.Connection, args [][]byte) redis.Reply {
if len(args)%2 != 0 {
return protocol.MakeSyntaxErrReply()
}
mdb.masterStatus.mu.RLock()
slave := mdb.masterStatus.slaveMap[c]
mdb.masterStatus.mu.RUnlock()
for i := 0; i < len(args); i += 2 {
key := strings.ToLower(string(args[i]))
value := string(args[i+1])
switch key {
case "ack":
offset, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range")
}
slave.offset = offset
slave.lastAckTime = time.Now()
return &protocol.NoReply{}
}
}
return protocol.MakeOkReply()
}
func (mdb *MultiDB) removeSlave(slave *slaveClient) {
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
_ = slave.conn.Close()
delete(mdb.masterStatus.slaveMap, slave.conn)
delete(mdb.masterStatus.waitSlaves, slave)
delete(mdb.masterStatus.onlineSlaves, slave)
}
func (mdb *MultiDB) setSlaveOnline(slave *slaveClient, currentOffset int64) {
mdb.masterStatus.mu.Lock()
defer mdb.masterStatus.mu.Unlock()
slave.state = slaveStateOnline
slave.offset = currentOffset
mdb.masterStatus.onlineSlaves[slave] = struct{}{}
}
var pingBytes = protocol.MakeMultiBulkReply(utils.ToCmdLine("ping")).ToBytes()
func (mdb *MultiDB) masterCron() {
if mdb.role != masterRole {
return
}
mdb.masterStatus.mu.Lock()
if mdb.masterStatus.bgSaveState == bgSaveFinish {
mdb.masterStatus.appendBacklog(pingBytes)
}
mdb.masterStatus.mu.Unlock()
if err := mdb.masterSendUpdatesToSlave(); err != nil {
logger.Errorf("masterSendUpdatesToSlave error: %v", err)
}
}
func (mdb *MultiDB) masterListenAof(listener chan CmdLine) {
for {
select {
case cmdLine := <-listener:
mdb.masterStatus.mu.Lock()
reply := protocol.MakeMultiBulkReply(cmdLine)
mdb.masterStatus.appendBacklog(reply.ToBytes())
mdb.masterStatus.mu.Unlock()
if err := mdb.masterSendUpdatesToSlave(); err != nil {
logger.Errorf("masterSendUpdatesToSlave after receive aof error: %v", err)
}
// if bgSave is running, updates will be sent after the save finished
case <-mdb.masterStatus.ctx.Done():
break
}
}
}
func (mdb *MultiDB) startAsMaster() {
mdb.masterStatus = &masterStatus{
ctx: context.Background(),
mu: sync.RWMutex{},
replId: utils.RandHexString(40),
backlog: nil,
beginOffset: 0,
currentOffset: 0,
slaveMap: make(map[redis.Connection]*slaveClient),
waitSlaves: make(map[*slaveClient]struct{}),
onlineSlaves: make(map[*slaveClient]struct{}),
bgSaveState: bgSaveIdle,
rdbFilename: "",
}
}

View File

@@ -0,0 +1,194 @@
package database
import (
"bytes"
"github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/lib/utils"
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/parser"
"github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts"
rdb "github.com/hdt3213/rdb/parser"
"io/ioutil"
"os"
"path"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
)
func mockServer() *MultiDB {
server := &MultiDB{}
server.dbSet = make([]*atomic.Value, 16)
for i := range server.dbSet {
singleDB := makeDB()
singleDB.index = i
holder := &atomic.Value{}
holder.Store(singleDB)
server.dbSet[i] = holder
}
server.slaveStatus = initReplSlaveStatus()
return server
}
func TestReplicationMasterSide(t *testing.T) {
tmpDir, err := ioutil.TempDir("", "godis")
if err != nil {
t.Error(err)
return
}
aofFilename := path.Join(tmpDir, "a.aof")
defer func() {
_ = os.Remove(aofFilename)
}()
config.Properties = &config.ServerProperties{
Databases: 16,
AppendOnly: true,
AppendFilename: aofFilename,
}
master := mockServer()
master.initAof()
master.startAsMaster()
slave := mockServer()
replConn := connection.NewFakeConn()
// set data to master
masterConn := connection.NewFakeConn()
resp := master.Exec(masterConn, utils.ToCmdLine("SET", "a", "a"))
asserts.AssertNotError(t, resp)
time.Sleep(time.Millisecond * 100) // wait write aof
// full re-sync
master.Exec(replConn, utils.ToCmdLine("psync", "?", "-1"))
masterChan := parser.ParseStream(replConn)
psyncPayload := <-masterChan
if psyncPayload.Err != nil {
t.Errorf("master bad protocol: %v", psyncPayload.Err)
return
}
psyncHeader, ok := psyncPayload.Data.(*protocol.StatusReply)
if !ok {
t.Error("psync header is not a status reply")
return
}
headers := strings.Split(psyncHeader.Status, " ")
if len(headers) != 3 {
t.Errorf("illegal psync header: %s", psyncHeader.Status)
return
}
replId := headers[1]
replOffset, err := strconv.ParseInt(headers[2], 10, 64)
if err != nil {
t.Errorf("illegal offset: %s", headers[2])
return
}
t.Logf("repl id: %s, offset: %d", replId, replOffset)
rdbPayload := <-masterChan
if rdbPayload.Err != nil {
t.Error("read response failed: " + rdbPayload.Err.Error())
return
}
rdbReply, ok := rdbPayload.Data.(*protocol.BulkReply)
if !ok {
t.Error("illegal payload header: " + string(rdbPayload.Data.ToBytes()))
return
}
rdbDec := rdb.NewDecoder(bytes.NewReader(rdbReply.Arg))
err = importRDB(rdbDec, slave)
if err != nil {
t.Error("import rdb failed: " + err.Error())
return
}
// get a
slaveConn := connection.NewFakeConn()
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "a"))
asserts.AssertBulkReply(t, resp, "a")
/*---- test broadcast aof ----*/
masterConn = connection.NewFakeConn()
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "b", "b"))
time.Sleep(time.Millisecond * 100) // wait write aof
asserts.AssertNotError(t, resp)
master.masterCron()
for {
payload := <-masterChan
if payload.Err != nil {
t.Error(payload.Err)
return
}
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
if !ok {
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
return
}
slave.Exec(replConn, cmdLine.Args)
n := len(cmdLine.ToBytes())
slave.slaveStatus.replOffset += int64(n)
if string(cmdLine.Args[0]) != "ping" {
break
}
}
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "b"))
asserts.AssertBulkReply(t, resp, "b")
/*---- test partial reconnect ----*/
_ = replConn.Close() // mock disconnect
replConn = connection.NewFakeConn()
master.Exec(replConn, utils.ToCmdLine("psync", replId,
strconv.FormatInt(slave.slaveStatus.replOffset, 10)))
masterChan = parser.ParseStream(replConn)
psyncPayload = <-masterChan
if psyncPayload.Err != nil {
t.Errorf("master bad protocol: %v", psyncPayload.Err)
return
}
psyncHeader, ok = psyncPayload.Data.(*protocol.StatusReply)
if !ok {
t.Error("psync header is not a status reply")
return
}
headers = strings.Split(psyncHeader.Status, " ")
if len(headers) != 2 {
t.Errorf("illegal psync header: %s", psyncHeader.Status)
return
}
if headers[0] != "CONTINUE" {
t.Errorf("expect CONTINUE actual %s", headers[0])
return
}
replId = headers[1]
t.Logf("partial resync repl id: %s, offset: %d", replId, slave.slaveStatus.replOffset)
resp = master.Exec(masterConn, utils.ToCmdLine("SET", "c", "c"))
time.Sleep(time.Millisecond * 100) // wait write aof
asserts.AssertNotError(t, resp)
master.masterCron()
for {
payload := <-masterChan
if payload.Err != nil {
t.Error(payload.Err)
return
}
cmdLine, ok := payload.Data.(*protocol.MultiBulkReply)
if !ok {
t.Error("unexpected payload: " + string(payload.Data.ToBytes()))
return
}
slave.Exec(replConn, cmdLine.Args)
if string(cmdLine.Args[0]) != "ping" {
break
}
}
resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "c"))
asserts.AssertBulkReply(t, resp, "c")
}

View File

@@ -31,9 +31,9 @@ type slaveStatus struct {
ctx context.Context
cancel context.CancelFunc
// configVersion stands for the version of replication config. Any change of master host/port will cause configVersion increment
// If configVersion change has been found during replication current replication procedure will stop.
// It is designed to abort a running replication procedure
// configVersion stands for the version of slaveStatus config. Any change of master host/port will cause configVersion increment
// If configVersion change has been found during slaveStatus current slaveStatus procedure will stop.
// It is designed to abort a running slaveStatus procedure
configVersion int32
masterHost string
@@ -47,26 +47,10 @@ type slaveStatus struct {
running sync.WaitGroup
}
var configChangedErr = errors.New("replication config changed")
var configChangedErr = errors.New("slaveStatus config changed")
func initReplStatus() *slaveStatus {
repl := &slaveStatus{}
// start cron
return repl
}
func (mdb *MultiDB) startReplCron() {
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error("panic", err)
}
}()
ticker := time.Tick(time.Second)
for range ticker {
mdb.slaveCron()
}
}()
func initReplSlaveStatus() *slaveStatus {
return &slaveStatus{}
}
func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
@@ -80,29 +64,29 @@ func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply {
if err != nil {
return protocol.MakeErrReply("ERR value is not an integer or out of range")
}
mdb.replication.mutex.Lock()
mdb.slaveStatus.mutex.Lock()
atomic.StoreInt32(&mdb.role, slaveRole)
mdb.replication.masterHost = host
mdb.replication.masterPort = port
mdb.slaveStatus.masterHost = host
mdb.slaveStatus.masterPort = port
// use buffered channel in case receiver goroutine exited before controller send stop signal
atomic.AddInt32(&mdb.replication.configVersion, 1)
mdb.replication.mutex.Unlock()
atomic.AddInt32(&mdb.slaveStatus.configVersion, 1)
mdb.slaveStatus.mutex.Unlock()
go mdb.setupMaster()
return protocol.MakeOkReply()
}
func (mdb *MultiDB) slaveOfNone() {
mdb.replication.mutex.Lock()
defer mdb.replication.mutex.Unlock()
mdb.replication.masterHost = ""
mdb.replication.masterPort = 0
mdb.replication.replId = ""
mdb.replication.replOffset = -1
mdb.replication.stopSlaveWithMutex()
mdb.slaveStatus.mutex.Lock()
defer mdb.slaveStatus.mutex.Unlock()
mdb.slaveStatus.masterHost = ""
mdb.slaveStatus.masterPort = 0
mdb.slaveStatus.replId = ""
mdb.slaveStatus.replOffset = -1
mdb.slaveStatus.stopSlaveWithMutex()
}
// stopSlaveWithMutex stops in-progress connectWithMaster/fullSync/receiveAOF
// invoker should have replication mutex
// invoker should have slaveStatus mutex
func (repl *slaveStatus) stopSlaveWithMutex() {
// update configVersion to stop connectWithMaster and fullSync
atomic.AddInt32(&repl.configVersion, 1)
@@ -135,11 +119,11 @@ func (mdb *MultiDB) setupMaster() {
}()
var configVersion int32
ctx, cancel := context.WithCancel(context.Background())
mdb.replication.mutex.Lock()
mdb.replication.ctx = ctx
mdb.replication.cancel = cancel
configVersion = mdb.replication.configVersion
mdb.replication.mutex.Unlock()
mdb.slaveStatus.mutex.Lock()
mdb.slaveStatus.ctx = ctx
mdb.slaveStatus.cancel = cancel
configVersion = mdb.slaveStatus.configVersion
mdb.slaveStatus.mutex.Unlock()
isFullReSync, err := mdb.connectWithMaster(configVersion)
if err != nil {
// connect failed, abort master
@@ -167,7 +151,7 @@ func (mdb *MultiDB) setupMaster() {
// connectWithMaster finishes handshake with master
// returns: isFullReSync, error
func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
addr := mdb.replication.masterHost + ":" + strconv.Itoa(mdb.replication.masterPort)
addr := mdb.slaveStatus.masterHost + ":" + strconv.Itoa(mdb.slaveStatus.masterPort)
conn, err := net.Dial("tcp", addr)
if err != nil {
mdb.slaveOfNone() // abort
@@ -256,34 +240,34 @@ func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) {
}
// update connection
mdb.replication.mutex.Lock()
defer mdb.replication.mutex.Unlock()
if mdb.replication.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex
mdb.slaveStatus.mutex.Lock()
defer mdb.slaveStatus.mutex.Unlock()
if mdb.slaveStatus.configVersion != configVersion {
// slaveStatus conf changed during connecting and waiting mutex
return false, configChangedErr
}
mdb.replication.masterConn = conn
mdb.replication.masterChan = masterChan
mdb.replication.lastRecvTime = time.Now()
mdb.slaveStatus.masterConn = conn
mdb.slaveStatus.masterChan = masterChan
mdb.slaveStatus.lastRecvTime = time.Now()
return mdb.psyncHandshake()
}
// psyncHandshake send `psync` to master and sync repl-id/offset with master
// invoker should provide with replication.mutex
// invoker should provide with slaveStatus.mutex
func (mdb *MultiDB) psyncHandshake() (bool, error) {
replId := "?"
var replOffset int64 = -1
if mdb.replication.replId != "" {
replId = mdb.replication.replId
replOffset = mdb.replication.replOffset
if mdb.slaveStatus.replId != "" {
replId = mdb.slaveStatus.replId
replOffset = mdb.slaveStatus.replOffset
}
psyncCmdLine := utils.ToCmdLine("psync", replId, strconv.FormatInt(replOffset, 10))
psyncReq := protocol.MakeMultiBulkReply(psyncCmdLine)
_, err := mdb.replication.masterConn.Write(psyncReq.ToBytes())
_, err := mdb.slaveStatus.masterConn.Write(psyncReq.ToBytes())
if err != nil {
return false, errors.New("send failed " + err.Error())
}
psyncPayload := <-mdb.replication.masterChan
psyncPayload := <-mdb.slaveStatus.masterChan
if psyncPayload.Err != nil {
return false, errors.New("read response failed: " + psyncPayload.Err.Error())
}
@@ -300,12 +284,12 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
var isFullReSync bool
if headers[0] == "FULLRESYNC" {
logger.Info("full re-sync with master")
mdb.replication.replId = headers[1]
mdb.replication.replOffset, err = strconv.ParseInt(headers[2], 10, 64)
mdb.slaveStatus.replId = headers[1]
mdb.slaveStatus.replOffset, err = strconv.ParseInt(headers[2], 10, 64)
isFullReSync = true
} else if headers[0] == "CONTINUE" {
logger.Info("continue partial sync")
mdb.replication.replId = headers[1]
mdb.slaveStatus.replId = headers[1]
isFullReSync = false
} else {
return false, errors.New("illegal psync resp: " + psyncHeader.Status)
@@ -314,13 +298,13 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) {
if err != nil {
return false, errors.New("get illegal repl offset: " + headers[2])
}
logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.replication.replId, mdb.replication.replOffset))
logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.slaveStatus.replId, mdb.slaveStatus.replOffset))
return isFullReSync, nil
}
// loadMasterRDB downloads rdb after handshake has been done
func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
rdbPayload := <-mdb.replication.masterChan
rdbPayload := <-mdb.slaveStatus.masterChan
if rdbPayload.Err != nil {
return errors.New("read response failed: " + rdbPayload.Err.Error())
}
@@ -337,10 +321,10 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
return errors.New("dump rdb failed: " + err.Error())
}
mdb.replication.mutex.Lock()
defer mdb.replication.mutex.Unlock()
if mdb.replication.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex
mdb.slaveStatus.mutex.Lock()
defer mdb.slaveStatus.mutex.Unlock()
if mdb.slaveStatus.configVersion != configVersion {
// slaveStatus conf changed during connecting and waiting mutex
return configChangedErr
}
for i, h := range rdbHolder.dbSet {
@@ -353,13 +337,13 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error {
}
func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
conn := connection.NewConn(mdb.replication.masterConn)
conn.SetRole(connection.ReplicationRecvCli)
mdb.replication.running.Add(1)
defer mdb.replication.running.Done()
conn := connection.NewConn(mdb.slaveStatus.masterConn)
conn.SetMaster()
mdb.slaveStatus.running.Add(1)
defer mdb.slaveStatus.running.Done()
for {
select {
case payload, open := <-mdb.replication.masterChan:
case payload, open := <-mdb.slaveStatus.masterChan:
if !open {
return errors.New("master channel unexpected close")
}
@@ -370,31 +354,32 @@ func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error {
if !ok {
return errors.New("unexpected payload: " + string(payload.Data.ToBytes()))
}
mdb.replication.mutex.Lock()
if mdb.replication.configVersion != configVersion {
// replication conf changed during connecting and waiting mutex
mdb.slaveStatus.mutex.Lock()
if mdb.slaveStatus.configVersion != configVersion {
// slaveStatus conf changed during connecting and waiting mutex
return configChangedErr
}
mdb.Exec(conn, cmdLine.Args)
n := len(cmdLine.ToBytes()) // todo: directly get size from socket
mdb.replication.replOffset += int64(n)
mdb.replication.lastRecvTime = time.Now()
mdb.slaveStatus.replOffset += int64(n)
mdb.slaveStatus.lastRecvTime = time.Now()
logger.Info(fmt.Sprintf("receive %d bytes from master, current offset %d, %s",
n, mdb.replication.replOffset, strconv.Quote(string(cmdLine.ToBytes()))))
mdb.replication.mutex.Unlock()
n, mdb.slaveStatus.replOffset, strconv.Quote(string(cmdLine.ToBytes()))))
mdb.slaveStatus.mutex.Unlock()
case <-ctx.Done():
conn.GetConnPool().Put(conn)
_ = conn.Close()
return nil
}
}
}
func (mdb *MultiDB) slaveCron() {
repl := mdb.replication
repl := mdb.slaveStatus
if repl.masterConn == nil {
return
}
// check master timeout
replTimeout := 60 * time.Second
if config.Properties.ReplTimeout != 0 {
replTimeout = time.Duration(config.Properties.ReplTimeout) * time.Second
@@ -427,9 +412,9 @@ func (repl *slaveStatus) sendAck2Master() error {
func (mdb *MultiDB) reconnectWithMaster() error {
logger.Info("reconnecting with master")
mdb.replication.mutex.Lock()
defer mdb.replication.mutex.Unlock()
mdb.replication.stopSlaveWithMutex()
mdb.slaveStatus.mutex.Lock()
defer mdb.slaveStatus.mutex.Unlock()
mdb.slaveStatus.stopSlaveWithMutex()
go mdb.setupMaster()
return nil
}

View File

@@ -8,22 +8,12 @@ import (
"github.com/hdt3213/godis/redis/connection"
"github.com/hdt3213/godis/redis/protocol"
"github.com/hdt3213/godis/redis/protocol/asserts"
"sync/atomic"
"testing"
"time"
)
func TestReplication(t *testing.T) {
mdb := &MultiDB{}
mdb.dbSet = make([]*atomic.Value, 16)
for i := range mdb.dbSet {
singleDB := makeDB()
singleDB.index = i
holder := &atomic.Value{}
holder.Store(singleDB)
mdb.dbSet[i] = holder
}
mdb.replication = initReplStatus()
func TestReplicationSlaveSide(t *testing.T) {
mdb := mockServer()
masterCli, err := client.MakeClient("127.0.0.1:6379")
if err != nil {
t.Error(err)
@@ -33,7 +23,7 @@ func TestReplication(t *testing.T) {
// sync with master
ret := masterCli.Send(utils.ToCmdLine("set", "1", "1"))
asserts.AssertStatusReply(t, ret, "OK")
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
ret = mdb.Exec(conn, utils.ToCmdLine("SLAVEOF", "127.0.0.1", "6379"))
asserts.AssertStatusReply(t, ret, "OK")
success := false
@@ -74,7 +64,7 @@ func TestReplication(t *testing.T) {
t.Error("sync failed")
return
}
err = mdb.replication.sendAck2Master()
err = mdb.slaveStatus.sendAck2Master()
if err != nil {
t.Error(err)
return
@@ -83,8 +73,8 @@ func TestReplication(t *testing.T) {
// test reconnect
config.Properties.ReplTimeout = 1
_ = mdb.replication.masterConn.Close()
mdb.replication.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout
_ = mdb.slaveStatus.masterConn.Close()
mdb.slaveStatus.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout
mdb.slaveCron()
time.Sleep(3 * time.Second)
ret = masterCli.Send(utils.ToCmdLine("set", "1", "3"))
@@ -134,7 +124,7 @@ func TestReplication(t *testing.T) {
return
}
err = mdb.replication.close()
err = mdb.slaveStatus.close()
if err != nil {
t.Error("cannot close")
}

View File

@@ -20,7 +20,7 @@ func TestPing(t *testing.T) {
func TestAuth(t *testing.T) {
passwd := utils.RandString(10)
c := &connection.FakeConn{}
c := connection.NewFakeConn()
ret := testServer.Exec(c, utils.ToCmdLine("AUTH"))
asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command")
ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd))

View File

@@ -2,7 +2,9 @@ package redis
// Connection represents a connection with redis client
type Connection interface {
Write([]byte) error
Write([]byte) (int, error)
Close() error
SetPassword(string)
GetPassword() string
@@ -12,7 +14,6 @@ type Connection interface {
SubsCount() int
GetChannels() []string
// used for `Multi` command
InMultiState() bool
SetMultiState(bool)
GetQueuedCmdLine() [][][]byte
@@ -22,10 +23,12 @@ type Connection interface {
AddTxError(err error)
GetTxErrors() []error
// used for multi database
GetDBIndex() int
SelectDB(int)
// returns role of conn, such as connection with client, connection with master node
GetRole() int32
SetRole(int32)
SetSlave()
IsSlave() bool
SetMaster()
IsMaster() bool
}

View File

@@ -107,6 +107,13 @@ func Error(v ...interface{}) {
logger.Println(v...)
}
func Errorf(format string, v ...interface{}) {
mu.Lock()
defer mu.Unlock()
setPrefix(ERROR)
logger.Println(fmt.Sprintf(format, v...))
}
// Fatal prints error log then stop the program
func Fatal(v ...interface{}) {
mu.Lock()

View File

@@ -16,3 +16,13 @@ func RandString(n int) string {
}
return string(b)
}
var hexLetters = []rune("0123456789abcdef")
func RandHexString(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = hexLetters[rand.Intn(len(hexLetters))]
}
return string(b)
}

View File

@@ -82,7 +82,7 @@ func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
for _, channel := range channels {
if subscribe0(hub, channel, c) {
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
_, _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
}
}
return &protocol.NoReply{}
@@ -117,13 +117,13 @@ func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
defer db.subsLocker.UnLocks(channels...)
if len(channels) == 0 {
_ = c.Write(unSubscribeNothing)
_, _ = c.Write(unSubscribeNothing)
return &protocol.NoReply{}
}
for _, channel := range channels {
if unsubscribe0(db, channel, c) {
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
_, _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
}
}
return &protocol.NoReply{}
@@ -151,7 +151,7 @@ func Publish(hub *Hub, args [][]byte) redis.Reply {
replyArgs[0] = messageBytes
replyArgs[1] = []byte(channel)
replyArgs[2] = message
_ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes())
_, _ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes())
return true
})
return protocol.MakeIntReply(int64(subscribers.Len()))

View File

@@ -2,6 +2,6 @@ bind 0.0.0.0
port 6399
maxclients 128
#appendonly no
#appendfilename appendonly.aof
appendonly yes
appendfilename appendonly.aof
#dbfilename test.rdb

View File

@@ -1,7 +1,6 @@
package connection
import (
"bytes"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/wait"
"net"
@@ -10,21 +9,24 @@ import (
)
const (
// NormalCli is client with user
NormalCli = iota
// ReplicationRecvCli is fake client with replication master
ReplicationRecvCli
// flagSlave means this a connection with slave
flagSlave = uint64(1 << iota)
// flagSlave means this a connection with master
flagMaster
// flagMulti means this connection is within a transaction
flagMulti
)
// Connection represents a connection with a redis-cli
type Connection struct {
conn net.Conn
// waiting until protocol finished
waitingReply wait.Wait
// waiting until finish sending data, used for graceful shutdown
sendingData wait.Wait
// lock while server sending response
mu sync.Mutex
mu sync.Mutex
flags uint64
// subscribing channels
subs map[string]bool
@@ -33,14 +35,12 @@ type Connection struct {
password string
// queued commands for `multi`
multiState bool
queue [][][]byte
watching map[string]uint32
txErrors []error
queue [][][]byte
watching map[string]uint32
txErrors []error
// selected db
selectedDB int
role int32
}
var connPool = sync.Pool{
@@ -49,11 +49,6 @@ var connPool = sync.Pool{
},
}
// GetConnPool returns the connection pool pointer for putting and getting connection
func (c *Connection) GetConnPool() *sync.Pool {
return &connPool
}
// RemoteAddr returns the remote network address
func (c *Connection) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
@@ -61,8 +56,15 @@ func (c *Connection) RemoteAddr() net.Addr {
// Close disconnect with the client
func (c *Connection) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second)
c.sendingData.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close()
c.subs = nil
c.password = ""
c.queue = nil
c.watching = nil
c.txErrors = nil
c.selectedDB = 0
connPool.Put(c)
return nil
}
@@ -80,17 +82,16 @@ func NewConn(conn net.Conn) *Connection {
}
// Write sends response to client over tcp connection
func (c *Connection) Write(b []byte) error {
func (c *Connection) Write(b []byte) (int, error) {
if len(b) == 0 {
return nil
return 0, nil
}
c.waitingReply.Add(1)
c.sendingData.Add(1)
defer func() {
c.waitingReply.Done()
c.sendingData.Done()
}()
_, err := c.conn.Write(b)
return err
return c.conn.Write(b)
}
// Subscribe add current connection into subscribers of the given channel
@@ -146,7 +147,7 @@ func (c *Connection) GetPassword() string {
// InMultiState tells is connection in an uncommitted transaction
func (c *Connection) InMultiState() bool {
return c.multiState
return c.flags&flagMulti > 0
}
// SetMultiState sets transaction flag
@@ -154,8 +155,10 @@ func (c *Connection) SetMultiState(state bool) {
if !state { // reset data when cancel multi
c.watching = nil
c.queue = nil
c.flags &= ^flagMulti // clean multi flag
return
}
c.multiState = state
c.flags |= flagMulti
}
// GetQueuedCmdLine returns queued commands of current transaction
@@ -183,18 +186,6 @@ func (c *Connection) ClearQueuedCmds() {
c.queue = nil
}
// GetRole returns role of connection, such as connection with master
func (c *Connection) GetRole() int32 {
if c == nil {
return NormalCli
}
return c.role
}
func (c *Connection) SetRole(r int32) {
c.role = r
}
// GetWatching returns watching keys and their version code when started watching
func (c *Connection) GetWatching() map[string]uint32 {
if c.watching == nil {
@@ -213,24 +204,18 @@ func (c *Connection) SelectDB(dbNum int) {
c.selectedDB = dbNum
}
// FakeConn implements redis.Connection for test
type FakeConn struct {
Connection
buf bytes.Buffer
func (c *Connection) SetSlave() {
c.flags |= flagSlave
}
// Write writes data to buffer
func (c *FakeConn) Write(b []byte) error {
c.buf.Write(b)
return nil
func (c *Connection) IsSlave() bool {
return c.flags&flagSlave > 0
}
// Clean resets the buffer
func (c *FakeConn) Clean() {
c.buf.Reset()
func (c *Connection) SetMaster() {
c.flags |= flagMaster
}
// Bytes returns written data
func (c *FakeConn) Bytes() []byte {
return c.buf.Bytes()
func (c *Connection) IsMaster() bool {
return c.flags&flagMaster > 0
}

79
redis/connection/fake.go Normal file
View File

@@ -0,0 +1,79 @@
package connection
import (
"bytes"
"io"
"sync"
)
// FakeConn implements redis.Connection for test
type FakeConn struct {
Connection
buf bytes.Buffer
wait chan struct{}
closed bool
mu sync.Mutex
}
func NewFakeConn() *FakeConn {
c := &FakeConn{}
return c
}
// Write writes data to buffer
func (c *FakeConn) Write(b []byte) (int, error) {
if c.closed {
return 0, io.EOF
}
n, _ := c.buf.Write(b)
c.notify()
return n, nil
}
func (c *FakeConn) notify() {
if c.wait != nil {
c.mu.Lock()
if c.wait != nil {
close(c.wait)
c.wait = nil
}
c.mu.Unlock()
}
}
func (c *FakeConn) waiting() {
c.mu.Lock()
c.wait = make(chan struct{})
c.mu.Unlock()
<-c.wait
}
// Read reads data from buffer
func (c *FakeConn) Read(p []byte) (int, error) {
n, err := c.buf.Read(p)
if err == io.EOF {
if c.closed {
return 0, io.EOF
}
c.waiting()
return c.buf.Read(p)
}
return n, err
}
// Clean resets the buffer
func (c *FakeConn) Clean() {
c.wait = make(chan struct{})
c.buf.Reset()
}
// Bytes returns written data
func (c *FakeConn) Bytes() []byte {
return c.buf.Bytes()
}
func (c *FakeConn) Close() error {
c.closed = true
c.notify()
return nil
}

View File

@@ -13,7 +13,7 @@ func TestPublish(t *testing.T) {
hub := pubsub.MakeHub()
channel := utils.RandString(5)
msg := utils.RandString(5)
conn := &connection.FakeConn{}
conn := connection.NewFakeConn()
pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel))
conn.Clean() // clean subscribe success
pubsub.Publish(hub, utils.ToCmdLine(channel, msg))

View File

@@ -50,8 +50,6 @@ func (h *Handler) closeClient(client *connection.Connection) {
_ = client.Close()
h.db.AfterClientClose(client)
h.activeConn.Delete(client)
client.GetConnPool().Put(client)
}
// Handle receives and executes redis commands
@@ -78,7 +76,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
}
// protocol err
errReply := protocol.MakeErrReply(payload.Err.Error())
err := client.Write(errReply.ToBytes())
_, err := client.Write(errReply.ToBytes())
if err != nil {
h.closeClient(client)
logger.Info("connection closed: " + client.RemoteAddr().String())
@@ -97,9 +95,9 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
}
result := h.db.Exec(client, r.Args)
if result != nil {
_ = client.Write(result.ToBytes())
_, _ = client.Write(result.ToBytes())
} else {
_ = client.Write(unknownErrReplyBytes)
_, _ = client.Write(unknownErrReplyBytes)
}
}
}