diff --git a/src/modules/get/commands.go b/src/modules/get/commands.go index 53f0555..7222b7e 100644 --- a/src/modules/get/commands.go +++ b/src/modules/get/commands.go @@ -35,7 +35,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net key := cmd[1] if !server.KeyExists(key) { - return []byte("+nil\r\n\n"), nil + return []byte("+nil\r\n\r\n"), nil } _, err := server.KeyRLock(ctx, key) @@ -47,9 +47,9 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net switch value.(type) { default: - return []byte(fmt.Sprintf("+%v\r\n\n", value)), nil + return []byte(fmt.Sprintf("+%v\r\n\r\n", value)), nil case nil: - return []byte("+nil\r\n\n"), nil + return []byte("+nil\r\n\r\n"), nil } } @@ -84,7 +84,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(val), val))...) } - bytes = append(bytes, []byte("\n")...) + bytes = append(bytes, []byte("\r\n")...) return bytes, nil } diff --git a/src/server/aof/aof.go b/src/server/aof/aof.go index 7e4b8b0..e211c0d 100644 --- a/src/server/aof/aof.go +++ b/src/server/aof/aof.go @@ -13,8 +13,10 @@ import ( // Logging in clusters is handled in the raft layer. type Opts struct { - Config utils.Config - GetState func() map[string]interface{} + Config utils.Config + GetState func() map[string]interface{} + StartRewriteAOF func() + FinishRewriteAOF func() } type Engine struct { @@ -65,6 +67,9 @@ func (engine *Engine) RewriteLog() error { engine.mut.Lock() defer engine.mut.Unlock() + engine.options.StartRewriteAOF() + defer engine.options.FinishRewriteAOF() + // Get current state. state := engine.options.GetState() o, err := json.Marshal(state) diff --git a/src/server/cluster.go b/src/server/cluster.go index 453eeb6..4e2e5ad 100644 --- a/src/server/cluster.go +++ b/src/server/cluster.go @@ -47,11 +47,3 @@ func (server *Server) raftApply(ctx context.Context, cmd []string) ([]byte, erro return r.Response, nil } - -func (server *Server) StartSnapshot() { - server.SnapshotInProgress.Store(true) -} - -func (server *Server) FinishSnapshot() { - server.SnapshotInProgress.Store(false) -} diff --git a/src/server/keyspace.go b/src/server/keyspace.go index dcb0bd8..05db192 100644 --- a/src/server/keyspace.go +++ b/src/server/keyspace.go @@ -76,5 +76,16 @@ func (server *Server) SetValue(ctx context.Context, key string, value interface{ } func (server *Server) GetState() map[string]interface{} { - return server.store + for { + if !server.StateCopyInProgress.Load() && !server.StateMutationInProgress.Load() { + server.StateCopyInProgress.Store(true) + break + } + } + data := make(map[string]interface{}) + for k, v := range server.store { + data[k] = v + } + server.StateCopyInProgress.Store(false) + return data } diff --git a/src/server/server.go b/src/server/server.go index f2da619..bb1b6a1 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -16,7 +16,6 @@ import ( "log" "net" "os" - "slices" "sync" "sync/atomic" "time" @@ -42,6 +41,9 @@ type Server struct { PubSub *pubsub.PubSub SnapshotInProgress atomic.Bool + RewriteAOFInProgress atomic.Bool + StateCopyInProgress atomic.Bool + StateMutationInProgress atomic.Bool LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds SnapshotEngine *snapshot.Engine AOFEngine *aof.Engine @@ -154,6 +156,16 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { continue } + // If we're not in cluster mode and command/subcommand is a write command, wait for state copy to finish. + if utils.IsWriteCommand(command, subCommand) { + for { + if !server.StateCopyInProgress.Load() { + server.StateMutationInProgress.Store(true) + break + } + } + } + if !server.IsInCluster() || !synchronize { if res, err := handler(ctx, cmd, server, &conn); err != nil { if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil { @@ -163,14 +175,11 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) { if _, err := w.Write(res); err != nil { log.Println(err) } - if slices.Contains(append(command.Categories, subCommand.Categories...), utils.WriteCategory) { - // Log successful write command - err := server.AOFEngine.LogCommand(message) // TODO: Handle error - if err != nil { - log.Println(err) - } + if utils.IsWriteCommand(command, subCommand) { + // TODO: Queue successful write command instead of logging it directly } } + server.StateMutationInProgress.Store(false) continue } @@ -242,8 +251,10 @@ func (server *Server) Start(ctx context.Context) { } else { // Initialize standalone AOF engine server.AOFEngine = aof.NewAOFEngine(aof.Opts{ - Config: conf, - GetState: server.GetState, + Config: conf, + GetState: server.GetState, + StartRewriteAOF: server.StartRewriteAOF, + FinishRewriteAOF: server.FinishRewriteAOF, }) // Initialize and start standalone snapshot engine server.SnapshotEngine = snapshot.NewSnapshotEngine(snapshot.Opts{ @@ -288,6 +299,14 @@ func (server *Server) TakeSnapshot() error { return nil } +func (server *Server) StartSnapshot() { + server.SnapshotInProgress.Store(true) +} + +func (server *Server) FinishSnapshot() { + server.SnapshotInProgress.Store(false) +} + func (server *Server) SetLatestSnapshot(msec int64) { server.LatestSnapshotMilliseconds.Store(msec) } @@ -296,9 +315,24 @@ func (server *Server) GetLatestSnapshot() int64 { return server.LatestSnapshotMilliseconds.Load() } +func (server *Server) StartRewriteAOF() { + server.RewriteAOFInProgress.Store(true) +} + +func (server *Server) FinishRewriteAOF() { + server.RewriteAOFInProgress.Store(false) +} + func (server *Server) RewriteAOF() error { - // TODO: Make this concurrent - return server.AOFEngine.RewriteLog() + if server.RewriteAOFInProgress.Load() { + return errors.New("aof rewrite in progress") + } + go func() { + if err := server.AOFEngine.RewriteLog(); err != nil { + log.Println(err) + } + }() + return nil } func (server *Server) ShutDown(ctx context.Context) { diff --git a/src/utils/utils.go b/src/utils/utils.go index 20ae105..d4c19d4 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -128,6 +128,10 @@ func GetSubCommand(command Command, cmd []string) interface{} { return nil } +func IsWriteCommand(command Command, subCommand SubCommand) bool { + return slices.Contains(append(command.Categories, subCommand.Categories...), WriteCategory) +} + func AbsInt(n int) int { if n < 0 { return -n