GetState receiver function now checks wether there's an active copy or state mutation before proceeding with the copy.

Handling of write commands is delayed until state copy is complete.
This commit is contained in:
Kelvin Clement Mwinuka
2024-01-31 01:55:51 +08:00
parent 25b2cb7154
commit ce0eabf865
6 changed files with 72 additions and 26 deletions

View File

@@ -35,7 +35,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net
key := cmd[1] key := cmd[1]
if !server.KeyExists(key) { if !server.KeyExists(key) {
return []byte("+nil\r\n\n"), nil return []byte("+nil\r\n\r\n"), nil
} }
_, err := server.KeyRLock(ctx, key) _, err := server.KeyRLock(ctx, key)
@@ -47,9 +47,9 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net
switch value.(type) { switch value.(type) {
default: default:
return []byte(fmt.Sprintf("+%v\r\n\n", value)), nil return []byte(fmt.Sprintf("+%v\r\n\r\n", value)), nil
case nil: case nil:
return []byte("+nil\r\n\n"), nil return []byte("+nil\r\n\r\n"), nil
} }
} }
@@ -84,7 +84,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *ne
bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(val), val))...) bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(val), val))...)
} }
bytes = append(bytes, []byte("\n")...) bytes = append(bytes, []byte("\r\n")...)
return bytes, nil return bytes, nil
} }

View File

@@ -13,8 +13,10 @@ import (
// Logging in clusters is handled in the raft layer. // Logging in clusters is handled in the raft layer.
type Opts struct { type Opts struct {
Config utils.Config Config utils.Config
GetState func() map[string]interface{} GetState func() map[string]interface{}
StartRewriteAOF func()
FinishRewriteAOF func()
} }
type Engine struct { type Engine struct {
@@ -65,6 +67,9 @@ func (engine *Engine) RewriteLog() error {
engine.mut.Lock() engine.mut.Lock()
defer engine.mut.Unlock() defer engine.mut.Unlock()
engine.options.StartRewriteAOF()
defer engine.options.FinishRewriteAOF()
// Get current state. // Get current state.
state := engine.options.GetState() state := engine.options.GetState()
o, err := json.Marshal(state) o, err := json.Marshal(state)

View File

@@ -47,11 +47,3 @@ func (server *Server) raftApply(ctx context.Context, cmd []string) ([]byte, erro
return r.Response, nil return r.Response, nil
} }
func (server *Server) StartSnapshot() {
server.SnapshotInProgress.Store(true)
}
func (server *Server) FinishSnapshot() {
server.SnapshotInProgress.Store(false)
}

View File

@@ -76,5 +76,16 @@ func (server *Server) SetValue(ctx context.Context, key string, value interface{
} }
func (server *Server) GetState() map[string]interface{} { func (server *Server) GetState() map[string]interface{} {
return server.store for {
if !server.StateCopyInProgress.Load() && !server.StateMutationInProgress.Load() {
server.StateCopyInProgress.Store(true)
break
}
}
data := make(map[string]interface{})
for k, v := range server.store {
data[k] = v
}
server.StateCopyInProgress.Store(false)
return data
} }

View File

@@ -16,7 +16,6 @@ import (
"log" "log"
"net" "net"
"os" "os"
"slices"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -42,6 +41,9 @@ type Server struct {
PubSub *pubsub.PubSub PubSub *pubsub.PubSub
SnapshotInProgress atomic.Bool SnapshotInProgress atomic.Bool
RewriteAOFInProgress atomic.Bool
StateCopyInProgress atomic.Bool
StateMutationInProgress atomic.Bool
LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds
SnapshotEngine *snapshot.Engine SnapshotEngine *snapshot.Engine
AOFEngine *aof.Engine AOFEngine *aof.Engine
@@ -154,6 +156,16 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
continue continue
} }
// If we're not in cluster mode and command/subcommand is a write command, wait for state copy to finish.
if utils.IsWriteCommand(command, subCommand) {
for {
if !server.StateCopyInProgress.Load() {
server.StateMutationInProgress.Store(true)
break
}
}
}
if !server.IsInCluster() || !synchronize { if !server.IsInCluster() || !synchronize {
if res, err := handler(ctx, cmd, server, &conn); err != nil { if res, err := handler(ctx, cmd, server, &conn); err != nil {
if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil { if _, err := w.Write([]byte(fmt.Sprintf("-%s\r\n\r\n", err.Error()))); err != nil {
@@ -163,14 +175,11 @@ func (server *Server) handleConnection(ctx context.Context, conn net.Conn) {
if _, err := w.Write(res); err != nil { if _, err := w.Write(res); err != nil {
log.Println(err) log.Println(err)
} }
if slices.Contains(append(command.Categories, subCommand.Categories...), utils.WriteCategory) { if utils.IsWriteCommand(command, subCommand) {
// Log successful write command // TODO: Queue successful write command instead of logging it directly
err := server.AOFEngine.LogCommand(message) // TODO: Handle error
if err != nil {
log.Println(err)
}
} }
} }
server.StateMutationInProgress.Store(false)
continue continue
} }
@@ -242,8 +251,10 @@ func (server *Server) Start(ctx context.Context) {
} else { } else {
// Initialize standalone AOF engine // Initialize standalone AOF engine
server.AOFEngine = aof.NewAOFEngine(aof.Opts{ server.AOFEngine = aof.NewAOFEngine(aof.Opts{
Config: conf, Config: conf,
GetState: server.GetState, GetState: server.GetState,
StartRewriteAOF: server.StartRewriteAOF,
FinishRewriteAOF: server.FinishRewriteAOF,
}) })
// Initialize and start standalone snapshot engine // Initialize and start standalone snapshot engine
server.SnapshotEngine = snapshot.NewSnapshotEngine(snapshot.Opts{ server.SnapshotEngine = snapshot.NewSnapshotEngine(snapshot.Opts{
@@ -288,6 +299,14 @@ func (server *Server) TakeSnapshot() error {
return nil return nil
} }
func (server *Server) StartSnapshot() {
server.SnapshotInProgress.Store(true)
}
func (server *Server) FinishSnapshot() {
server.SnapshotInProgress.Store(false)
}
func (server *Server) SetLatestSnapshot(msec int64) { func (server *Server) SetLatestSnapshot(msec int64) {
server.LatestSnapshotMilliseconds.Store(msec) server.LatestSnapshotMilliseconds.Store(msec)
} }
@@ -296,9 +315,24 @@ func (server *Server) GetLatestSnapshot() int64 {
return server.LatestSnapshotMilliseconds.Load() return server.LatestSnapshotMilliseconds.Load()
} }
func (server *Server) StartRewriteAOF() {
server.RewriteAOFInProgress.Store(true)
}
func (server *Server) FinishRewriteAOF() {
server.RewriteAOFInProgress.Store(false)
}
func (server *Server) RewriteAOF() error { func (server *Server) RewriteAOF() error {
// TODO: Make this concurrent if server.RewriteAOFInProgress.Load() {
return server.AOFEngine.RewriteLog() return errors.New("aof rewrite in progress")
}
go func() {
if err := server.AOFEngine.RewriteLog(); err != nil {
log.Println(err)
}
}()
return nil
} }
func (server *Server) ShutDown(ctx context.Context) { func (server *Server) ShutDown(ctx context.Context) {

View File

@@ -128,6 +128,10 @@ func GetSubCommand(command Command, cmd []string) interface{} {
return nil return nil
} }
func IsWriteCommand(command Command, subCommand SubCommand) bool {
return slices.Contains(append(command.Categories, subCommand.Categories...), WriteCategory)
}
func AbsInt(n int) int { func AbsInt(n int) int {
if n < 0 { if n < 0 {
return -n return -n