diff --git a/aof/aof.go b/aof/aof.go index e72e840..bfe02a1 100644 --- a/aof/aof.go +++ b/aof/aof.go @@ -61,6 +61,8 @@ type Persister struct { pausingAof sync.Mutex currentDB int listeners map[Listener]struct{} + // reuse cmdLine buffer + buffer []CmdLine } // NewPersister creates a new aof.Persister @@ -70,6 +72,7 @@ func NewPersister(db database.DBEngine, filename string, load bool, fsync string persister.aofFsync = strings.ToLower(fsync) persister.db = db persister.tmpDBMaker = tmpDBMaker + persister.currentDB = 0 if load { persister.LoadAof(0) } @@ -100,20 +103,6 @@ func (persister *Persister) RemoveListener(listener Listener) { delete(persister.listeners, listener) } -var wgPool = sync.Pool{ - New: func() interface{} { - return &sync.WaitGroup{} - }, -} - -func getWg() *sync.WaitGroup { - return wgPool.Get().(*sync.WaitGroup) -} - -func returnWg(wg *sync.WaitGroup) { - wgPool.Put(wg) -} - // SaveCmdLine send command to aof goroutine through channel func (persister *Persister) SaveCmdLine(dbIndex int, cmdLine CmdLine) { // aofChan will be set as nil temporarily during load aof see Persister.LoadAof @@ -121,16 +110,12 @@ func (persister *Persister) SaveCmdLine(dbIndex int, cmdLine CmdLine) { return } if persister.aofFsync == FsyncAlways { - // use WaitGroup to wait for saving finished - wg := getWg() - defer returnWg(wg) - wg.Add(1) - persister.aofChan <- &payload{ + p := &payload{ cmdLine: cmdLine, dbIndex: dbIndex, - wg: wg, } - wg.Wait() + persister.writeAof(p) + return } persister.aofChan <- &payload{ cmdLine: cmdLine, @@ -140,46 +125,42 @@ func (persister *Persister) SaveCmdLine(dbIndex int, cmdLine CmdLine) { // listenCmd listen aof channel and write into file func (persister *Persister) listenCmd() { - // serialized execution - var cmdLines []CmdLine - persister.currentDB = 0 for p := range persister.aofChan { - cmdLines = cmdLines[:0] // reuse underlying array - persister.pausingAof.Lock() // prevent other goroutines from pausing aof - // ensure aof is in the right database - if p.dbIndex != persister.currentDB { - // select db - selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex)) - cmdLines = append(cmdLines, selectCmd) - data := protocol.MakeMultiBulkReply(selectCmd).ToBytes() - _, err := persister.aofFile.Write(data) - if err != nil { - logger.Warn(err) - persister.pausingAof.Unlock() - continue // skip this command - } - persister.currentDB = p.dbIndex - } - // save command - data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes() - cmdLines = append(cmdLines, p.cmdLine) + persister.writeAof(p) + } + persister.aofFinished <- struct{}{} +} + +func (persister *Persister) writeAof(p *payload) { + persister.buffer = persister.buffer[:0] // reuse underlying array + persister.pausingAof.Lock() // prevent other goroutines from pausing aof + defer persister.pausingAof.Unlock() + // ensure aof is in the right database + if p.dbIndex != persister.currentDB { + // select db + selectCmd := utils.ToCmdLine("SELECT", strconv.Itoa(p.dbIndex)) + persister.buffer = append(persister.buffer, selectCmd) + data := protocol.MakeMultiBulkReply(selectCmd).ToBytes() _, err := persister.aofFile.Write(data) if err != nil { logger.Warn(err) + return // skip this command } - for listener := range persister.listeners { - listener.Callback(cmdLines) - } - if persister.aofFsync == FsyncAlways { - _ = persister.aofFile.Sync() - } - if p.wg != nil { - p.wg.Done() - } - persister.pausingAof.Unlock() - + persister.currentDB = p.dbIndex + } + // save command + data := protocol.MakeMultiBulkReply(p.cmdLine).ToBytes() + persister.buffer = append(persister.buffer, p.cmdLine) + _, err := persister.aofFile.Write(data) + if err != nil { + logger.Warn(err) + } + for listener := range persister.listeners { + listener.Callback(persister.buffer) + } + if persister.aofFsync == FsyncAlways { + _ = persister.aofFile.Sync() } - persister.aofFinished <- struct{}{} } // LoadAof read aof file, can only be used before Persister.listenCmd started @@ -231,6 +212,13 @@ func (persister *Persister) LoadAof(maxBytes int) { if protocol.IsErrorReply(ret) { logger.Error("exec err", string(ret.ToBytes())) } + if strings.ToLower(string(r.Args[0])) == "select" { + // execSelect success, here must be no error + dbIndex, err := strconv.Atoi(string(r.Args[1])) + if err == nil { + persister.currentDB = dbIndex + } + } } } diff --git a/database/persistence_test.go b/database/persistence_test.go index d8421d7..41a77c6 100644 --- a/database/persistence_test.go +++ b/database/persistence_test.go @@ -53,7 +53,8 @@ func TestServerFsyncAlways(t *testing.T) { config.Properties.AppendFsync = aof.FsyncAlways server := NewStandaloneServer() conn := connection.NewFakeConn() - ret := server.Exec(conn, utils.ToCmdLine("set", "1", "1")) + server.Exec(conn, utils.ToCmdLine("del", "1")) + ret := server.Exec(conn, utils.ToCmdLine("incr", "1")) asserts.AssertNotError(t, ret) reader := NewStandaloneServer() ret = reader.Exec(conn, utils.ToCmdLine("get", "1")) @@ -71,7 +72,8 @@ func TestServerFsyncEverySec(t *testing.T) { config.Properties.AppendFsync = aof.FsyncEverySec server := NewStandaloneServer() conn := connection.NewFakeConn() - ret := server.Exec(conn, utils.ToCmdLine("set", "1", "1")) + server.Exec(conn, utils.ToCmdLine("del", "1")) + ret := server.Exec(conn, utils.ToCmdLine("incr", "1")) asserts.AssertNotError(t, ret) time.Sleep(1500 * time.Millisecond) reader := NewStandaloneServer()