diff --git a/echovault/api_connection_test.go b/echovault/api_connection_test.go index 9507fa1..638e2b1 100644 --- a/echovault/api_connection_test.go +++ b/echovault/api_connection_test.go @@ -15,10 +15,138 @@ package echovault import ( + "bufio" + "bytes" + "github.com/echovault/echovault/internal" + "github.com/echovault/echovault/internal/constants" + "github.com/echovault/echovault/internal/modules/connection" + "github.com/tidwall/resp" "reflect" "testing" ) +func TestEchoVault_Hello(t *testing.T) { + t.Parallel() + + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + + conf := DefaultConfig() + conf.Port = uint16(port) + conf.RequirePass = false + + mockServer := createEchoVaultWithConfig(conf) + if err != nil { + t.Error(err) + return + } + go func() { + mockServer.Start() + }() + t.Cleanup(func() { + mockServer.ShutDown() + }) + + tests := []struct { + name string + command []resp.Value + wantRes []byte + }{ + { + name: "1. Hello", + command: []resp.Value{resp.StringValue("HELLO")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 1, + Name: "", + Protocol: 2, + Database: 0, + }, + ), + }, + { + name: "2. Hello 2", + command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("2")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 2, + Name: "", + Protocol: 2, + Database: 0, + }, + ), + }, + { + name: "3. Hello 3", + command: []resp.Value{resp.StringValue("HELLO"), resp.StringValue("3")}, + wantRes: connection.BuildHelloResponse( + internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: "", + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + }, + internal.ConnectionInfo{ + Id: 3, + Name: "", + Protocol: 3, + Database: 0, + }, + ), + }, + } + + for i := 0; i < len(tests); i++ { + conn, err := internal.GetConnection("localhost", port) + if err != nil { + t.Error(err) + return + } + client := resp.NewConn(conn) + + if err = client.WriteArray(tests[i].command); err != nil { + t.Error(err) + return + } + + buf := bufio.NewReader(conn) + res, err := internal.ReadMessage(buf) + if err != nil { + t.Error(err) + return + } + + if !bytes.Equal(tests[i].wantRes, res) { + t.Errorf("expected byte resposne:\n%s, \n\ngot:\n%s", string(tests[i].wantRes), string(res)) + return + } + + // Close connection + _ = conn.Close() + } +} + func TestEchoVault_SelectDB(t *testing.T) { t.Parallel() tests := []struct { diff --git a/echovault/echovault.go b/echovault/echovault.go index 83355c6..d2f6075 100644 --- a/echovault/echovault.go +++ b/echovault/echovault.go @@ -322,7 +322,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { echovault.aofEngine = aofEngine } - // If eviction policy is not noeviction, start a goroutine to evict keys every 100 milliseconds. + // If eviction policy is not noeviction, start a goroutine to evict keys at the configured interval. if echovault.config.EvictionPolicy != constants.NoEviction { go func() { ticker := time.NewTicker(echovault.config.EvictionInterval) @@ -639,6 +639,9 @@ func (server *EchoVault) ShutDown() { log.Printf("listener close: %v\n", err) } } + if !server.isInCluster() { + server.aofEngine.Close() + } if server.isInCluster() { server.raft.RaftShutdown() server.memberList.MemberListShutdown() diff --git a/echovault/echovault_test.go b/echovault/echovault_test.go index c043c24..6696d10 100644 --- a/echovault/echovault_test.go +++ b/echovault/echovault_test.go @@ -24,6 +24,7 @@ import ( "github.com/echovault/echovault/internal/clock" "github.com/echovault/echovault/internal/config" "github.com/echovault/echovault/internal/constants" + "github.com/go-test/deep" "github.com/tidwall/resp" "io" "math" @@ -591,6 +592,65 @@ func Test_Cluster(t *testing.T) { t.Errorf("expected response to contain \"%s\", got \"%s\"", expected, res.Error().Error()) } }) + + t.Run("Test_SnapshotRestore", func(t *testing.T) { + // TODO: Test snapshot creation and restoration on the cluster. + }) + + t.Run("Test_EvictExpiredTTL", func(t *testing.T) { + // TODO: Implement test for evicting expired keys on the cluster. + }) + + t.Run("Test_GetServerInfo", func(t *testing.T) { + nodeInfo := []internal.ServerInfo{ + { + Server: "echovault", + Version: constants.Version, + Id: nodes[0].serverId, + Mode: "cluster", + Role: "master", + Modules: nodes[0].server.ListModules(), + }, + { + Server: "echovault", + Version: constants.Version, + Id: nodes[1].serverId, + Mode: "cluster", + Role: "replica", + Modules: nodes[1].server.ListModules(), + }, + { + Server: "echovault", + Version: constants.Version, + Id: nodes[2].serverId, + Mode: "cluster", + Role: "replica", + Modules: nodes[2].server.ListModules(), + }, + { + Server: "echovault", + Version: constants.Version, + Id: nodes[3].serverId, + Mode: "cluster", + Role: "replica", + Modules: nodes[3].server.ListModules(), + }, + { + Server: "echovault", + Version: constants.Version, + Id: nodes[4].serverId, + Mode: "cluster", + Role: "replica", + Modules: nodes[4].server.ListModules(), + }, + } + for i := 0; i < len(nodes); i++ { + if diff := deep.Equal(nodes[i].server.GetServerInfo(), nodeInfo[i]); diff != nil { + t.Errorf("GetServerInfo() - node %d: %+v", i, err) + return + } + } + }) } func Test_Standalone(t *testing.T) { @@ -606,6 +666,7 @@ func Test_Standalone(t *testing.T) { Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, + ServerID: "Server_1", }), ) if err != nil { @@ -1089,4 +1150,23 @@ func Test_Standalone(t *testing.T) { } } }) + + t.Run("Test_EvictExpiredTTL", func(t *testing.T) { + // TODO: Implement test for evicting expired keys in standalone mode. + }) + + t.Run("Test_GetServerInfo", func(t *testing.T) { + wantInfo := internal.ServerInfo{ + Server: "echovault", + Version: constants.Version, + Id: mockServer.config.ServerID, + Mode: "standalone", + Role: "master", + Modules: mockServer.ListModules(), + } + info := mockServer.GetServerInfo() + if diff := deep.Equal(wantInfo, info); diff != nil { + t.Errorf("GetServerInfo(): %+v", err) + } + }) } diff --git a/echovault/keyspace.go b/echovault/keyspace.go index 90020e4..9bff841 100644 --- a/echovault/keyspace.go +++ b/echovault/keyspace.go @@ -333,6 +333,7 @@ func (server *EchoVault) getState() map[int]map[string]interface{} { // depending on whether an LFU or LRU strategy was used. func (server *EchoVault) updateKeysInCache(ctx context.Context, keys []string) error { database := ctx.Value("Database").(int) + for _, key := range keys { // Only update cache when in standalone mode or when raft leader. if server.isInCluster() || (server.isInCluster() && !server.raft.IsRaftLeader()) { diff --git a/go.mod b/go.mod index 7125dfb..90a80b8 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/armon/go-metrics v0.4.1 // indirect github.com/boltdb/bolt v1.3.1 // indirect github.com/fatih/color v1.13.0 // indirect + github.com/go-test/deep v1.1.1 // indirect github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c // indirect github.com/hashicorp/errwrap v1.0.0 // indirect github.com/hashicorp/go-hclog v1.5.0 // indirect diff --git a/go.sum b/go.sum index 2b1b8da..7a0739b 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= diff --git a/internal/aof/engine.go b/internal/aof/engine.go index c88555d..4319b7a 100644 --- a/internal/aof/engine.go +++ b/internal/aof/engine.go @@ -189,3 +189,12 @@ func (engine *Engine) Restore() error { } return nil } + +func (engine *Engine) Close() { + if err := engine.preambleStore.Close(); err != nil { + log.Printf("close preamble store error: %+v\n", engine) + } + if err := engine.appendStore.Close(); err != nil { + log.Printf("close append store error: %+v\n", engine) + } +} diff --git a/internal/aof/engine_test.go b/internal/aof/engine_test.go index c248fa8..47ecc76 100644 --- a/internal/aof/engine_test.go +++ b/internal/aof/engine_test.go @@ -216,5 +216,6 @@ func Test_AOFEngine(t *testing.T) { } } + engine.Close() _ = os.RemoveAll(directory) } diff --git a/internal/aof/log/store.go b/internal/aof/log/store.go index 2fd4975..09c71cf 100644 --- a/internal/aof/log/store.go +++ b/internal/aof/log/store.go @@ -253,5 +253,11 @@ func (store *Store) Truncate() error { func (store *Store) Close() error { store.mut.Lock() defer store.mut.Unlock() - return store.rw.Close() + if store.rw == nil { + return nil + } + if err := store.rw.Close(); err != nil { + return err + } + return nil } diff --git a/internal/aof/preamble/store.go b/internal/aof/preamble/store.go index e8cafa3..4c22235 100644 --- a/internal/aof/preamble/store.go +++ b/internal/aof/preamble/store.go @@ -172,5 +172,11 @@ func (store *Store) Restore() error { func (store *Store) Close() error { store.mut.Lock() defer store.mut.Unlock() - return store.rw.Close() + if store.rw == nil { + return nil + } + if err := store.rw.Close(); err != nil { + return err + } + return nil } diff --git a/internal/config/config.go b/internal/config/config.go index a528e4e..c01299f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -224,7 +224,7 @@ It is a plain text value by default but you can provide a SHA256 hash by adding } if len(*config) > 0 { - // Override configurations from file + // Override configurations from file. if f, err := os.Open(*config); err != nil { panic(err) } else { @@ -250,11 +250,11 @@ It is a plain text value by default but you can provide a SHA256 hash by adding } } - // If requirePass is set to true, then password must be provided as well + // If requirePass is set to true, then password must be provided as well. var err error = nil if conf.RequirePass && conf.Password == "" { - err = errors.New("password cannot be empty if requirePass is generic to true") + err = errors.New("password cannot be empty if requirePass is true") } return conf, err diff --git a/internal/snapshot/snapshot.go b/internal/snapshot/snapshot.go index 6c5132d..017e3e7 100644 --- a/internal/snapshot/snapshot.go +++ b/internal/snapshot/snapshot.go @@ -313,6 +313,11 @@ func (engine *Engine) Restore() error { if err != nil { return err } + defer func() { + if err := mf.Close(); err != nil { + log.Println(err) + } + }() manifest := new(Manifest) @@ -340,6 +345,11 @@ func (engine *Engine) Restore() error { if err != nil { return err } + defer func() { + if err := sf.Close(); err != nil { + log.Println(err) + } + }() sd, err := io.ReadAll(sf) if err != nil {