From ba7ea942cb44e3a14d7121f380e9baa6546f7045 Mon Sep 17 00:00:00 2001 From: finley Date: Mon, 21 Nov 2022 23:36:35 +0800 Subject: [PATCH] replication master side --- aof/aof.go | 29 +- aof/rdb.go | 17 +- cluster/com_test.go | 8 +- cluster/del_test.go | 2 +- cluster/mset_test.go | 4 +- cluster/pubsub_test.go | 2 +- database/aof_test.go | 8 +- database/database.go | 80 ++-- database/rdb_test.go | 2 +- database/replication_master.go | 362 ++++++++++++++++++ database/replication_master_test.go | 194 ++++++++++ .../{replication.go => replication_slave.go} | 145 ++++--- ...tion_test.go => replication_slave_test.go} | 24 +- database/sys_test.go | 2 +- interface/redis/conn.go | 15 +- lib/logger/logger.go | 7 + lib/utils/rand_string.go | 10 + pubsub/pubsub.go | 8 +- redis.conf | 4 +- redis/connection/conn.go | 91 ++--- redis/connection/fake.go | 79 ++++ redis/server/pubsub_test.go | 2 +- redis/server/server.go | 8 +- 23 files changed, 886 insertions(+), 217 deletions(-) create mode 100644 database/replication_master.go create mode 100644 database/replication_master_test.go rename database/{replication.go => replication_slave.go} (76%) rename database/{replication_test.go => replication_slave_test.go} (84%) create mode 100644 redis/connection/fake.go diff --git a/aof/aof.go b/aof/aof.go index 312e34e..8d58acb 100644 --- a/aof/aof.go +++ b/aof/aof.go @@ -26,6 +26,10 @@ type payload struct { dbIndex int } +// Listener is a channel receive a replication of all aof payloads +// with a listener we can forward the updates to slave nodes etc. +type Listener chan<- CmdLine + // Handler receive msgs from channel and write to AOF file type Handler struct { db database.EmbedDB @@ -38,6 +42,7 @@ type Handler struct { // pause aof for start/finish aof rewrite progress pausingAof sync.RWMutex currentDB int + listeners map[Listener]struct{} } // NewAOFHandler creates a new aof.Handler @@ -54,12 +59,20 @@ func NewAOFHandler(db database.EmbedDB, tmpDBMaker func() database.EmbedDB) (*Ha handler.aofFile = aofFile handler.aofChan = make(chan *payload, aofQueueSize) handler.aofFinished = make(chan struct{}) + handler.listeners = make(map[Listener]struct{}) go func() { handler.handleAof() }() return handler, nil } +// RemoveListener removes a listener from aof handler, so we can close the listener +func (handler *Handler) RemoveListener(listener Listener) { + handler.pausingAof.Lock() + defer handler.pausingAof.Unlock() + delete(handler.listeners, listener) +} + // AddAof send command to aof goroutine through channel func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) { if config.Properties.AppendOnly && handler.aofChan != nil { @@ -73,12 +86,16 @@ func (handler *Handler) AddAof(dbIndex int, cmdLine CmdLine) { // handleAof listen aof channel and write into file func (handler *Handler) handleAof() { // serialized execution + var cmdLines []CmdLine handler.currentDB = 0 for p := range handler.aofChan { + cmdLines = cmdLines[:0] // reuse underlying array handler.pausingAof.RLock() // prevent other goroutines from pausing aof if p.dbIndex != handler.currentDB { // select db - data := protocol.MakeMultiBulkReply(utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex))).ToBytes() + selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex)) + cmdLines = append(cmdLines, selectCmd) + data := protocol.MakeMultiBulkReply(selectCmd).ToBytes() _, err := handler.aofFile.Write(data) if err != nil { logger.Warn(err) @@ -88,11 +105,17 @@ func (handler *Handler) handleAof() { handler.currentDB = p.dbIndex } data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes() + cmdLines = append(cmdLines, p.cmdLine) _, err := handler.aofFile.Write(data) if err != nil { logger.Warn(err) } handler.pausingAof.RUnlock() + for listener := range handler.listeners { + for _, line := range cmdLines { + listener <- line + } + } } handler.aofFinished <- struct{}{} } @@ -123,7 +146,7 @@ func (handler *Handler) LoadAof(maxBytes int) { reader = file } ch := parser.ParseStream(reader) - fakeConn := &connection.FakeConn{} // only used for save dbIndex + fakeConn := connection.NewFakeConn() // only used for save dbIndex for p := range ch { if p.Err != nil { if p.Err == io.EOF { @@ -143,7 +166,7 @@ func (handler *Handler) LoadAof(maxBytes int) { } ret := handler.db.Exec(fakeConn, r.Args) if protocol.IsErrorReply(ret) { - logger.Error("exec err", ret.ToBytes()) + logger.Error("exec err", string(ret.ToBytes())) } } } diff --git a/aof/rdb.go b/aof/rdb.go index 9b6879f..893981b 100644 --- a/aof/rdb.go +++ b/aof/rdb.go @@ -16,8 +16,12 @@ import ( "time" ) -func (handler *Handler) Rewrite2RDB() error { - ctx, err := handler.startRewrite2RDB() +// todo: forbid concurrent rewrite + +// Rewrite2RDB rewrite aof data into rdb +// if extraListener is not nil, it will be appended to Handler.listeners, it will receive all updates after rdb +func (handler *Handler) Rewrite2RDB(rdbFilename string, extraListener Listener) error { + ctx, err := handler.startRewrite2RDB(extraListener) if err != nil { return err } @@ -25,10 +29,6 @@ func (handler *Handler) Rewrite2RDB() error { if err != nil { return err } - rdbFilename := config.Properties.RDBFilename - if rdbFilename == "" { - rdbFilename = "dump.rdb" - } err = ctx.tmpFile.Close() if err != nil { return err @@ -40,7 +40,7 @@ func (handler *Handler) Rewrite2RDB() error { return nil } -func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) { +func (handler *Handler) startRewrite2RDB(extraListener Listener) (*RewriteCtx, error) { handler.pausingAof.Lock() // pausing aof defer handler.pausingAof.Unlock() @@ -59,6 +59,9 @@ func (handler *Handler) startRewrite2RDB() (*RewriteCtx, error) { logger.Warn("tmp file create failed") return nil, err } + if extraListener != nil { + handler.listeners[extraListener] = struct{}{} + } return &RewriteCtx{ tmpFile: file, fileSize: filesize, diff --git a/cluster/com_test.go b/cluster/com_test.go index 25ee917..7978bc3 100644 --- a/cluster/com_test.go +++ b/cluster/com_test.go @@ -10,7 +10,7 @@ import ( func TestExec(t *testing.T) { testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() for i := 0; i < 1000; i++ { key := RandString(4) value := RandString(4) @@ -26,7 +26,7 @@ func TestAuth(t *testing.T) { defer func() { config.Properties.RequirePass = "" }() - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() ret := testNodeA.Exec(conn, toArgs("GET", "a")) asserts.AssertErrReply(t, ret, "NOAUTH Authentication required") ret = testNodeA.Exec(conn, toArgs("AUTH", passwd)) @@ -39,7 +39,7 @@ func TestRelay(t *testing.T) { testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) key := RandString(4) value := RandString(4) - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() ret := testCluster2.relay("127.0.0.1:6379", conn, toArgs("SET", key, value)) asserts.AssertNotError(t, ret) ret = testCluster2.relay("127.0.0.1:6379", conn, toArgs("GET", key)) @@ -50,7 +50,7 @@ func TestBroadcast(t *testing.T) { testCluster2 := MakeTestCluster([]string{"127.0.0.1:6379"}) key := RandString(4) value := RandString(4) - rets := testCluster2.broadcast(&connection.FakeConn{}, toArgs("SET", key, value)) + rets := testCluster2.broadcast(connection.NewFakeConn(), toArgs("SET", key, value)) for _, v := range rets { asserts.AssertNotError(t, v) } diff --git a/cluster/del_test.go b/cluster/del_test.go index fbd0f32..cc3205f 100644 --- a/cluster/del_test.go +++ b/cluster/del_test.go @@ -7,7 +7,7 @@ import ( ) func TestDel(t *testing.T) { - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() allowFastTransaction = false testNodeA.Exec(conn, toArgs("SET", "a", "a")) ret := Del(testNodeA, conn, toArgs("DEL", "a", "b", "c")) diff --git a/cluster/mset_test.go b/cluster/mset_test.go index 377f575..9f81e4d 100644 --- a/cluster/mset_test.go +++ b/cluster/mset_test.go @@ -7,7 +7,7 @@ import ( ) func TestMSet(t *testing.T) { - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() allowFastTransaction = false ret := MSet(testNodeA, conn, toArgs("MSET", "a", "a", "b", "b")) asserts.AssertNotError(t, ret) @@ -16,7 +16,7 @@ func TestMSet(t *testing.T) { } func TestMSetNx(t *testing.T) { - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() allowFastTransaction = false FlushAll(testNodeA, conn, toArgs("FLUSHALL")) ret := MSetNX(testNodeA, conn, toArgs("MSETNX", "a", "a", "b", "b")) diff --git a/cluster/pubsub_test.go b/cluster/pubsub_test.go index 5b51dd7..7c380fb 100644 --- a/cluster/pubsub_test.go +++ b/cluster/pubsub_test.go @@ -11,7 +11,7 @@ import ( func TestPublish(t *testing.T) { channel := utils.RandString(5) msg := utils.RandString(5) - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() Subscribe(testNodeA, conn, utils.ToCmdLine("SUBSCRIBE", channel)) conn.Clean() // clean subscribe success Publish(testNodeA, conn, utils.ToCmdLine("PUBLISH", channel, msg)) diff --git a/database/aof_test.go b/database/aof_test.go index 4c9f7fc..af0d353 100644 --- a/database/aof_test.go +++ b/database/aof_test.go @@ -17,7 +17,7 @@ import ( ) func makeTestData(db database.DB, dbIndex int, prefix string, size int) { - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() conn.SelectDB(dbIndex) db.Exec(conn, utils.ToCmdLine("FlushDB")) cursor := 0 @@ -49,7 +49,7 @@ func makeTestData(db database.DB, dbIndex int, prefix string, size int) { } func validateTestData(t *testing.T, db database.DB, dbIndex int, prefix string, size int) { - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() conn.SelectDB(dbIndex) cursor := 0 var ret redis.Reply @@ -146,7 +146,7 @@ func TestRDB(t *testing.T) { dbNum := 4 size := 10 var prefixes []string - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() writeDB := NewStandaloneServer() for i := 0; i < dbNum; i++ { prefix := utils.RandString(8) @@ -216,7 +216,7 @@ func TestRewriteAOF2(t *testing.T) { } aofWriteDB := NewStandaloneServer() dbNum := 4 - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() for i := 0; i < dbNum; i++ { conn.SelectDB(i) key := strconv.Itoa(i) diff --git a/database/database.go b/database/database.go index 936661a..ca7a35d 100644 --- a/database/database.go +++ b/database/database.go @@ -9,7 +9,6 @@ import ( "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/pubsub" - "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" "runtime/debug" "strconv" @@ -27,10 +26,10 @@ type MultiDB struct { // handle aof persistence aofHandler *aof.Handler - // store master node address - slaveOf string - role int32 - replication *slaveStatus + // for replication + role int32 + slaveStatus *slaveStatus + masterStatus *masterStatus } // NewStandaloneServer creates a standalone redis server, with multi database and all other funtions @@ -50,31 +49,36 @@ func NewStandaloneServer() *MultiDB { mdb.hub = pubsub.MakeHub() validAof := false if config.Properties.AppendOnly { - aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB { - return MakeBasicMultiDB() - }) - if err != nil { - panic(err) - } - mdb.aofHandler = aofHandler - for _, db := range mdb.dbSet { - singleDB := db.Load().(*DB) - singleDB.addAof = func(line CmdLine) { - mdb.aofHandler.AddAof(singleDB.index, line) - } - } + mdb.initAof() validAof = true } if config.Properties.RDBFilename != "" && !validAof { // load rdb loadRdbFile(mdb) } - mdb.replication = initReplStatus() + mdb.slaveStatus = initReplSlaveStatus() + mdb.startAsMaster() mdb.startReplCron() mdb.role = masterRole // The initialization process does not require atomicity return mdb } +func (mdb *MultiDB) initAof() { + aofHandler, err := aof.NewAOFHandler(mdb, func() database.EmbedDB { + return MakeBasicMultiDB() + }) + if err != nil { + panic(err) + } + mdb.aofHandler = aofHandler + for _, db := range mdb.dbSet { + singleDB := db.Load().(*DB) + singleDB.addAof = func(line CmdLine) { + mdb.aofHandler.AddAof(singleDB.index, line) + } + } +} + // MakeBasicMultiDB create a MultiDB only with basic abilities for aof rewrite and other usages func MakeBasicMultiDB() *MultiDB { mdb := &MultiDB{} @@ -117,8 +121,7 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep // read only slave role := atomic.LoadInt32(&mdb.role) - if role == slaveRole && - c.GetRole() != connection.ReplicationRecvCli { + if role == slaveRole && !c.IsMaster() { // only allow read only command, forbid all special commands except `auth` and `slaveof` if !isReadOnlyCommand(cmdName) { return protocol.MakeErrReply("READONLY You can't write against a read only slave.") @@ -167,6 +170,10 @@ func (mdb *MultiDB) Exec(c redis.Connection, cmdLine [][]byte) (result redis.Rep return protocol.MakeArgNumErrReply("copy") } return execCopy(mdb, c, cmdLine[1:]) + } else if cmdName == "replconf" { + return mdb.execReplConf(c, cmdLine[1:]) + } else if cmdName == "psync" { + return mdb.execPSync(c, cmdLine[1:]) } // todo: support multi database transaction @@ -186,8 +193,8 @@ func (mdb *MultiDB) AfterClientClose(c redis.Connection) { // Close graceful shutdown database func (mdb *MultiDB) Close() { - // stop replication first - mdb.replication.close() + // stop slaveStatus first + mdb.slaveStatus.close() if mdb.aofHandler != nil { mdb.aofHandler.Close() } @@ -308,7 +315,11 @@ func SaveRDB(db *MultiDB, args [][]byte) redis.Reply { if db.aofHandler == nil { return protocol.MakeErrReply("please enable aof before using save") } - err := db.aofHandler.Rewrite2RDB() + rdbFilename := config.Properties.RDBFilename + if rdbFilename == "" { + rdbFilename = "dump.rdb" + } + err := db.aofHandler.Rewrite2RDB(rdbFilename, nil) if err != nil { return protocol.MakeErrReply(err.Error()) } @@ -326,7 +337,11 @@ func BGSaveRDB(db *MultiDB, args [][]byte) redis.Reply { logger.Error(err) } }() - err := db.aofHandler.Rewrite2RDB() + rdbFilename := config.Properties.RDBFilename + if rdbFilename == "" { + rdbFilename = "dump.rdb" + } + err := db.aofHandler.Rewrite2RDB(rdbFilename, nil) if err != nil { logger.Error(err) } @@ -339,3 +354,18 @@ func (mdb *MultiDB) GetDBSize(dbIndex int) (int, int) { db := mdb.mustSelectDB(dbIndex) return db.data.Len(), db.ttlMap.Len() } + +func (mdb *MultiDB) startReplCron() { + go func() { + defer func() { + if err := recover(); err != nil { + logger.Error("panic", err) + } + }() + ticker := time.Tick(time.Second * 10) + for range ticker { + mdb.slaveCron() + mdb.masterCron() + } + }() +} diff --git a/database/rdb_test.go b/database/rdb_test.go index fcd0dc6..8130959 100644 --- a/database/rdb_test.go +++ b/database/rdb_test.go @@ -17,7 +17,7 @@ func TestLoadRDB(t *testing.T) { AppendOnly: false, RDBFilename: filepath.Join(projectRoot, "test.rdb"), // set working directory to project root } - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() rdbDB := NewStandaloneServer() result := rdbDB.Exec(conn, utils.ToCmdLine("Get", "str")) asserts.AssertBulkReply(t, result, "str") diff --git a/database/replication_master.go b/database/replication_master.go new file mode 100644 index 0000000..29a2ce3 --- /dev/null +++ b/database/replication_master.go @@ -0,0 +1,362 @@ +package database + +import ( + "context" + "errors" + "fmt" + "github.com/hdt3213/godis/interface/redis" + "github.com/hdt3213/godis/lib/logger" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/protocol" + "io" + "io/ioutil" + "os" + "strconv" + "strings" + "sync" + "time" +) + +const ( + slaveStateHandShake = uint8(iota) + slaveStateWaitSaveEnd + slaveStateSendingRDB + slaveStateOnline +) + +const ( + bgSaveIdle = uint8(iota) + bgSaveRunning + bgSaveFinish +) + +const ( + salveCapacityNone = 0 + salveCapacityEOF = 1 << iota + salveCapacityPsync2 +) + +// slaveClient stores slave status in the view of master +type slaveClient struct { + conn redis.Connection + state uint8 + offset int64 + lastAckTime time.Time + announceIp string + announcePort int + capacity uint8 +} + +type masterStatus struct { + ctx context.Context + mu sync.RWMutex + replId string + backlog []byte // backlog can be appended or replaced as a whole, cannot be modified(insert/set/delete) + beginOffset int64 + currentOffset int64 + slaveMap map[redis.Connection]*slaveClient + waitSlaves map[*slaveClient]struct{} + onlineSlaves map[*slaveClient]struct{} + bgSaveState uint8 + rdbFilename string +} + +func (master *masterStatus) appendBacklog(bin []byte) { + master.backlog = append(master.backlog, bin...) + master.currentOffset += int64(len(bin)) +} + +func (mdb *MultiDB) bgSaveForReplication() { + go func() { + defer func() { + if e := recover(); e != nil { + logger.Errorf("panic: %v", e) + } + }() + if err := mdb.saveForReplication(); err != nil { + logger.Errorf("save for replication error: %v", err) + } + }() + +} + +// saveForReplication does bg-save and send rdb to waiting slaves +func (mdb *MultiDB) saveForReplication() error { + rdbFile, err := ioutil.TempFile("", "*.rdb") + if err != nil { + return fmt.Errorf("create temp rdb failed: %v", err) + } + rdbFilename := rdbFile.Name() + mdb.masterStatus.mu.Lock() + mdb.masterStatus.bgSaveState = bgSaveRunning + mdb.masterStatus.rdbFilename = rdbFilename // todo: can reuse config.Properties.RDBFilename? + mdb.masterStatus.mu.Unlock() + + aofListener := make(chan CmdLine, 1024) // give channel enough capacity to store all updates during rewrite to db + err = mdb.aofHandler.Rewrite2RDB(rdbFilename, aofListener) + if err != nil { + return err + } + go func() { + mdb.masterListenAof(aofListener) + }() + + // change bgSaveState and get waitSlaves for sending + waitSlaves := make(map[*slaveClient]struct{}) + mdb.masterStatus.mu.Lock() + mdb.masterStatus.bgSaveState = bgSaveFinish + for slave := range mdb.masterStatus.waitSlaves { + waitSlaves[slave] = struct{}{} + } + mdb.masterStatus.waitSlaves = nil + mdb.masterStatus.mu.Unlock() + + for slave := range waitSlaves { + err = mdb.masterFullReSyncWithSlave(slave) + if err != nil { + mdb.removeSlave(slave) + logger.Errorf("masterFullReSyncWithSlave error: %v", err) + continue + } + } + return nil +} + +// masterFullReSyncWithSlave send replication header, rdb file and all backlogs to slave +func (mdb *MultiDB) masterFullReSyncWithSlave(slave *slaveClient) error { + // write replication header + header := "+FULLRESYNC " + mdb.masterStatus.replId + " " + + strconv.FormatInt(mdb.masterStatus.beginOffset, 10) + protocol.CRLF + _, err := slave.conn.Write([]byte(header)) + if err != nil { + return fmt.Errorf("write replication header to slave failed: %v", err) + } + // send rdb + rdbFile, err := os.Open(mdb.masterStatus.rdbFilename) + if err != nil { + return fmt.Errorf("open rdb file %s for replication error: %v", mdb.masterStatus.rdbFilename, err) + } + slave.state = slaveStateSendingRDB + rdbInfo, _ := os.Stat(mdb.masterStatus.rdbFilename) + rdbSize := rdbInfo.Size() + rdbHeader := "$" + strconv.FormatInt(rdbSize, 10) + protocol.CRLF + _, err = slave.conn.Write([]byte(rdbHeader)) + if err != nil { + return fmt.Errorf("write rdb header to slave failed: %v", err) + } + _, err = io.Copy(slave.conn, rdbFile) + if err != nil { + return fmt.Errorf("write rdb file to slave failed: %v", err) + } + + // send backlog + mdb.masterStatus.mu.RLock() + currentOffset := mdb.masterStatus.currentOffset + backlog := mdb.masterStatus.backlog[:currentOffset-mdb.masterStatus.beginOffset] + mdb.masterStatus.mu.RUnlock() + _, err = slave.conn.Write(backlog) + if err != nil { + return fmt.Errorf("full resync write backlog to slave failed: %v", err) + } + + // set slave as online + mdb.setSlaveOnline(slave, currentOffset) + return nil +} + +var cannotPartialSync = errors.New("cannot do partial sync") + +func (mdb *MultiDB) masterTryPartialSyncWithSlave(slave *slaveClient, replId string, slaveOffset int64) error { + mdb.masterStatus.mu.RLock() + if replId != mdb.masterStatus.replId { + mdb.masterStatus.mu.RUnlock() + return cannotPartialSync + } + if slaveOffset < mdb.masterStatus.beginOffset || slaveOffset > mdb.masterStatus.currentOffset { + mdb.masterStatus.mu.RUnlock() + return cannotPartialSync + } + currentOffset := mdb.masterStatus.currentOffset + backlog := mdb.masterStatus.backlog[slaveOffset-mdb.masterStatus.beginOffset : currentOffset-mdb.masterStatus.beginOffset] + mdb.masterStatus.mu.RUnlock() + + // send replication header + header := "+CONTINUE " + mdb.masterStatus.replId + protocol.CRLF + _, err := slave.conn.Write([]byte(header)) + if err != nil { + return fmt.Errorf("write replication header to slave failed: %v", err) + } + // send backlog + _, err = slave.conn.Write(backlog) + if err != nil { + return fmt.Errorf("partial resync write backlog to slave failed: %v", err) + } + + // set slave online + mdb.setSlaveOnline(slave, currentOffset) + return nil +} + +func (mdb *MultiDB) masterSendUpdatesToSlave() error { + onlineSlaves := make(map[*slaveClient]struct{}) + mdb.masterStatus.mu.RLock() + currentOffset := mdb.masterStatus.currentOffset + beginOffset := mdb.masterStatus.beginOffset + backlog := mdb.masterStatus.backlog[:currentOffset-beginOffset] + for slave := range mdb.masterStatus.onlineSlaves { + onlineSlaves[slave] = struct{}{} + } + mdb.masterStatus.mu.RUnlock() + for slave := range onlineSlaves { + slaveBeginOffset := slave.offset - beginOffset + _, err := slave.conn.Write(backlog[slaveBeginOffset:]) + if err != nil { + logger.Errorf("send updates write backlog to slave failed: %v", err) + mdb.removeSlave(slave) + continue + } + slave.offset = currentOffset + } + return nil +} + +func (mdb *MultiDB) execPSync(c redis.Connection, args [][]byte) redis.Reply { + replId := string(args[0]) + replOffset, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + mdb.masterStatus.mu.Lock() + defer mdb.masterStatus.mu.Unlock() + slave := mdb.masterStatus.slaveMap[c] + if slave == nil { + slave = &slaveClient{ + conn: c, + } + c.SetSlave() + mdb.masterStatus.slaveMap[c] = slave + } + if mdb.masterStatus.bgSaveState == bgSaveIdle { + slave.state = slaveStateWaitSaveEnd + mdb.masterStatus.waitSlaves[slave] = struct{}{} + mdb.bgSaveForReplication() + } else if mdb.masterStatus.bgSaveState == bgSaveRunning { + slave.state = slaveStateWaitSaveEnd + mdb.masterStatus.waitSlaves[slave] = struct{}{} + } else if mdb.masterStatus.bgSaveState == bgSaveFinish { + go func() { + defer func() { + if e := recover(); e != nil { + logger.Errorf("panic: %v", e) + } + }() + err := mdb.masterTryPartialSyncWithSlave(slave, replId, replOffset) + if err == nil { + return + } + if err != nil && err != cannotPartialSync { + mdb.removeSlave(slave) + logger.Errorf("masterTryPartialSyncWithSlave error: %v", err) + return + } + // assert err == cannotPartialSync + if err := mdb.masterFullReSyncWithSlave(slave); err != nil { + mdb.removeSlave(slave) + logger.Errorf("masterFullReSyncWithSlave error: %v", err) + return + } + }() + } + return &protocol.NoReply{} +} + +func (mdb *MultiDB) execReplConf(c redis.Connection, args [][]byte) redis.Reply { + if len(args)%2 != 0 { + return protocol.MakeSyntaxErrReply() + } + mdb.masterStatus.mu.RLock() + slave := mdb.masterStatus.slaveMap[c] + mdb.masterStatus.mu.RUnlock() + for i := 0; i < len(args); i += 2 { + key := strings.ToLower(string(args[i])) + value := string(args[i+1]) + switch key { + case "ack": + offset, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return protocol.MakeErrReply("ERR value is not an integer or out of range") + } + slave.offset = offset + slave.lastAckTime = time.Now() + return &protocol.NoReply{} + } + } + return protocol.MakeOkReply() +} + +func (mdb *MultiDB) removeSlave(slave *slaveClient) { + mdb.masterStatus.mu.Lock() + defer mdb.masterStatus.mu.Unlock() + _ = slave.conn.Close() + delete(mdb.masterStatus.slaveMap, slave.conn) + delete(mdb.masterStatus.waitSlaves, slave) + delete(mdb.masterStatus.onlineSlaves, slave) +} + +func (mdb *MultiDB) setSlaveOnline(slave *slaveClient, currentOffset int64) { + mdb.masterStatus.mu.Lock() + defer mdb.masterStatus.mu.Unlock() + slave.state = slaveStateOnline + slave.offset = currentOffset + mdb.masterStatus.onlineSlaves[slave] = struct{}{} +} + +var pingBytes = protocol.MakeMultiBulkReply(utils.ToCmdLine("ping")).ToBytes() + +func (mdb *MultiDB) masterCron() { + if mdb.role != masterRole { + return + } + mdb.masterStatus.mu.Lock() + if mdb.masterStatus.bgSaveState == bgSaveFinish { + mdb.masterStatus.appendBacklog(pingBytes) + } + mdb.masterStatus.mu.Unlock() + if err := mdb.masterSendUpdatesToSlave(); err != nil { + logger.Errorf("masterSendUpdatesToSlave error: %v", err) + } +} + +func (mdb *MultiDB) masterListenAof(listener chan CmdLine) { + for { + select { + case cmdLine := <-listener: + mdb.masterStatus.mu.Lock() + reply := protocol.MakeMultiBulkReply(cmdLine) + mdb.masterStatus.appendBacklog(reply.ToBytes()) + mdb.masterStatus.mu.Unlock() + if err := mdb.masterSendUpdatesToSlave(); err != nil { + logger.Errorf("masterSendUpdatesToSlave after receive aof error: %v", err) + } + // if bgSave is running, updates will be sent after the save finished + case <-mdb.masterStatus.ctx.Done(): + break + } + } +} + +func (mdb *MultiDB) startAsMaster() { + mdb.masterStatus = &masterStatus{ + ctx: context.Background(), + mu: sync.RWMutex{}, + replId: utils.RandHexString(40), + backlog: nil, + beginOffset: 0, + currentOffset: 0, + slaveMap: make(map[redis.Connection]*slaveClient), + waitSlaves: make(map[*slaveClient]struct{}), + onlineSlaves: make(map[*slaveClient]struct{}), + bgSaveState: bgSaveIdle, + rdbFilename: "", + } +} diff --git a/database/replication_master_test.go b/database/replication_master_test.go new file mode 100644 index 0000000..5b361d4 --- /dev/null +++ b/database/replication_master_test.go @@ -0,0 +1,194 @@ +package database + +import ( + "bytes" + "github.com/hdt3213/godis/config" + "github.com/hdt3213/godis/lib/utils" + "github.com/hdt3213/godis/redis/connection" + "github.com/hdt3213/godis/redis/parser" + "github.com/hdt3213/godis/redis/protocol" + "github.com/hdt3213/godis/redis/protocol/asserts" + rdb "github.com/hdt3213/rdb/parser" + "io/ioutil" + "os" + "path" + "strconv" + "strings" + "sync/atomic" + "testing" + "time" +) + +func mockServer() *MultiDB { + server := &MultiDB{} + server.dbSet = make([]*atomic.Value, 16) + for i := range server.dbSet { + singleDB := makeDB() + singleDB.index = i + holder := &atomic.Value{} + holder.Store(singleDB) + server.dbSet[i] = holder + } + server.slaveStatus = initReplSlaveStatus() + return server +} + +func TestReplicationMasterSide(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "godis") + if err != nil { + t.Error(err) + return + } + aofFilename := path.Join(tmpDir, "a.aof") + defer func() { + _ = os.Remove(aofFilename) + }() + config.Properties = &config.ServerProperties{ + Databases: 16, + AppendOnly: true, + AppendFilename: aofFilename, + } + master := mockServer() + master.initAof() + master.startAsMaster() + slave := mockServer() + replConn := connection.NewFakeConn() + + // set data to master + masterConn := connection.NewFakeConn() + resp := master.Exec(masterConn, utils.ToCmdLine("SET", "a", "a")) + asserts.AssertNotError(t, resp) + time.Sleep(time.Millisecond * 100) // wait write aof + + // full re-sync + master.Exec(replConn, utils.ToCmdLine("psync", "?", "-1")) + masterChan := parser.ParseStream(replConn) + psyncPayload := <-masterChan + if psyncPayload.Err != nil { + t.Errorf("master bad protocol: %v", psyncPayload.Err) + return + } + psyncHeader, ok := psyncPayload.Data.(*protocol.StatusReply) + if !ok { + t.Error("psync header is not a status reply") + return + } + headers := strings.Split(psyncHeader.Status, " ") + if len(headers) != 3 { + t.Errorf("illegal psync header: %s", psyncHeader.Status) + return + } + + replId := headers[1] + replOffset, err := strconv.ParseInt(headers[2], 10, 64) + if err != nil { + t.Errorf("illegal offset: %s", headers[2]) + return + } + t.Logf("repl id: %s, offset: %d", replId, replOffset) + + rdbPayload := <-masterChan + if rdbPayload.Err != nil { + t.Error("read response failed: " + rdbPayload.Err.Error()) + return + } + rdbReply, ok := rdbPayload.Data.(*protocol.BulkReply) + if !ok { + t.Error("illegal payload header: " + string(rdbPayload.Data.ToBytes())) + return + } + + rdbDec := rdb.NewDecoder(bytes.NewReader(rdbReply.Arg)) + err = importRDB(rdbDec, slave) + if err != nil { + t.Error("import rdb failed: " + err.Error()) + return + } + + // get a + slaveConn := connection.NewFakeConn() + resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "a")) + asserts.AssertBulkReply(t, resp, "a") + + /*---- test broadcast aof ----*/ + masterConn = connection.NewFakeConn() + resp = master.Exec(masterConn, utils.ToCmdLine("SET", "b", "b")) + time.Sleep(time.Millisecond * 100) // wait write aof + asserts.AssertNotError(t, resp) + master.masterCron() + for { + payload := <-masterChan + if payload.Err != nil { + t.Error(payload.Err) + return + } + cmdLine, ok := payload.Data.(*protocol.MultiBulkReply) + if !ok { + t.Error("unexpected payload: " + string(payload.Data.ToBytes())) + return + } + slave.Exec(replConn, cmdLine.Args) + n := len(cmdLine.ToBytes()) + slave.slaveStatus.replOffset += int64(n) + if string(cmdLine.Args[0]) != "ping" { + break + } + } + + resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "b")) + asserts.AssertBulkReply(t, resp, "b") + + /*---- test partial reconnect ----*/ + _ = replConn.Close() // mock disconnect + + replConn = connection.NewFakeConn() + + master.Exec(replConn, utils.ToCmdLine("psync", replId, + strconv.FormatInt(slave.slaveStatus.replOffset, 10))) + masterChan = parser.ParseStream(replConn) + psyncPayload = <-masterChan + if psyncPayload.Err != nil { + t.Errorf("master bad protocol: %v", psyncPayload.Err) + return + } + psyncHeader, ok = psyncPayload.Data.(*protocol.StatusReply) + if !ok { + t.Error("psync header is not a status reply") + return + } + headers = strings.Split(psyncHeader.Status, " ") + if len(headers) != 2 { + t.Errorf("illegal psync header: %s", psyncHeader.Status) + return + } + if headers[0] != "CONTINUE" { + t.Errorf("expect CONTINUE actual %s", headers[0]) + return + } + replId = headers[1] + t.Logf("partial resync repl id: %s, offset: %d", replId, slave.slaveStatus.replOffset) + + resp = master.Exec(masterConn, utils.ToCmdLine("SET", "c", "c")) + time.Sleep(time.Millisecond * 100) // wait write aof + asserts.AssertNotError(t, resp) + master.masterCron() + for { + payload := <-masterChan + if payload.Err != nil { + t.Error(payload.Err) + return + } + cmdLine, ok := payload.Data.(*protocol.MultiBulkReply) + if !ok { + t.Error("unexpected payload: " + string(payload.Data.ToBytes())) + return + } + slave.Exec(replConn, cmdLine.Args) + if string(cmdLine.Args[0]) != "ping" { + break + } + } + + resp = slave.Exec(slaveConn, utils.ToCmdLine("get", "c")) + asserts.AssertBulkReply(t, resp, "c") +} diff --git a/database/replication.go b/database/replication_slave.go similarity index 76% rename from database/replication.go rename to database/replication_slave.go index 6bccd95..d02f66c 100644 --- a/database/replication.go +++ b/database/replication_slave.go @@ -31,9 +31,9 @@ type slaveStatus struct { ctx context.Context cancel context.CancelFunc - // configVersion stands for the version of replication config. Any change of master host/port will cause configVersion increment - // If configVersion change has been found during replication current replication procedure will stop. - // It is designed to abort a running replication procedure + // configVersion stands for the version of slaveStatus config. Any change of master host/port will cause configVersion increment + // If configVersion change has been found during slaveStatus current slaveStatus procedure will stop. + // It is designed to abort a running slaveStatus procedure configVersion int32 masterHost string @@ -47,26 +47,10 @@ type slaveStatus struct { running sync.WaitGroup } -var configChangedErr = errors.New("replication config changed") +var configChangedErr = errors.New("slaveStatus config changed") -func initReplStatus() *slaveStatus { - repl := &slaveStatus{} - // start cron - return repl -} - -func (mdb *MultiDB) startReplCron() { - go func() { - defer func() { - if err := recover(); err != nil { - logger.Error("panic", err) - } - }() - ticker := time.Tick(time.Second) - for range ticker { - mdb.slaveCron() - } - }() +func initReplSlaveStatus() *slaveStatus { + return &slaveStatus{} } func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply { @@ -80,29 +64,29 @@ func (mdb *MultiDB) execSlaveOf(c redis.Connection, args [][]byte) redis.Reply { if err != nil { return protocol.MakeErrReply("ERR value is not an integer or out of range") } - mdb.replication.mutex.Lock() + mdb.slaveStatus.mutex.Lock() atomic.StoreInt32(&mdb.role, slaveRole) - mdb.replication.masterHost = host - mdb.replication.masterPort = port + mdb.slaveStatus.masterHost = host + mdb.slaveStatus.masterPort = port // use buffered channel in case receiver goroutine exited before controller send stop signal - atomic.AddInt32(&mdb.replication.configVersion, 1) - mdb.replication.mutex.Unlock() + atomic.AddInt32(&mdb.slaveStatus.configVersion, 1) + mdb.slaveStatus.mutex.Unlock() go mdb.setupMaster() return protocol.MakeOkReply() } func (mdb *MultiDB) slaveOfNone() { - mdb.replication.mutex.Lock() - defer mdb.replication.mutex.Unlock() - mdb.replication.masterHost = "" - mdb.replication.masterPort = 0 - mdb.replication.replId = "" - mdb.replication.replOffset = -1 - mdb.replication.stopSlaveWithMutex() + mdb.slaveStatus.mutex.Lock() + defer mdb.slaveStatus.mutex.Unlock() + mdb.slaveStatus.masterHost = "" + mdb.slaveStatus.masterPort = 0 + mdb.slaveStatus.replId = "" + mdb.slaveStatus.replOffset = -1 + mdb.slaveStatus.stopSlaveWithMutex() } // stopSlaveWithMutex stops in-progress connectWithMaster/fullSync/receiveAOF -// invoker should have replication mutex +// invoker should have slaveStatus mutex func (repl *slaveStatus) stopSlaveWithMutex() { // update configVersion to stop connectWithMaster and fullSync atomic.AddInt32(&repl.configVersion, 1) @@ -135,11 +119,11 @@ func (mdb *MultiDB) setupMaster() { }() var configVersion int32 ctx, cancel := context.WithCancel(context.Background()) - mdb.replication.mutex.Lock() - mdb.replication.ctx = ctx - mdb.replication.cancel = cancel - configVersion = mdb.replication.configVersion - mdb.replication.mutex.Unlock() + mdb.slaveStatus.mutex.Lock() + mdb.slaveStatus.ctx = ctx + mdb.slaveStatus.cancel = cancel + configVersion = mdb.slaveStatus.configVersion + mdb.slaveStatus.mutex.Unlock() isFullReSync, err := mdb.connectWithMaster(configVersion) if err != nil { // connect failed, abort master @@ -167,7 +151,7 @@ func (mdb *MultiDB) setupMaster() { // connectWithMaster finishes handshake with master // returns: isFullReSync, error func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) { - addr := mdb.replication.masterHost + ":" + strconv.Itoa(mdb.replication.masterPort) + addr := mdb.slaveStatus.masterHost + ":" + strconv.Itoa(mdb.slaveStatus.masterPort) conn, err := net.Dial("tcp", addr) if err != nil { mdb.slaveOfNone() // abort @@ -256,34 +240,34 @@ func (mdb *MultiDB) connectWithMaster(configVersion int32) (bool, error) { } // update connection - mdb.replication.mutex.Lock() - defer mdb.replication.mutex.Unlock() - if mdb.replication.configVersion != configVersion { - // replication conf changed during connecting and waiting mutex + mdb.slaveStatus.mutex.Lock() + defer mdb.slaveStatus.mutex.Unlock() + if mdb.slaveStatus.configVersion != configVersion { + // slaveStatus conf changed during connecting and waiting mutex return false, configChangedErr } - mdb.replication.masterConn = conn - mdb.replication.masterChan = masterChan - mdb.replication.lastRecvTime = time.Now() + mdb.slaveStatus.masterConn = conn + mdb.slaveStatus.masterChan = masterChan + mdb.slaveStatus.lastRecvTime = time.Now() return mdb.psyncHandshake() } // psyncHandshake send `psync` to master and sync repl-id/offset with master -// invoker should provide with replication.mutex +// invoker should provide with slaveStatus.mutex func (mdb *MultiDB) psyncHandshake() (bool, error) { replId := "?" var replOffset int64 = -1 - if mdb.replication.replId != "" { - replId = mdb.replication.replId - replOffset = mdb.replication.replOffset + if mdb.slaveStatus.replId != "" { + replId = mdb.slaveStatus.replId + replOffset = mdb.slaveStatus.replOffset } psyncCmdLine := utils.ToCmdLine("psync", replId, strconv.FormatInt(replOffset, 10)) psyncReq := protocol.MakeMultiBulkReply(psyncCmdLine) - _, err := mdb.replication.masterConn.Write(psyncReq.ToBytes()) + _, err := mdb.slaveStatus.masterConn.Write(psyncReq.ToBytes()) if err != nil { return false, errors.New("send failed " + err.Error()) } - psyncPayload := <-mdb.replication.masterChan + psyncPayload := <-mdb.slaveStatus.masterChan if psyncPayload.Err != nil { return false, errors.New("read response failed: " + psyncPayload.Err.Error()) } @@ -300,12 +284,12 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) { var isFullReSync bool if headers[0] == "FULLRESYNC" { logger.Info("full re-sync with master") - mdb.replication.replId = headers[1] - mdb.replication.replOffset, err = strconv.ParseInt(headers[2], 10, 64) + mdb.slaveStatus.replId = headers[1] + mdb.slaveStatus.replOffset, err = strconv.ParseInt(headers[2], 10, 64) isFullReSync = true } else if headers[0] == "CONTINUE" { logger.Info("continue partial sync") - mdb.replication.replId = headers[1] + mdb.slaveStatus.replId = headers[1] isFullReSync = false } else { return false, errors.New("illegal psync resp: " + psyncHeader.Status) @@ -314,13 +298,13 @@ func (mdb *MultiDB) psyncHandshake() (bool, error) { if err != nil { return false, errors.New("get illegal repl offset: " + headers[2]) } - logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.replication.replId, mdb.replication.replOffset)) + logger.Info(fmt.Sprintf("repl id: %s, current offset: %d", mdb.slaveStatus.replId, mdb.slaveStatus.replOffset)) return isFullReSync, nil } // loadMasterRDB downloads rdb after handshake has been done func (mdb *MultiDB) loadMasterRDB(configVersion int32) error { - rdbPayload := <-mdb.replication.masterChan + rdbPayload := <-mdb.slaveStatus.masterChan if rdbPayload.Err != nil { return errors.New("read response failed: " + rdbPayload.Err.Error()) } @@ -337,10 +321,10 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error { return errors.New("dump rdb failed: " + err.Error()) } - mdb.replication.mutex.Lock() - defer mdb.replication.mutex.Unlock() - if mdb.replication.configVersion != configVersion { - // replication conf changed during connecting and waiting mutex + mdb.slaveStatus.mutex.Lock() + defer mdb.slaveStatus.mutex.Unlock() + if mdb.slaveStatus.configVersion != configVersion { + // slaveStatus conf changed during connecting and waiting mutex return configChangedErr } for i, h := range rdbHolder.dbSet { @@ -353,13 +337,13 @@ func (mdb *MultiDB) loadMasterRDB(configVersion int32) error { } func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error { - conn := connection.NewConn(mdb.replication.masterConn) - conn.SetRole(connection.ReplicationRecvCli) - mdb.replication.running.Add(1) - defer mdb.replication.running.Done() + conn := connection.NewConn(mdb.slaveStatus.masterConn) + conn.SetMaster() + mdb.slaveStatus.running.Add(1) + defer mdb.slaveStatus.running.Done() for { select { - case payload, open := <-mdb.replication.masterChan: + case payload, open := <-mdb.slaveStatus.masterChan: if !open { return errors.New("master channel unexpected close") } @@ -370,31 +354,32 @@ func (mdb *MultiDB) receiveAOF(ctx context.Context, configVersion int32) error { if !ok { return errors.New("unexpected payload: " + string(payload.Data.ToBytes())) } - mdb.replication.mutex.Lock() - if mdb.replication.configVersion != configVersion { - // replication conf changed during connecting and waiting mutex + mdb.slaveStatus.mutex.Lock() + if mdb.slaveStatus.configVersion != configVersion { + // slaveStatus conf changed during connecting and waiting mutex return configChangedErr } mdb.Exec(conn, cmdLine.Args) n := len(cmdLine.ToBytes()) // todo: directly get size from socket - mdb.replication.replOffset += int64(n) - mdb.replication.lastRecvTime = time.Now() + mdb.slaveStatus.replOffset += int64(n) + mdb.slaveStatus.lastRecvTime = time.Now() logger.Info(fmt.Sprintf("receive %d bytes from master, current offset %d, %s", - n, mdb.replication.replOffset, strconv.Quote(string(cmdLine.ToBytes())))) - mdb.replication.mutex.Unlock() + n, mdb.slaveStatus.replOffset, strconv.Quote(string(cmdLine.ToBytes())))) + mdb.slaveStatus.mutex.Unlock() case <-ctx.Done(): - conn.GetConnPool().Put(conn) + _ = conn.Close() return nil } } } func (mdb *MultiDB) slaveCron() { - repl := mdb.replication + repl := mdb.slaveStatus if repl.masterConn == nil { return } + // check master timeout replTimeout := 60 * time.Second if config.Properties.ReplTimeout != 0 { replTimeout = time.Duration(config.Properties.ReplTimeout) * time.Second @@ -427,9 +412,9 @@ func (repl *slaveStatus) sendAck2Master() error { func (mdb *MultiDB) reconnectWithMaster() error { logger.Info("reconnecting with master") - mdb.replication.mutex.Lock() - defer mdb.replication.mutex.Unlock() - mdb.replication.stopSlaveWithMutex() + mdb.slaveStatus.mutex.Lock() + defer mdb.slaveStatus.mutex.Unlock() + mdb.slaveStatus.stopSlaveWithMutex() go mdb.setupMaster() return nil } diff --git a/database/replication_test.go b/database/replication_slave_test.go similarity index 84% rename from database/replication_test.go rename to database/replication_slave_test.go index d37fb05..107e919 100644 --- a/database/replication_test.go +++ b/database/replication_slave_test.go @@ -8,22 +8,12 @@ import ( "github.com/hdt3213/godis/redis/connection" "github.com/hdt3213/godis/redis/protocol" "github.com/hdt3213/godis/redis/protocol/asserts" - "sync/atomic" "testing" "time" ) -func TestReplication(t *testing.T) { - mdb := &MultiDB{} - mdb.dbSet = make([]*atomic.Value, 16) - for i := range mdb.dbSet { - singleDB := makeDB() - singleDB.index = i - holder := &atomic.Value{} - holder.Store(singleDB) - mdb.dbSet[i] = holder - } - mdb.replication = initReplStatus() +func TestReplicationSlaveSide(t *testing.T) { + mdb := mockServer() masterCli, err := client.MakeClient("127.0.0.1:6379") if err != nil { t.Error(err) @@ -33,7 +23,7 @@ func TestReplication(t *testing.T) { // sync with master ret := masterCli.Send(utils.ToCmdLine("set", "1", "1")) asserts.AssertStatusReply(t, ret, "OK") - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() ret = mdb.Exec(conn, utils.ToCmdLine("SLAVEOF", "127.0.0.1", "6379")) asserts.AssertStatusReply(t, ret, "OK") success := false @@ -74,7 +64,7 @@ func TestReplication(t *testing.T) { t.Error("sync failed") return } - err = mdb.replication.sendAck2Master() + err = mdb.slaveStatus.sendAck2Master() if err != nil { t.Error(err) return @@ -83,8 +73,8 @@ func TestReplication(t *testing.T) { // test reconnect config.Properties.ReplTimeout = 1 - _ = mdb.replication.masterConn.Close() - mdb.replication.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout + _ = mdb.slaveStatus.masterConn.Close() + mdb.slaveStatus.lastRecvTime = time.Now().Add(-time.Hour) // mock timeout mdb.slaveCron() time.Sleep(3 * time.Second) ret = masterCli.Send(utils.ToCmdLine("set", "1", "3")) @@ -134,7 +124,7 @@ func TestReplication(t *testing.T) { return } - err = mdb.replication.close() + err = mdb.slaveStatus.close() if err != nil { t.Error("cannot close") } diff --git a/database/sys_test.go b/database/sys_test.go index 7d8828f..816c5be 100644 --- a/database/sys_test.go +++ b/database/sys_test.go @@ -20,7 +20,7 @@ func TestPing(t *testing.T) { func TestAuth(t *testing.T) { passwd := utils.RandString(10) - c := &connection.FakeConn{} + c := connection.NewFakeConn() ret := testServer.Exec(c, utils.ToCmdLine("AUTH")) asserts.AssertErrReply(t, ret, "ERR wrong number of arguments for 'auth' command") ret = testServer.Exec(c, utils.ToCmdLine("AUTH", passwd)) diff --git a/interface/redis/conn.go b/interface/redis/conn.go index 05ca534..785538f 100644 --- a/interface/redis/conn.go +++ b/interface/redis/conn.go @@ -2,7 +2,9 @@ package redis // Connection represents a connection with redis client type Connection interface { - Write([]byte) error + Write([]byte) (int, error) + Close() error + SetPassword(string) GetPassword() string @@ -12,7 +14,6 @@ type Connection interface { SubsCount() int GetChannels() []string - // used for `Multi` command InMultiState() bool SetMultiState(bool) GetQueuedCmdLine() [][][]byte @@ -22,10 +23,12 @@ type Connection interface { AddTxError(err error) GetTxErrors() []error - // used for multi database GetDBIndex() int SelectDB(int) - // returns role of conn, such as connection with client, connection with master node - GetRole() int32 - SetRole(int32) + + SetSlave() + IsSlave() bool + + SetMaster() + IsMaster() bool } diff --git a/lib/logger/logger.go b/lib/logger/logger.go index cff2589..f268df4 100644 --- a/lib/logger/logger.go +++ b/lib/logger/logger.go @@ -107,6 +107,13 @@ func Error(v ...interface{}) { logger.Println(v...) } +func Errorf(format string, v ...interface{}) { + mu.Lock() + defer mu.Unlock() + setPrefix(ERROR) + logger.Println(fmt.Sprintf(format, v...)) +} + // Fatal prints error log then stop the program func Fatal(v ...interface{}) { mu.Lock() diff --git a/lib/utils/rand_string.go b/lib/utils/rand_string.go index ba675dc..2aed53c 100644 --- a/lib/utils/rand_string.go +++ b/lib/utils/rand_string.go @@ -16,3 +16,13 @@ func RandString(n int) string { } return string(b) } + +var hexLetters = []rune("0123456789abcdef") + +func RandHexString(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = hexLetters[rand.Intn(len(hexLetters))] + } + return string(b) +} diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 1c1ef06..f701f88 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -82,7 +82,7 @@ func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply { for _, channel := range channels { if subscribe0(hub, channel, c) { - _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) + _, _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) } } return &protocol.NoReply{} @@ -117,13 +117,13 @@ func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply { defer db.subsLocker.UnLocks(channels...) if len(channels) == 0 { - _ = c.Write(unSubscribeNothing) + _, _ = c.Write(unSubscribeNothing) return &protocol.NoReply{} } for _, channel := range channels { if unsubscribe0(db, channel, c) { - _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) + _, _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) } } return &protocol.NoReply{} @@ -151,7 +151,7 @@ func Publish(hub *Hub, args [][]byte) redis.Reply { replyArgs[0] = messageBytes replyArgs[1] = []byte(channel) replyArgs[2] = message - _ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes()) + _, _ = client.Write(protocol.MakeMultiBulkReply(replyArgs).ToBytes()) return true }) return protocol.MakeIntReply(int64(subscribers.Len())) diff --git a/redis.conf b/redis.conf index e35e2c1..c15e7d4 100644 --- a/redis.conf +++ b/redis.conf @@ -2,6 +2,6 @@ bind 0.0.0.0 port 6399 maxclients 128 -#appendonly no -#appendfilename appendonly.aof +appendonly yes +appendfilename appendonly.aof #dbfilename test.rdb diff --git a/redis/connection/conn.go b/redis/connection/conn.go index 7a0a09f..66fbfbd 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -1,7 +1,6 @@ package connection import ( - "bytes" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/sync/wait" "net" @@ -10,21 +9,24 @@ import ( ) const ( - // NormalCli is client with user - NormalCli = iota - // ReplicationRecvCli is fake client with replication master - ReplicationRecvCli + // flagSlave means this a connection with slave + flagSlave = uint64(1 << iota) + // flagSlave means this a connection with master + flagMaster + // flagMulti means this connection is within a transaction + flagMulti ) // Connection represents a connection with a redis-cli type Connection struct { conn net.Conn - // waiting until protocol finished - waitingReply wait.Wait + // waiting until finish sending data, used for graceful shutdown + sendingData wait.Wait // lock while server sending response - mu sync.Mutex + mu sync.Mutex + flags uint64 // subscribing channels subs map[string]bool @@ -33,14 +35,12 @@ type Connection struct { password string // queued commands for `multi` - multiState bool - queue [][][]byte - watching map[string]uint32 - txErrors []error + queue [][][]byte + watching map[string]uint32 + txErrors []error // selected db selectedDB int - role int32 } var connPool = sync.Pool{ @@ -49,11 +49,6 @@ var connPool = sync.Pool{ }, } -// GetConnPool returns the connection pool pointer for putting and getting connection -func (c *Connection) GetConnPool() *sync.Pool { - return &connPool -} - // RemoteAddr returns the remote network address func (c *Connection) RemoteAddr() net.Addr { return c.conn.RemoteAddr() @@ -61,8 +56,15 @@ func (c *Connection) RemoteAddr() net.Addr { // Close disconnect with the client func (c *Connection) Close() error { - c.waitingReply.WaitWithTimeout(10 * time.Second) + c.sendingData.WaitWithTimeout(10 * time.Second) _ = c.conn.Close() + c.subs = nil + c.password = "" + c.queue = nil + c.watching = nil + c.txErrors = nil + c.selectedDB = 0 + connPool.Put(c) return nil } @@ -80,17 +82,16 @@ func NewConn(conn net.Conn) *Connection { } // Write sends response to client over tcp connection -func (c *Connection) Write(b []byte) error { +func (c *Connection) Write(b []byte) (int, error) { if len(b) == 0 { - return nil + return 0, nil } - c.waitingReply.Add(1) + c.sendingData.Add(1) defer func() { - c.waitingReply.Done() + c.sendingData.Done() }() - _, err := c.conn.Write(b) - return err + return c.conn.Write(b) } // Subscribe add current connection into subscribers of the given channel @@ -146,7 +147,7 @@ func (c *Connection) GetPassword() string { // InMultiState tells is connection in an uncommitted transaction func (c *Connection) InMultiState() bool { - return c.multiState + return c.flags&flagMulti > 0 } // SetMultiState sets transaction flag @@ -154,8 +155,10 @@ func (c *Connection) SetMultiState(state bool) { if !state { // reset data when cancel multi c.watching = nil c.queue = nil + c.flags &= ^flagMulti // clean multi flag + return } - c.multiState = state + c.flags |= flagMulti } // GetQueuedCmdLine returns queued commands of current transaction @@ -183,18 +186,6 @@ func (c *Connection) ClearQueuedCmds() { c.queue = nil } -// GetRole returns role of connection, such as connection with master -func (c *Connection) GetRole() int32 { - if c == nil { - return NormalCli - } - return c.role -} - -func (c *Connection) SetRole(r int32) { - c.role = r -} - // GetWatching returns watching keys and their version code when started watching func (c *Connection) GetWatching() map[string]uint32 { if c.watching == nil { @@ -213,24 +204,18 @@ func (c *Connection) SelectDB(dbNum int) { c.selectedDB = dbNum } -// FakeConn implements redis.Connection for test -type FakeConn struct { - Connection - buf bytes.Buffer +func (c *Connection) SetSlave() { + c.flags |= flagSlave } -// Write writes data to buffer -func (c *FakeConn) Write(b []byte) error { - c.buf.Write(b) - return nil +func (c *Connection) IsSlave() bool { + return c.flags&flagSlave > 0 } -// Clean resets the buffer -func (c *FakeConn) Clean() { - c.buf.Reset() +func (c *Connection) SetMaster() { + c.flags |= flagMaster } -// Bytes returns written data -func (c *FakeConn) Bytes() []byte { - return c.buf.Bytes() +func (c *Connection) IsMaster() bool { + return c.flags&flagMaster > 0 } diff --git a/redis/connection/fake.go b/redis/connection/fake.go new file mode 100644 index 0000000..4718164 --- /dev/null +++ b/redis/connection/fake.go @@ -0,0 +1,79 @@ +package connection + +import ( + "bytes" + "io" + "sync" +) + +// FakeConn implements redis.Connection for test +type FakeConn struct { + Connection + buf bytes.Buffer + wait chan struct{} + closed bool + mu sync.Mutex +} + +func NewFakeConn() *FakeConn { + c := &FakeConn{} + return c +} + +// Write writes data to buffer +func (c *FakeConn) Write(b []byte) (int, error) { + if c.closed { + return 0, io.EOF + } + n, _ := c.buf.Write(b) + c.notify() + return n, nil +} + +func (c *FakeConn) notify() { + if c.wait != nil { + c.mu.Lock() + if c.wait != nil { + close(c.wait) + c.wait = nil + } + c.mu.Unlock() + } +} + +func (c *FakeConn) waiting() { + c.mu.Lock() + c.wait = make(chan struct{}) + c.mu.Unlock() + <-c.wait +} + +// Read reads data from buffer +func (c *FakeConn) Read(p []byte) (int, error) { + n, err := c.buf.Read(p) + if err == io.EOF { + if c.closed { + return 0, io.EOF + } + c.waiting() + return c.buf.Read(p) + } + return n, err +} + +// Clean resets the buffer +func (c *FakeConn) Clean() { + c.wait = make(chan struct{}) + c.buf.Reset() +} + +// Bytes returns written data +func (c *FakeConn) Bytes() []byte { + return c.buf.Bytes() +} + +func (c *FakeConn) Close() error { + c.closed = true + c.notify() + return nil +} diff --git a/redis/server/pubsub_test.go b/redis/server/pubsub_test.go index 40dfb71..bada3f4 100644 --- a/redis/server/pubsub_test.go +++ b/redis/server/pubsub_test.go @@ -13,7 +13,7 @@ func TestPublish(t *testing.T) { hub := pubsub.MakeHub() channel := utils.RandString(5) msg := utils.RandString(5) - conn := &connection.FakeConn{} + conn := connection.NewFakeConn() pubsub.Subscribe(hub, conn, utils.ToCmdLine(channel)) conn.Clean() // clean subscribe success pubsub.Publish(hub, utils.ToCmdLine(channel, msg)) diff --git a/redis/server/server.go b/redis/server/server.go index a48962f..3c81dbc 100644 --- a/redis/server/server.go +++ b/redis/server/server.go @@ -50,8 +50,6 @@ func (h *Handler) closeClient(client *connection.Connection) { _ = client.Close() h.db.AfterClientClose(client) h.activeConn.Delete(client) - - client.GetConnPool().Put(client) } // Handle receives and executes redis commands @@ -78,7 +76,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { } // protocol err errReply := protocol.MakeErrReply(payload.Err.Error()) - err := client.Write(errReply.ToBytes()) + _, err := client.Write(errReply.ToBytes()) if err != nil { h.closeClient(client) logger.Info("connection closed: " + client.RemoteAddr().String()) @@ -97,9 +95,9 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) { } result := h.db.Exec(client, r.Args) if result != nil { - _ = client.Write(result.ToBytes()) + _, _ = client.Write(result.ToBytes()) } else { - _ = client.Write(unknownErrReplyBytes) + _, _ = client.Write(unknownErrReplyBytes) } } }