Implement checking wether snapshot is in progress before continuing

This commit is contained in:
Kelvin Clement Mwinuka
2024-01-27 01:46:16 +08:00
parent 4895b109b8
commit 7e59de20a2
6 changed files with 42 additions and 15 deletions

View File

@@ -4,11 +4,14 @@ import (
"encoding/json" "encoding/json"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
"time"
) )
type SnapshotOpts struct { type SnapshotOpts struct {
config utils.Config config utils.Config
data map[string]interface{} data map[string]interface{}
startSnapshot func()
finishSnapshot func()
} }
type Snapshot struct { type Snapshot struct {
@@ -23,7 +26,8 @@ func NewFSMSnapshot(opts SnapshotOpts) *Snapshot {
// Persist implements FSMSnapshot interface // Persist implements FSMSnapshot interface
func (s *Snapshot) Persist(sink raft.SnapshotSink) error { func (s *Snapshot) Persist(sink raft.SnapshotSink) error {
// TODO: Turn on snapshot in-progress flag s.options.startSnapshot()
o, err := json.Marshal(s.options.data) o, err := json.Marshal(s.options.data)
if err != nil { if err != nil {
@@ -36,10 +40,12 @@ func (s *Snapshot) Persist(sink raft.SnapshotSink) error {
return err return err
} }
<-time.After(5 * time.Second)
return nil return nil
} }
// Release implements FSMSnapshot interface // Release implements FSMSnapshot interface
func (s *Snapshot) Release() { func (s *Snapshot) Release() {
// TODO: Turn off snapshot in-progress flag s.options.finishSnapshot()
} }

View File

@@ -76,8 +76,10 @@ func (fsm *FSM) Apply(log *raft.Log) interface{} {
// Snapshot implements raft.FSM interface // Snapshot implements raft.FSM interface
func (fsm *FSM) Snapshot() (raft.FSMSnapshot, error) { func (fsm *FSM) Snapshot() (raft.FSMSnapshot, error) {
return NewFSMSnapshot(SnapshotOpts{ return NewFSMSnapshot(SnapshotOpts{
config: fsm.options.Config, config: fsm.options.Config,
data: fsm.options.Server.GetState(), data: fsm.options.Server.GetState(),
startSnapshot: fsm.options.Server.StartSnapshot,
finishSnapshot: fsm.options.Server.FinishSnapshot,
}), nil }), nil
} }

View File

@@ -144,10 +144,10 @@ func (r *Raft) HasJoinedCluster() bool {
} }
func (r *Raft) AddVoter( func (r *Raft) AddVoter(
id raft.ServerID, id raft.ServerID,
address raft.ServerAddress, address raft.ServerAddress,
prevIndex uint64, prevIndex uint64,
timeout time.Duration, timeout time.Duration,
) error { ) error {
if r.IsRaftLeader() { if r.IsRaftLeader() {
raftConfig := r.raft.GetConfiguration() raftConfig := r.raft.GetConfiguration()

View File

@@ -47,3 +47,11 @@ func (server *Server) raftApply(ctx context.Context, cmd []string) ([]byte, erro
return r.Response, nil return r.Response, nil
} }
func (server *Server) StartSnapshot() {
server.SnapshotInProgress.Store(true)
}
func (server *Server) FinishSnapshot() {
server.SnapshotInProgress.Store(false)
}

View File

@@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"github.com/echovault/echovault/src/memberlist" "github.com/echovault/echovault/src/memberlist"
"github.com/echovault/echovault/src/modules/acl" "github.com/echovault/echovault/src/modules/acl"
@@ -36,6 +37,8 @@ type Server struct {
ACL *acl.ACL ACL *acl.ACL
PubSub *pubsub.PubSub PubSub *pubsub.PubSub
SnapshotInProgress atomic.Bool
} }
func (server *Server) StartTCP(ctx context.Context) { func (server *Server) StartTCP(ctx context.Context) {
@@ -241,11 +244,17 @@ func (server *Server) Start(ctx context.Context) {
} }
func (server *Server) TakeSnapshot() error { func (server *Server) TakeSnapshot() error {
// TODO: Check if there's a snapshot currently in progress if server.SnapshotInProgress.Load() {
go func() { return errors.New("snapshot already in progress")
err := server.raft.TakeSnapshot() }
log.Println(err) if server.IsInCluster() {
}() // Handle snapshot in cluster mode
go func() {
err := server.raft.TakeSnapshot()
log.Println(err)
}()
}
// Handle snapshot in standalone mode
return nil return nil
} }

View File

@@ -19,6 +19,8 @@ type Server interface {
GetACL() interface{} GetACL() interface{}
GetPubSub() interface{} GetPubSub() interface{}
TakeSnapshot() error TakeSnapshot() error
StartSnapshot()
FinishSnapshot()
} }
type ContextServerID string type ContextServerID string