Added multi-database support to snapshot module

This commit is contained in:
Kelvin Mwinuka
2024-06-29 23:31:55 +08:00
parent 56f0a5ce61
commit 182195ebc3
6 changed files with 307 additions and 278 deletions

View File

@@ -250,34 +250,36 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
ApplyDeleteKey: echovault.raftApplyDeleteKey,
})
} else {
// TODO: Update snapshot engine to support multiple databases.
// Set up standalone snapshot engine
// echovault.snapshotEngine = snapshot.NewSnapshotEngine(
// snapshot.WithClock(echovault.clock),
// snapshot.WithDirectory(echovault.config.DataDir),
// snapshot.WithThreshold(echovault.config.SnapShotThreshold),
// snapshot.WithInterval(echovault.config.SnapshotInterval),
// snapshot.WithStartSnapshotFunc(echovault.startSnapshot),
// snapshot.WithFinishSnapshotFunc(echovault.finishSnapshot),
// snapshot.WithSetLatestSnapshotTimeFunc(echovault.setLatestSnapshot),
// snapshot.WithGetLatestSnapshotTimeFunc(echovault.getLatestSnapshotTime),
// snapshot.WithGetStateFunc(func() map[string]internal.KeyData {
// state := make(map[string]internal.KeyData)
// for k, v := range echovault.getState() {
// if data, ok := v.(internal.KeyData); ok {
// state[k] = data
// }
// }
// return state
// }),
// snapshot.WithSetKeyDataFunc(func(key string, data internal.KeyData) {
// ctx := context.Background()
// if err := echovault.setValues(ctx, map[string]interface{}{key: data.Value}); err != nil {
// log.Println(err)
// }
// echovault.setExpiry(ctx, key, data.ExpireAt, false)
// }),
// )
echovault.snapshotEngine = snapshot.NewSnapshotEngine(
snapshot.WithClock(echovault.clock),
snapshot.WithDirectory(echovault.config.DataDir),
snapshot.WithThreshold(echovault.config.SnapShotThreshold),
snapshot.WithInterval(echovault.config.SnapshotInterval),
snapshot.WithStartSnapshotFunc(echovault.startSnapshot),
snapshot.WithFinishSnapshotFunc(echovault.finishSnapshot),
snapshot.WithSetLatestSnapshotTimeFunc(echovault.setLatestSnapshot),
snapshot.WithGetLatestSnapshotTimeFunc(echovault.getLatestSnapshotTime),
snapshot.WithGetStateFunc(func() map[int]map[string]internal.KeyData {
state := make(map[int]map[string]internal.KeyData)
for database, data := range echovault.getState() {
state[database] = make(map[string]internal.KeyData)
for key, value := range data {
if keyData, ok := value.(internal.KeyData); ok {
state[database][key] = keyData
}
}
}
return state
}),
snapshot.WithSetKeyDataFunc(func(database int, key string, data internal.KeyData) {
ctx := context.WithValue(context.Background(), "Database", database)
if err := echovault.setValues(ctx, map[string]interface{}{key: data.Value}); err != nil {
log.Println(err)
}
echovault.setExpiry(ctx, key, data.ExpireAt, false)
}),
)
// TODO: Update AOF engine to support multiple databases.
// Set up standalone AOF engine

View File

@@ -883,7 +883,7 @@ func Test_Standalone(t *testing.T) {
tests := []struct {
name string
dataDir string
values map[string]string
values map[int]map[string]string
snapshotFunc func(mockServer *EchoVault) error
lastSaveFunc func(mockServer *EchoVault) (int, error)
wantLastSave int
@@ -891,11 +891,9 @@ func Test_Standalone(t *testing.T) {
{
name: "1. Snapshot in embedded instance",
dataDir: path.Join(dataDir, "embedded_instance"),
values: map[string]string{
"key5": "value5",
"key6": "value6",
"key7": "value7",
"key8": "value8",
values: map[int]map[string]string{
0: {"key5": "value-05", "key6": "value-06", "key7": "value-07", "key8": "value-08"},
1: {"key5": "value-15", "key6": "value-16", "key7": "value-17", "key8": "value-18"},
},
snapshotFunc: func(mockServer *EchoVault) error {
if _, err := mockServer.Save(); err != nil {
@@ -937,12 +935,15 @@ func Test_Standalone(t *testing.T) {
}()
// Trigger some write commands
for key, value := range test.values {
for database, data := range test.values {
_ = mockServer.SelectDB(database)
for key, value := range data {
if _, _, err = mockServer.Set(key, value, SetOptions{}); err != nil {
t.Error(err)
return
}
}
}
// Function to trigger snapshot save
if err = test.snapshotFunc(mockServer); err != nil {
@@ -962,7 +963,9 @@ func Test_Standalone(t *testing.T) {
}
// Check that all the key/value pairs have been restored into the store.
for key, value := range test.values {
for database, data := range test.values {
_ = mockServer.SelectDB(database)
for key, value := range data {
res, err := mockServer.Get(key)
if err != nil {
t.Error(err)
@@ -973,6 +976,7 @@ func Test_Standalone(t *testing.T) {
return
}
}
}
// Check that the lastsave is the time the last snapshot was taken.
lastSave, err := test.lastSaveFunc(mockServer)

View File

@@ -215,8 +215,7 @@ func (server *EchoVault) setValues(ctx context.Context, entries map[string]inter
ExpireAt: expireAt,
}
if !server.isInCluster() {
// TODO: Enable this when snapshot engine has support for multiple databases.
// server.snapshotEngine.IncrementChangeCount()
server.snapshotEngine.IncrementChangeCount()
}
}
@@ -318,6 +317,7 @@ func (server *EchoVault) getState() map[int]map[string]interface{} {
}
data := make(map[int]map[string]interface{})
for db, store := range server.store {
data[db] = make(map[string]interface{})
for k, v := range store {
data[db][k] = v
}

View File

@@ -25,7 +25,7 @@ func createEchoVaultWithConfig(conf config.Config) *EchoVault {
}
func presetValue(server *EchoVault, ctx context.Context, key string, value interface{}) error {
ctx = context.WithValue(ctx, "Database", "0")
ctx = context.WithValue(ctx, "Database", 0)
if err := server.setValues(ctx, map[string]interface{}{key: value}); err != nil {
return err
}
@@ -33,7 +33,7 @@ func presetValue(server *EchoVault, ctx context.Context, key string, value inter
}
func presetKeyData(server *EchoVault, ctx context.Context, key string, data internal.KeyData) {
ctx = context.WithValue(ctx, "Database", "0")
ctx = context.WithValue(ctx, "Database", 0)
_ = server.setValues(ctx, map[string]interface{}{key: data.Value})
server.setExpiry(ctx, key, data.ExpireAt, false)
}

View File

@@ -15,9 +15,17 @@
package snapshot
import (
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"github.com/echovault/echovault/internal"
"github.com/echovault/echovault/internal/clock"
"io"
"io/fs"
"log"
"os"
"path"
"sync/atomic"
"time"
)
@@ -38,10 +46,10 @@ type Engine struct {
snapshotThreshold uint64
startSnapshotFunc func()
finishSnapshotFunc func()
getStateFunc func() map[string]internal.KeyData
getStateFunc func() map[int]map[string]internal.KeyData
setLatestSnapshotTimeFunc func(msec int64)
getLatestSnapshotTimeFunc func() int64
setKeyDataFunc func(key string, data internal.KeyData)
setKeyDataFunc func(database int, key string, data internal.KeyData)
}
func WithClock(clock clock.Clock) func(engine *Engine) {
@@ -80,7 +88,7 @@ func WithFinishSnapshotFunc(f func()) func(engine *Engine) {
}
}
func WithGetStateFunc(f func() map[string]internal.KeyData) func(engine *Engine) {
func WithGetStateFunc(f func() map[int]map[string]internal.KeyData) func(engine *Engine) {
return func(engine *Engine) {
engine.getStateFunc = f
}
@@ -98,7 +106,7 @@ func WithGetLatestSnapshotTimeFunc(f func() int64) func(engine *Engine) {
}
}
func WithSetKeyDataFunc(f func(key string, data internal.KeyData)) func(engine *Engine) {
func WithSetKeyDataFunc(f func(database int, key string, data internal.KeyData)) func(engine *Engine) {
return func(engine *Engine) {
engine.setKeyDataFunc = f
}
@@ -113,10 +121,10 @@ func NewSnapshotEngine(options ...func(engine *Engine)) *Engine {
snapshotThreshold: 1000,
startSnapshotFunc: func() {},
finishSnapshotFunc: func() {},
getStateFunc: func() map[string]internal.KeyData {
return map[string]internal.KeyData{}
getStateFunc: func() map[int]map[string]internal.KeyData {
return make(map[int]map[string]internal.KeyData)
},
setKeyDataFunc: func(key string, data internal.KeyData) {},
setKeyDataFunc: func(database int, key string, data internal.KeyData) {},
setLatestSnapshotTimeFunc: func(msec int64) {},
getLatestSnapshotTimeFunc: func() int64 {
return 0
@@ -148,213 +156,211 @@ func NewSnapshotEngine(options ...func(engine *Engine)) *Engine {
}
func (engine *Engine) TakeSnapshot() error {
// TODO: Update to support multiple databases.
engine.startSnapshotFunc()
defer engine.finishSnapshotFunc()
//engine.startSnapshotFunc()
//defer engine.finishSnapshotFunc()
//
//// Extract current time
//msec := engine.clock.Now().UnixMilli()
//
//// Update manifest file to indicate the latest snapshot.
//// If manifest file does not exist, create it.
//// Manifest object will contain the following information:
//// 1. Hash of the snapshot contents.
//// 2. Unix time of the latest snapshot taken.
//// The information above will be used to determine whether a snapshot should be taken.
//// If the hash of the current state equals the hash in the manifest file, skip the snapshot.
//// Otherwise, take the snapshot and update the latest snapshot timestamp and hash in the manifest file.
//
//var firstSnapshot bool // Tracks whether the snapshot being attempted is the first one
//
//dirname := path.Join(engine.directory, "snapshots")
//if err := os.MkdirAll(dirname, os.ModePerm); err != nil {
// log.Println(err)
// return err
//}
//
//// Open manifest file
//var mf *os.File
//mf, err := os.Open(path.Join(dirname, "manifest.bin"))
//if err != nil {
// if errors.Is(err, fs.ErrNotExist) {
// // Create file if it does not exist
// mf, err = os.Create(path.Join(dirname, "manifest.bin"))
// if err != nil {
// log.Println(err)
// return err
// }
// firstSnapshot = true
// } else {
// log.Println(err)
// return err
// }
//}
//
//md, err := io.ReadAll(mf)
//if err != nil {
// log.Println(err)
// return err
//}
//if err := mf.Close(); err != nil {
// log.Println(err)
// return err
//}
//
//manifest := new(Manifest)
//
//if !firstSnapshot {
// if err = json.Unmarshal(md, manifest); err != nil {
// log.Println(err)
// return err
// }
//}
//
//// Get current state
//snapshotObject := internal.SnapshotObject{
// State: internal.FilterExpiredKeys(engine.clock.Now(), engine.getStateFunc()),
// LatestSnapshotMilliseconds: engine.getLatestSnapshotTimeFunc(),
//}
//out, err := json.Marshal(snapshotObject)
//if err != nil {
// log.Println(err)
// return err
//}
//
//snapshotHash := md5.Sum(out)
//if snapshotHash == manifest.LatestSnapshotHash {
// return errors.New("nothing new to snapshot")
//}
//
//// Update the snapshotObject
//snapshotObject.LatestSnapshotMilliseconds = msec
//// Marshal the updated snapshotObject
//out, err = json.Marshal(snapshotObject)
//if err != nil {
// log.Println(err)
// return err
//}
//
//// os.Create will replace the old manifest file
//mf, err = os.Create(path.Join(dirname, "manifest.bin"))
//if err != nil {
// log.Println(err)
// return err
//}
//
//// Write the latest manifest data
//manifest = &Manifest{
// LatestSnapshotHash: md5.Sum(out),
// LatestSnapshotMilliseconds: msec,
//}
//mo, err := json.Marshal(manifest)
//if err != nil {
// log.Println(err)
// return err
//}
//if _, err = mf.Write(mo); err != nil {
// log.Println(err)
// return err
//}
//if err = mf.Sync(); err != nil {
// log.Println(err)
//}
//if err = mf.Close(); err != nil {
// log.Println(err)
// return err
//}
//
//// Create snapshot directory
//dirname = path.Join(engine.directory, "snapshots", fmt.Sprintf("%d", msec))
//if err := os.MkdirAll(dirname, os.ModePerm); err != nil {
// return err
//}
//
//// Create snapshot file
//f, err := os.OpenFile(path.Join(dirname, "state.bin"), os.O_WRONLY|os.O_CREATE, os.ModePerm)
//if err != nil {
// log.Println(err)
// return err
//}
//defer func() {
// if err := f.Close(); err != nil {
// log.Println(err)
// }
//}()
//
//// Write state to file
//if _, err = f.Write(out); err != nil {
// return err
//}
//if err = f.Sync(); err != nil {
// log.Println(err)
//}
//
//// Set the latest snapshot in unix milliseconds
//engine.setLatestSnapshotTimeFunc(msec)
//
//// Reset the change count
//engine.resetChangeCount()
// Extract current time
msec := engine.clock.Now().UnixMilli()
// Update manifest file to indicate the latest snapshot.
// If manifest file does not exist, create it.
// Manifest object will contain the following information:
// 1. Hash of the snapshot contents.
// 2. Unix time of the latest snapshot taken.
// The information above will be used to determine whether a snapshot should be taken.
// If the hash of the current state equals the hash in the manifest file, skip the snapshot.
// Otherwise, take the snapshot and update the latest snapshot timestamp and hash in the manifest file.
var firstSnapshot bool // Tracks whether the snapshot being attempted is the first one
dirname := path.Join(engine.directory, "snapshots")
if err := os.MkdirAll(dirname, os.ModePerm); err != nil {
log.Println(err)
return err
}
// Open manifest file
var mf *os.File
mf, err := os.Open(path.Join(dirname, "manifest.bin"))
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
// Create file if it does not exist
mf, err = os.Create(path.Join(dirname, "manifest.bin"))
if err != nil {
log.Println(err)
return err
}
firstSnapshot = true
} else {
log.Println(err)
return err
}
}
md, err := io.ReadAll(mf)
if err != nil {
log.Println(err)
return err
}
if err := mf.Close(); err != nil {
log.Println(err)
return err
}
manifest := new(Manifest)
if !firstSnapshot {
if err = json.Unmarshal(md, manifest); err != nil {
log.Println(err)
return err
}
}
// Get current state
snapshotObject := internal.SnapshotObject{
State: internal.FilterExpiredKeys(engine.clock.Now(), engine.getStateFunc()),
LatestSnapshotMilliseconds: engine.getLatestSnapshotTimeFunc(),
}
out, err := json.Marshal(snapshotObject)
if err != nil {
log.Println(err)
return err
}
snapshotHash := md5.Sum(out)
if snapshotHash == manifest.LatestSnapshotHash {
return errors.New("nothing new to snapshot")
}
// Update the snapshotObject
snapshotObject.LatestSnapshotMilliseconds = msec
// Marshal the updated snapshotObject
out, err = json.Marshal(snapshotObject)
if err != nil {
log.Println(err)
return err
}
// os.Create will replace the old manifest file
mf, err = os.Create(path.Join(dirname, "manifest.bin"))
if err != nil {
log.Println(err)
return err
}
// Write the latest manifest data
manifest = &Manifest{
LatestSnapshotHash: md5.Sum(out),
LatestSnapshotMilliseconds: msec,
}
mo, err := json.Marshal(manifest)
if err != nil {
log.Println(err)
return err
}
if _, err = mf.Write(mo); err != nil {
log.Println(err)
return err
}
if err = mf.Sync(); err != nil {
log.Println(err)
}
if err = mf.Close(); err != nil {
log.Println(err)
return err
}
// Create snapshot directory
dirname = path.Join(engine.directory, "snapshots", fmt.Sprintf("%d", msec))
if err := os.MkdirAll(dirname, os.ModePerm); err != nil {
return err
}
// Create snapshot file
f, err := os.OpenFile(path.Join(dirname, "state.bin"), os.O_WRONLY|os.O_CREATE, os.ModePerm)
if err != nil {
log.Println(err)
return err
}
defer func() {
if err := f.Close(); err != nil {
log.Println(err)
}
}()
// Write state to file
if _, err = f.Write(out); err != nil {
return err
}
if err = f.Sync(); err != nil {
log.Println(err)
}
// Set the latest snapshot in unix milliseconds
engine.setLatestSnapshotTimeFunc(msec)
// Reset the change count
engine.resetChangeCount()
return nil
}
func (engine *Engine) Restore() error {
// TODO: Update to support multiple databases.
mf, err := os.Open(path.Join(engine.directory, "snapshots", "manifest.bin"))
if err != nil && errors.Is(err, fs.ErrNotExist) {
return errors.New("no snapshot manifest, skipping snapshot restore")
}
if err != nil {
return err
}
//mf, err := os.Open(path.Join(engine.directory, "snapshots", "manifest.bin"))
//if err != nil && errors.Is(err, fs.ErrNotExist) {
// return errors.New("no snapshot manifest, skipping snapshot restore")
//}
//if err != nil {
// return err
//}
//
//manifest := new(Manifest)
//
//md, err := io.ReadAll(mf)
//if err != nil {
// return err
//}
//
//if err = json.Unmarshal(md, manifest); err != nil {
// return err
//}
//
//if manifest.LatestSnapshotMilliseconds == 0 {
// return errors.New("no snapshot to restore")
//}
//
//sf, err := os.Open(path.Join(
// engine.directory,
// "snapshots",
// fmt.Sprintf("%d", manifest.LatestSnapshotMilliseconds),
// "state.bin"))
//if err != nil && errors.Is(err, fs.ErrNotExist) {
// return fmt.Errorf("snapshot file %d/state.bin not found, skipping snapshot", manifest.LatestSnapshotMilliseconds)
//}
//if err != nil {
// return err
//}
//
//sd, err := io.ReadAll(sf)
//if err != nil {
// return nil
//}
//
//snapshotObject := new(internal.SnapshotObject)
//
//if err = json.Unmarshal(sd, snapshotObject); err != nil {
// return err
//}
//
//engine.setLatestSnapshotTimeFunc(snapshotObject.LatestSnapshotMilliseconds)
//
//for key, data := range internal.FilterExpiredKeys(engine.clock.Now(), snapshotObject.State) {
// engine.setKeyDataFunc(key, data)
//}
//
//log.Println("successfully restored latest snapshot")
manifest := new(Manifest)
md, err := io.ReadAll(mf)
if err != nil {
return err
}
if err = json.Unmarshal(md, manifest); err != nil {
return err
}
if manifest.LatestSnapshotMilliseconds == 0 {
return errors.New("no snapshot to restore")
}
sf, err := os.Open(path.Join(
engine.directory,
"snapshots",
fmt.Sprintf("%d", manifest.LatestSnapshotMilliseconds),
"state.bin"))
if err != nil && errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("snapshot file %d/state.bin not found, skipping snapshot", manifest.LatestSnapshotMilliseconds)
}
if err != nil {
return err
}
sd, err := io.ReadAll(sf)
if err != nil {
return nil
}
snapshotObject := new(internal.SnapshotObject)
if err = json.Unmarshal(sd, snapshotObject); err != nil {
return err
}
engine.setLatestSnapshotTimeFunc(snapshotObject.LatestSnapshotMilliseconds)
for database, data := range internal.FilterExpiredKeys(engine.clock.Now(), snapshotObject.State) {
for key, value := range data {
engine.setKeyDataFunc(database, key, value)
}
}
log.Println("successfully restored latest snapshot")
return nil
}

View File

@@ -44,20 +44,33 @@ func Test_SnapshotEngine(t *testing.T) {
snapshotInProgress.Store(false)
}
state := map[string]internal.KeyData{
state := map[int]map[string]internal.KeyData{
0: {
"key1": {Value: "value-01", ExpireAt: clock.NewClock().Now().Add(13 * time.Second)},
"key2": {Value: "value-02", ExpireAt: clock.NewClock().Now().Add(43 * time.Minute)},
"key3": {Value: "value-03", ExpireAt: clock.NewClock().Now().Add(112 * time.Millisecond)},
"key4": {Value: "value-04", ExpireAt: clock.NewClock().Now().Add(23 * time.Second)},
"key5": {Value: "value-45", ExpireAt: clock.NewClock().Now().Add(121 * time.Millisecond)},
},
1: {
"key1": {Value: "value1", ExpireAt: clock.NewClock().Now().Add(13 * time.Second)},
"key2": {Value: "value2", ExpireAt: clock.NewClock().Now().Add(43 * time.Minute)},
"key3": {Value: "value3", ExpireAt: clock.NewClock().Now().Add(112 * time.Millisecond)},
"key4": {Value: "value4", ExpireAt: clock.NewClock().Now().Add(23 * time.Second)},
"key5": {Value: "value5", ExpireAt: clock.NewClock().Now().Add(121 * time.Millisecond)},
},
}
getStateFunc := func() map[string]internal.KeyData {
getStateFunc := func() map[int]map[string]internal.KeyData {
return state
}
restoredState := map[string]internal.KeyData{}
setKeyDataFunc := func(key string, data internal.KeyData) {
restoredState[key] = data
restoredState := make(map[int]map[string]internal.KeyData)
setKeyDataFunc := func(database int, key string, data internal.KeyData) {
if restoredState[database] == nil {
restoredState[database] = make(map[string]internal.KeyData)
}
restoredState[database][key] = data
}
var latestSnapshotTime int64
@@ -85,13 +98,15 @@ func Test_SnapshotEngine(t *testing.T) {
t.Error(err)
}
// Add more records to the state
// Add more records to each database in the state
for database, _ := range state {
for i := 0; i < 5; i++ {
state[fmt.Sprintf("key%d", i)] = internal.KeyData{
state[database][fmt.Sprintf("key%d", i)] = internal.KeyData{
Value: fmt.Sprintf("value%d", i),
ExpireAt: clock.NewClock().Now().Add(time.Duration(i) * time.Second),
}
}
}
// Take another snapshot
if err := snapshotEngine.TakeSnapshot(); err != nil {
@@ -106,12 +121,14 @@ func Test_SnapshotEngine(t *testing.T) {
t.Errorf("expected restored state to be length %d, got %d", len(state), len(restoredState))
}
for key, data := range restoredState {
if state[key].Value != data.Value {
t.Errorf("expected value %v for key %s, got %v", state[key].Value, key, data.Value)
for database, data := range restoredState {
for key, keyData := range data {
if state[database][key].Value != keyData.Value {
t.Errorf("expected value %v for key %s, got %v", state[database][key].Value, key, keyData.Value)
}
if !state[database][key].ExpireAt.Equal(keyData.ExpireAt) {
t.Errorf("expected expiry time %v for key %s, got %v", state[database][key].ExpireAt, key, keyData.ExpireAt)
}
if !state[key].ExpireAt.Equal(data.ExpireAt) {
t.Errorf("expected expiry time %v for key %s, got %v", state[key].ExpireAt, key, data.ExpireAt)
}
}