Updated all test suites to include connection and server shutdown on cleanup.

This commit is contained in:
Kelvin Clement Mwinuka
2024-05-31 01:30:18 +08:00
parent 6f8511632e
commit c7560ce9dd
13 changed files with 13350 additions and 13106 deletions

View File

@@ -1,119 +0,0 @@
// Copyright 2024 Kelvin Clement Mwinuka
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package echovault
// func Test_raftApplyDeleteKey(t *testing.T) {
// nodes, err := makeCluster(5)
// if err != nil {
// t.Error(err)
// return
// }
//
// // Prepare the write data for the cluster.
// tests := []struct {
// key string
// value string
// }{
// {
// key: "key1",
// value: "value1",
// },
// {
// key: "key2",
// value: "value2",
// },
// {
// key: "key3",
// value: "value3",
// },
// }
//
// // Write all the data to the cluster leader.
// for i, test := range tests {
// node := nodes[0]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("SET"),
// resp.StringValue(test.key),
// resp.StringValue(test.value),
// }); err != nil {
// t.Errorf("could not write data to leader node (test %d): %v", i, err)
// }
// // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read response from leader node (test %d): %v", i, err)
// }
// if !strings.EqualFold(rd.String(), "ok") {
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
// }
// }
//
// // Check if the data has been replicated on a quorum (majority of the cluster).
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.String() == test.value {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// }
// // Delete key across raft cluster.
// if err = nodes[0].server.raftApplyDeleteKey(nodes[0].server.context, test.key); err != nil {
// t.Error(err)
// }
// }
//
// <-time.After(200 * time.Millisecond) // Yield to give key deletion time to take effect across cluster.
//
// // Check if the data is absent in quorum (majority of the cluster).
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.IsNull() {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("found value %s at key %s in cluster quorum", test.value, test.key)
// }
// }
// }

View File

@@ -73,7 +73,7 @@ type EchoVault struct {
mutex sync.Mutex // Mutex as only one goroutine can edit the LFU cache at a time. mutex sync.Mutex // Mutex as only one goroutine can edit the LFU cache at a time.
cache eviction.CacheLFU // LFU cache represented by a min head. cache eviction.CacheLFU // LFU cache represented by a min head.
} }
// LRU cache used when eviction policy is allkeys-lru or volatile-lru // LRU cache used when eviction policy is allkeys-lru or volatile-lru.
lruCache struct { lruCache struct {
mutex sync.Mutex // Mutex as only one goroutine can edit the LRU at a time. mutex sync.Mutex // Mutex as only one goroutine can edit the LRU at a time.
cache eviction.CacheLRU // LRU cache represented by a max head. cache eviction.CacheLRU // LRU cache represented by a max head.
@@ -95,9 +95,12 @@ type EchoVault struct {
rewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress. rewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress.
stateCopyInProgress atomic.Bool // Atomic boolean that's true when actively copying state for snapshotting or preamble generation. stateCopyInProgress atomic.Bool // Atomic boolean that's true when actively copying state for snapshotting or preamble generation.
stateMutationInProgress atomic.Bool // Atomic boolean that is set to true when state mutation is in progress. stateMutationInProgress atomic.Bool // Atomic boolean that is set to true when state mutation is in progress.
latestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds latestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds.
snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode.
aofEngine *aof.Engine // AOF engine for standalone mode aofEngine *aof.Engine // AOF engine for standalone mode.
listener net.Listener // TCP listener.
quit chan struct{} // Channel that signals the closing of all client connections.
} }
// WithContext is an options that for the NewEchoVault function that allows you to // WithContext is an options that for the NewEchoVault function that allows you to
@@ -142,6 +145,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) {
commands = append(commands, str.Commands()...) commands = append(commands, str.Commands()...)
return commands return commands
}(), }(),
quit: make(chan struct{}),
} }
for _, option := range options { for _, option := range options {
@@ -320,30 +324,35 @@ func (server *EchoVault) startTCP() {
KeepAlive: 200 * time.Millisecond, KeepAlive: 200 * time.Millisecond,
} }
listener, err := listenConfig.Listen(server.context, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port)) listener, err := listenConfig.Listen(
server.context,
"tcp",
fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port),
)
if err != nil { if err != nil {
log.Fatal(err) log.Printf("listener error: %v", err)
return
} }
if !conf.TLS { if !conf.TLS {
// TCP // TCP
log.Printf("Starting TCP echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) log.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
} }
if conf.TLS || conf.MTLS { if conf.TLS || conf.MTLS {
// TLS // TLS
if conf.TLS { if conf.TLS {
log.Printf("Starting mTLS echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) log.Printf("Starting mTLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
} else { } else {
log.Printf("Starting TLS echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) log.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port)
} }
var certificates []tls.Certificate var certificates []tls.Certificate
for _, certKeyPair := range conf.CertKeyPairs { for _, certKeyPair := range conf.CertKeyPairs {
c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1]) c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1])
if err != nil { if err != nil {
log.Fatal(err) log.Printf("load cert key pair: %v\n", err)
return
} }
certificates = append(certificates, c) certificates = append(certificates, c)
} }
@@ -356,14 +365,15 @@ func (server *EchoVault) startTCP() {
for _, c := range conf.ClientCAs { for _, c := range conf.ClientCAs {
ca, err := os.Open(c) ca, err := os.Open(c)
if err != nil { if err != nil {
log.Fatal(err) log.Printf("client cert open: %v\n", err)
return
} }
certBytes, err := io.ReadAll(ca) certBytes, err := io.ReadAll(ca)
if err != nil { if err != nil {
log.Fatal(err) log.Printf("client cert read: %v\n", err)
} }
if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok { if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok {
log.Fatal(err) log.Printf("client cert append: %v\n", err)
} }
} }
} }
@@ -375,15 +385,22 @@ func (server *EchoVault) startTCP() {
}) })
} }
// Listen to connection server.listener = listener
// Listen to connection.
for { for {
conn, err := listener.Accept() select {
if err != nil { case <-server.quit:
log.Println("Could not establish connection") return
continue default:
conn, err := server.listener.Accept()
if err != nil {
log.Printf("listener error: %v\n", err)
continue
}
// Read loop for connection
go server.handleConnection(conn)
} }
// Read loop for connection
go server.handleConnection(conn)
} }
} }
@@ -536,6 +553,13 @@ func (server *EchoVault) rewriteAOF() error {
// ShutDown gracefully shuts down the EchoVault instance. // ShutDown gracefully shuts down the EchoVault instance.
// This function shuts down the memberlist and raft layers. // This function shuts down the memberlist and raft layers.
func (server *EchoVault) ShutDown() { func (server *EchoVault) ShutDown() {
if server.listener != nil {
go func() { server.quit <- struct{}{} }()
log.Println("closing tcp listener...")
if err := server.listener.Close(); err != nil {
log.Printf("listener close: %v\n", err)
}
}
if server.isInCluster() { if server.isInCluster() {
server.raft.RaftShutdown() server.raft.RaftShutdown()
server.memberList.MemberListShutdown() server.memberList.MemberListShutdown()

View File

@@ -23,12 +23,14 @@ import (
"github.com/echovault/echovault/internal" "github.com/echovault/echovault/internal"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"io" "io"
"math"
"net" "net"
"os" "os"
"path" "path"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time"
) )
type ClientServerPair struct { type ClientServerPair struct {
@@ -38,6 +40,7 @@ type ClientServerPair struct {
raftPort int raftPort int
mlPort int mlPort int
bootstrapCluster bool bootstrapCluster bool
raw net.Conn
client *resp.Conn client *resp.Conn
server *EchoVault server *EchoVault
} }
@@ -62,8 +65,6 @@ func getBindAddr() net.IP {
return getBindAddrNet(0) return getBindAddrNet(0)
} }
var setupLock sync.Mutex
func setupServer( func setupServer(
serverId string, serverId string,
bootstrapCluster bool, bootstrapCluster bool,
@@ -73,25 +74,20 @@ func setupServer(
raftPort, raftPort,
mlPort int, mlPort int,
) (*EchoVault, error) { ) (*EchoVault, error) {
setupLock.Lock()
defer setupLock.Unlock()
ctx := context.Background()
config := DefaultConfig() config := DefaultConfig()
config.DataDir = "./testdata" config.DataDir = "./testdata"
config.ForwardCommand = true config.ForwardCommand = true
config.BindAddr = bindAddr config.BindAddr = bindAddr
config.JoinAddr = joinAddr config.JoinAddr = joinAddr
config.Port = uint16(port) config.Port = uint16(port)
// config.InMemory = true config.InMemory = true
config.ServerID = serverId config.ServerID = serverId
config.RaftBindPort = uint16(raftPort) config.RaftBindPort = uint16(raftPort)
config.MemberListBindPort = uint16(mlPort) config.MemberListBindPort = uint16(mlPort)
config.BootstrapCluster = bootstrapCluster config.BootstrapCluster = bootstrapCluster
return NewEchoVault( return NewEchoVault(
WithContext(ctx), WithContext(context.Background()),
WithConfig(config), WithConfig(config),
) )
} }
@@ -163,6 +159,7 @@ func makeCluster(size int) ([]ClientServerPair, error) {
raftPort: raftPort, raftPort: raftPort,
mlPort: memberlistPort, mlPort: memberlistPort,
bootstrapCluster: bootstrapCluster, bootstrapCluster: bootstrapCluster,
raw: conn,
client: client, client: client,
server: server, server: server,
} }
@@ -171,273 +168,318 @@ func makeCluster(size int) ([]ClientServerPair, error) {
return pairs, nil return pairs, nil
} }
// func Test_ClusterReplication(t *testing.T) { func Test_Cluster(t *testing.T) {
// nodes, err := makeCluster(5) nodes, err := makeCluster(5)
// if err != nil { if err != nil {
// t.Error(err) t.Error(err)
// return return
// } }
//
// // Prepare the write data for the cluster.
// tests := []struct {
// key string
// value string
// }{
// {
// key: "key1",
// value: "value1",
// },
// {
// key: "key2",
// value: "value2",
// },
// {
// key: "key3",
// value: "value3",
// },
// }
//
// // Write all the data to the cluster leader
// for i, test := range tests {
// node := nodes[0]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("SET"),
// resp.StringValue(test.key),
// resp.StringValue(test.value),
// }); err != nil {
// t.Errorf("could not write data to leader node (test %d): %v", i, err)
// }
// // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read response from leader node (test %d): %v", i, err)
// }
// if !strings.EqualFold(rd.String(), "ok") {
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
// }
// }
//
// // Check if the data has been replicated on a quorum (majority of the cluster).
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.String() == test.value {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// }
// }
// }
// func Test_ClusterDeleteKey(t *testing.T) { defer func() {
// nodes, err := makeCluster(5) for _, node := range nodes {
// if err != nil { _ = node.raw.Close()
// t.Error(err) node.server.ShutDown()
// return }
// } }()
//
// // Prepare the write data for the cluster
// tests := []struct {
// key string
// value string
// }{
// {
// key: "key1",
// value: "value1",
// },
// {
// key: "key2",
// value: "value2",
// },
// {
// key: "key3",
// value: "value3",
// },
// }
//
// // Write all the data to the cluster leader
// for i, test := range tests {
// node := nodes[0]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("SET"),
// resp.StringValue(test.key),
// resp.StringValue(test.value),
// }); err != nil {
// t.Errorf("could not write command to leader node (test %d): %v", i, err)
// }
// // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read response from leader node (test %d): %v", i, err)
// }
// if !strings.EqualFold(rd.String(), "ok") {
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
// }
// }
//
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
//
// // Check if the data has been replicated on a quorum (majority of the cluster).
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.String() == test.value {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// return
// }
// }
//
// // Delete the key on the leader node
// for i, test := range tests {
// node := nodes[0]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("DEL"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write command to leader node (test %d): %v", i, err)
// }
// // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read response from leader node (test %d): %v", i, err)
// }
// if rd.Integer() != 1 {
// t.Errorf("expected response for test %d to be 1, got %d", i, rd.Integer())
// }
// }
//
// // Check if the data is absent in quorum (majority of the cluster).
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.IsNull() {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// }
// }
// }
// func Test_CommandForwarded(t *testing.T) { // Prepare the write data for the cluster.
// nodes, err := makeCluster(5) tests := map[string][]struct {
// if err != nil { key string
// t.Error(err) value string
// return }{
// } "replication": {
// {key: "key1", value: "value1"},
// // Prepare the write data for the cluster {key: "key2", value: "value2"},
// tests := []struct { {key: "key3", value: "value3"},
// key string },
// value string "deletion": {
// }{ {key: "key4", value: "value4"},
// { {key: "key5", value: "value4"},
// key: "key1", {key: "key6", value: "value5"},
// value: "value1", },
// }, "raft-apply-delete": {
// { {key: "key7", value: "value7"},
// key: "key2", {key: "key8", value: "value8"},
// value: "value2", {key: "key9", value: "value9"},
// }, },
// { "forward": {
// key: "key3", {key: "key10", value: "value10"},
// value: "value3", {key: "key11", value: "value11"},
// }, {key: "key12", value: "value12"},
// } },
// }
// // Write all the data a random cluster follower.
// for i, test := range tests { t.Run("Test_Replication", func(t *testing.T) {
// // Send write command to follower node. tests := tests["replication"]
// node := nodes[1] // Write all the data to the cluster leader.
// if err := node.client.WriteArray([]resp.Value{ for i, test := range tests {
// resp.StringValue("SET"), node := nodes[0]
// resp.StringValue(test.key), if err := node.client.WriteArray([]resp.Value{
// resp.StringValue(test.value), resp.StringValue("SET"), resp.StringValue(test.key), resp.StringValue(test.value),
// }); err != nil { }); err != nil {
// t.Errorf("could not write data to leader node (test %d): %v", i, err) t.Errorf("could not write data to leader node (test %d): %v", i, err)
// } }
// // Read response and make sure we received "ok" response. // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue() rd, _, err := node.client.ReadValue()
// if err != nil { if err != nil {
// t.Errorf("could not read response from leader node (test %d): %v", i, err) t.Errorf("could not read response from leader node (test %d): %v", i, err)
// } }
// if !strings.EqualFold(rd.String(), "ok") { if !strings.EqualFold(rd.String(), "ok") {
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
// } }
// } }
//
// <-time.After(250 * time.Millisecond) // Short yield to allow change to take effect. <-time.After(200 * time.Millisecond) // Yield
//
// // Check if the data has been replicated on a quorum (majority of the cluster). // Check if the data has been replicated on a quorum (majority of the cluster).
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
// for i, test := range tests { for i, test := range tests {
// count := 0 count := 0
// for j := 0; j < len(nodes); j++ { for j := 0; j < len(nodes); j++ {
// node := nodes[j] node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{ if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"), resp.StringValue("GET"),
// resp.StringValue(test.key), resp.StringValue(test.key),
// }); err != nil { }); err != nil {
// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
// } }
// rd, _, err := node.client.ReadValue() rd, _, err := node.client.ReadValue()
// if err != nil { if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// } }
// if rd.String() == test.value { if rd.String() == test.value {
// count += 1 // If the expected value is found, increment the count. count += 1 // If the expected value is found, increment the count.
// } }
// } }
// // Fail if count is less than quorum. // Fail if count is less than quorum.
// if count < quorum { if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// } }
// } }
// } })
t.Run("Test_DeleteKey", func(t *testing.T) {
tests := tests["deletion"]
// Write all the data to the cluster leader.
for i, test := range tests {
node := nodes[0]
_, ok, err := node.server.Set(test.key, test.value, SetOptions{})
if err != nil {
t.Errorf("could not write command to leader node (test %d): %v", i, err)
}
if !ok {
t.Errorf("expected set for test %d ok = true, got ok = false", i)
}
}
<-time.After(200 * time.Millisecond) // Yield
// Check if the data has been replicated on a quorum (majority of the cluster).
quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
for i, test := range tests {
count := 0
for j := 0; j < len(nodes); j++ {
node := nodes[j]
if err := node.client.WriteArray([]resp.Value{
resp.StringValue("GET"),
resp.StringValue(test.key),
}); err != nil {
t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
}
rd, _, err := node.client.ReadValue()
if err != nil {
t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
}
if rd.String() == test.value {
count += 1 // If the expected value is found, increment the count.
}
}
// Fail if count is less than quorum.
if count < quorum {
t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
return
}
}
// Delete the key on the leader node
// 1. Prepare delete command.
command := []resp.Value{resp.StringValue("DEL")}
for _, test := range tests {
command = append(command, resp.StringValue(test.key))
}
// 2. Send delete command.
if err := nodes[0].client.WriteArray(command); err != nil {
t.Error(err)
return
}
res, _, err := nodes[0].client.ReadValue()
if err != nil {
t.Error(err)
return
}
// 3. Check the delete count is equal to length of tests.
if res.Integer() != len(tests) {
t.Errorf("expected delete response to be %d, got %d", len(tests), res.Integer())
}
<-time.After(200 * time.Millisecond) // Yield
// Check if the data is absent in quorum (majority of the cluster).
for i, test := range tests {
count := 0
for j := 0; j < len(nodes); j++ {
node := nodes[j]
if err := node.client.WriteArray([]resp.Value{
resp.StringValue("GET"),
resp.StringValue(test.key),
}); err != nil {
t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
}
rd, _, err := node.client.ReadValue()
if err != nil {
t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
}
if rd.IsNull() {
count += 1 // If the expected value is found, increment the count.
}
}
// Fail if count is less than quorum.
if count < quorum {
t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
}
}
})
t.Run("Test_raftApplyDeleteKey", func(t *testing.T) {
tests := tests["raft-apply-delete"]
// Write all the data to the cluster leader.
for i, test := range tests {
node := nodes[0]
_, ok, err := node.server.Set(test.key, test.value, SetOptions{})
if err != nil {
t.Errorf("could not write command to leader node (test %d): %v", i, err)
}
if !ok {
t.Errorf("expected set for test %d ok = true, got ok = false", i)
}
}
<-time.After(200 * time.Millisecond) // Yield
// Check if the data has been replicated on a quorum (majority of the cluster).
quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
for i, test := range tests {
count := 0
for j := 0; j < len(nodes); j++ {
node := nodes[j]
if err := node.client.WriteArray([]resp.Value{
resp.StringValue("GET"),
resp.StringValue(test.key),
}); err != nil {
t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
}
rd, _, err := node.client.ReadValue()
if err != nil {
t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
}
if rd.String() == test.value {
count += 1 // If the expected value is found, increment the count.
}
}
// Fail if count is less than quorum.
if count < quorum {
t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
return
}
}
// Delete the keys using raftApplyDelete method.
for _, test := range tests {
if err := nodes[0].server.raftApplyDeleteKey(nodes[0].server.context, test.key); err != nil {
t.Error(err)
}
}
<-time.After(200 * time.Millisecond) // Yield to give key deletion time to take effect across cluster.
// Check if the data is absent in quorum (majority of the cluster).
for i, test := range tests {
count := 0
for j := 0; j < len(nodes); j++ {
node := nodes[j]
if err := node.client.WriteArray([]resp.Value{
resp.StringValue("GET"),
resp.StringValue(test.key),
}); err != nil {
t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err)
}
rd, _, err := node.client.ReadValue()
if err != nil {
t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
}
if rd.IsNull() {
count += 1 // If the expected value is found, increment the count.
}
}
// Fail if count is less than quorum.
if count < quorum {
t.Errorf("found value %s at key %s in cluster quorum", test.value, test.key)
}
}
})
// t.Run("Test_ForwardCommand", func(t *testing.T) {
// tests := tests["forward"]
// // Write all the data a random cluster follower.
// for i, test := range tests {
// // Send write command to follower node.
// node := nodes[1]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("SET"),
// resp.StringValue(test.key),
// resp.StringValue(test.value),
// }); err != nil {
// t.Errorf("could not write data to follower node (test %d): %v", i, err)
// }
// // Read response and make sure we received "ok" response.
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read response from follower node (test %d): %v", i, err)
// }
// if !strings.EqualFold(rd.String(), "ok") {
// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String())
// }
// }
//
// <-time.After(200 * time.Millisecond) // Short yield to allow change to take effect.
//
// // Check if the data has been replicated on a quorum (majority of the cluster).
// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1)
// for i, test := range tests {
// count := 0
// for j := 0; j < len(nodes); j++ {
// node := nodes[j]
// if err := node.client.WriteArray([]resp.Value{
// resp.StringValue("GET"),
// resp.StringValue(test.key),
// }); err != nil {
// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err)
// }
// rd, _, err := node.client.ReadValue()
// if err != nil {
// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err)
// }
// if rd.String() == test.value {
// count += 1 // If the expected value is found, increment the count.
// }
// }
// // Fail if count is less than quorum.
// if count < quorum {
// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key)
// }
// }
// })
}
func Test_TLS(t *testing.T) { func Test_TLS(t *testing.T) {
port, err := internal.GetFreePort() port, err := internal.GetFreePort()
@@ -464,6 +506,7 @@ func Test_TLS(t *testing.T) {
server, err := NewEchoVault(WithConfig(conf)) server, err := NewEchoVault(WithConfig(conf))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return
} }
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
@@ -494,8 +537,12 @@ func Test_TLS(t *testing.T) {
}) })
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return
} }
defer func() {
_ = conn.Close()
server.ShutDown()
}()
client := resp.NewConn(conn) client := resp.NewConn(conn)
// Test that we can set and get a value from the server. // Test that we can set and get a value from the server.
@@ -561,6 +608,7 @@ func Test_MTLS(t *testing.T) {
server, err := NewEchoVault(WithConfig(conf)) server, err := NewEchoVault(WithConfig(conf))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return
} }
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
@@ -613,7 +661,10 @@ func Test_MTLS(t *testing.T) {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
server.ShutDown()
}()
client := resp.NewConn(conn) client := resp.NewConn(conn)
// Test that we can set and get a value from the server. // Test that we can set and get a value from the server.

View File

@@ -160,13 +160,15 @@ func (m *MemberList) MemberListShutdown() {
// Gracefully leave memberlist cluster // Gracefully leave memberlist cluster
err := m.memberList.Leave(500 * time.Millisecond) err := m.memberList.Leave(500 * time.Millisecond)
if err != nil { if err != nil {
log.Fatal("Could not gracefully leave memberlist cluster") log.Printf("memberlist leave: %v\n", err)
return
} }
err = m.memberList.Shutdown() err = m.memberList.Shutdown()
if err != nil { if err != nil {
log.Fatal("Could not gracefully shutdown memberlist background maintenance") log.Printf("memberlist shutdown: %v\n", err)
return
} }
log.Println("Successfully shutdown memberlist") log.Println("successfully shutdown memberlist")
} }

View File

@@ -32,7 +32,6 @@ import (
str "github.com/echovault/echovault/internal/modules/string" str "github.com/echovault/echovault/internal/modules/string"
"github.com/tidwall/resp" "github.com/tidwall/resp"
"net" "net"
"os"
"path" "path"
"slices" "slices"
"strings" "strings"
@@ -50,38 +49,41 @@ func setupServer(port uint16) (*echovault.EchoVault, error) {
} }
func Test_AdminCommands(t *testing.T) { func Test_AdminCommands(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
t.Cleanup(func() { t.Cleanup(func() {
_ = os.RemoveAll("./testdata") mockServer.ShutDown()
}) })
t.Run("Test COMMANDS command", func(t *testing.T) { t.Run("Test COMMANDS command", func(t *testing.T) {
t.Parallel() t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn) client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMANDS")}); err != nil { if err = client.WriteArray([]resp.Value{resp.StringValue("COMMANDS")}); err != nil {
@@ -128,31 +130,14 @@ func Test_AdminCommands(t *testing.T) {
t.Run("Test COMMAND COUNT command", func(t *testing.T) { t.Run("Test COMMAND COUNT command", func(t *testing.T) {
t.Parallel() t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn) client := resp.NewConn(conn)
if err = client.WriteArray([]resp.Value{resp.StringValue("COMMAND"), resp.StringValue("COUNT")}); err != nil { if err = client.WriteArray([]resp.Value{resp.StringValue("COMMAND"), resp.StringValue("COUNT")}); err != nil {
@@ -199,31 +184,14 @@ func Test_AdminCommands(t *testing.T) {
t.Run("Test COMMAND LIST command", func(t *testing.T) { t.Run("Test COMMAND LIST command", func(t *testing.T) {
t.Parallel() t.Parallel()
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn) client := resp.NewConn(conn)
// Get all the commands from the existing modules. // Get all the commands from the existing modules.
@@ -336,17 +304,6 @@ func Test_AdminCommands(t *testing.T) {
}) })
t.Run("Test MODULE LOAD command", func(t *testing.T) { t.Run("Test MODULE LOAD command", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
tests := []struct { tests := []struct {
name string name string
execCommand []resp.Value execCommand []resp.Value
@@ -433,20 +390,14 @@ func Test_AdminCommands(t *testing.T) {
}, },
} }
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
respConn := resp.NewConn(conn) respConn := resp.NewConn(conn)
for i := 0; i < len(tests); i++ { for i := 0; i < len(tests); i++ {
@@ -505,31 +456,14 @@ func Test_AdminCommands(t *testing.T) {
}) })
t.Run("Test MODULE UNLOAD command", func(t *testing.T) { t.Run("Test MODULE UNLOAD command", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
respConn := resp.NewConn(conn) respConn := resp.NewConn(conn)
// Load module.set module // Load module.set module
@@ -693,31 +627,14 @@ func Test_AdminCommands(t *testing.T) {
}) })
t.Run("Test MODULE LIST command", func(t *testing.T) { t.Run("Test MODULE LIST command", func(t *testing.T) {
port, err := internal.GetFreePort()
if err != nil {
t.Error(err)
return
}
mockServer, err := setupServer(uint16(port))
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
wg.Done()
mockServer.Start()
}()
wg.Wait()
conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
respConn := resp.NewConn(conn) respConn := resp.NewConn(conn)
// Load module.get module with arg // Load module.get module with arg

View File

@@ -28,20 +28,26 @@ import (
"testing" "testing"
) )
var mockServer *echovault.EchoVault func Test_Connection(t *testing.T) {
var port int port, err := internal.GetFreePort()
var addr = "localhost" if err != nil {
t.Error(err)
return
}
func init() { mockServer, err := echovault.NewEchoVault(
port, _ = internal.GetFreePort()
mockServer, _ = echovault.NewEchoVault(
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,
BindAddr: addr, BindAddr: "localhost",
Port: uint16(port), Port: uint16(port),
}), }),
) )
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -49,62 +55,70 @@ func init() {
mockServer.Start() mockServer.Start()
}() }()
wg.Wait() wg.Wait()
}
func Test_HandlePing(t *testing.T) { t.Cleanup(func() {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) mockServer.ShutDown()
if err != nil { })
t.Error(err)
return
}
client := resp.NewConn(conn)
tests := []struct { t.Run("Test_HandlePing", func(t *testing.T) {
command []resp.Value conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
expected string if err != nil {
expectedErr error
}{
{
command: []resp.Value{resp.StringValue("PING")},
expected: "PONG",
expectedErr: nil,
},
{
command: []resp.Value{resp.StringValue("PING"), resp.StringValue("Hello, world!")},
expected: "Hello, world!",
expectedErr: nil,
},
{
command: []resp.Value{
resp.StringValue("PING"),
resp.StringValue("Hello, world!"),
resp.StringValue("Once more"),
},
expected: "",
expectedErr: errors.New(constants.WrongArgsResponse),
},
}
for _, test := range tests {
if err = client.WriteArray(test.command); err != nil {
t.Error(err) t.Error(err)
return return
} }
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn)
res, _, err := client.ReadValue() tests := []struct {
if err != nil { command []resp.Value
t.Error(err) expected string
expectedErr error
}{
{
command: []resp.Value{resp.StringValue("PING")},
expected: "PONG",
expectedErr: nil,
},
{
command: []resp.Value{resp.StringValue("PING"), resp.StringValue("Hello, world!")},
expected: "Hello, world!",
expectedErr: nil,
},
{
command: []resp.Value{
resp.StringValue("PING"),
resp.StringValue("Hello, world!"),
resp.StringValue("Once more"),
},
expected: "",
expectedErr: errors.New(constants.WrongArgsResponse),
},
} }
if test.expectedErr != nil { for _, test := range tests {
if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { if err = client.WriteArray(test.command); err != nil {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) t.Error(err)
return
} }
continue
}
if res.String() != test.expected { res, _, err := client.ReadValue()
t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String()) if err != nil {
t.Error(err)
}
if test.expectedErr != nil {
if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error())
}
continue
}
if res.String() != test.expected {
t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String())
}
} }
} })
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -29,20 +29,26 @@ import (
"testing" "testing"
) )
var mockServer *echovault.EchoVault func Test_String(t *testing.T) {
var addr = "localhost" port, err := internal.GetFreePort()
var port int if err != nil {
t.Error()
return
}
func init() { mockServer, err := echovault.NewEchoVault(
port, _ = internal.GetFreePort()
mockServer, _ = echovault.NewEchoVault(
echovault.WithConfig(config.Config{ echovault.WithConfig(config.Config{
BindAddr: addr, BindAddr: "localhost",
Port: uint16(port), Port: uint16(port),
DataDir: "", DataDir: "",
EvictionPolicy: constants.NoEviction, EvictionPolicy: constants.NoEviction,
}), }),
) )
if err != nil {
t.Error(err)
return
}
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
@@ -50,115 +56,140 @@ func init() {
mockServer.Start() mockServer.Start()
}() }()
wg.Wait() wg.Wait()
}
func Test_HandleSetRange(t *testing.T) { t.Cleanup(func() {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) mockServer.ShutDown()
if err != nil { })
t.Error(err)
return
}
client := resp.NewConn(conn)
tests := []struct { t.Run("Test_HandleSetRange", func(t *testing.T) {
name string t.Parallel()
key string conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
presetValue string if err != nil {
command []string t.Error(err)
expectedValue string return
expectedResponse int }
expectedError error defer func() {
}{ _ = conn.Close()
{ }()
name: "Test that SETRANGE on non-existent string creates new string", client := resp.NewConn(conn)
key: "SetRangeKey1",
presetValue: "",
command: []string{"SETRANGE", "SetRangeKey1", "10", "New String Value"},
expectedValue: "New String Value",
expectedResponse: len("New String Value"),
expectedError: nil,
},
{
name: "Test SETRANGE with an offset that leads to a longer resulting string",
key: "SetRangeKey2",
presetValue: "Original String Value",
command: []string{"SETRANGE", "SetRangeKey2", "16", "Portion Replaced With This New String"},
expectedValue: "Original String Portion Replaced With This New String",
expectedResponse: len("Original String Portion Replaced With This New String"),
expectedError: nil,
},
{
name: "SETRANGE with negative offset prepends the string",
key: "SetRangeKey3",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey3", "-10", "Prepended "},
expectedValue: "Prepended This is a preset value",
expectedResponse: len("Prepended This is a preset value"),
expectedError: nil,
},
{
name: "SETRANGE with offset that embeds new string inside the old string",
key: "SetRangeKey4",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey4", "0", "That"},
expectedValue: "That is a preset value",
expectedResponse: len("That is a preset value"),
expectedError: nil,
},
{
name: "SETRANGE with offset longer than original lengths appends the string",
key: "SetRangeKey5",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey5", "100", " Appended"},
expectedValue: "This is a preset value Appended",
expectedResponse: len("This is a preset value Appended"),
expectedError: nil,
},
{
name: "SETRANGE with offset on the last character replaces last character with new string",
key: "SetRangeKey6",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey6", strconv.Itoa(len("This is a preset value") - 1), " replaced"},
expectedValue: "This is a preset valu replaced",
expectedResponse: len("This is a preset valu replaced"),
expectedError: nil,
},
{
name: " Offset not integer",
command: []string{"SETRANGE", "key", "offset", "value"},
expectedResponse: 0,
expectedError: errors.New("offset must be an integer"),
},
{
name: "SETRANGE target is not a string",
key: "test-int",
presetValue: "10",
command: []string{"SETRANGE", "test-int", "10", "value"},
expectedResponse: 0,
expectedError: errors.New("value at key test-int is not a string"),
},
{
name: "Command too short",
command: []string{"SETRANGE", "key"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Command too long",
command: []string{"SETRANGE", "key", "offset", "value", "value1"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
}
for _, test := range tests { tests := []struct {
t.Run(test.name, func(t *testing.T) { name string
if test.presetValue != "" { key string
if err = client.WriteArray([]resp.Value{ presetValue string
resp.StringValue("SET"), command []string
resp.StringValue(test.key), expectedValue string
resp.StringValue(test.presetValue), expectedResponse int
}); err != nil { expectedError error
}{
{
name: "Test that SETRANGE on non-existent string creates new string",
key: "SetRangeKey1",
presetValue: "",
command: []string{"SETRANGE", "SetRangeKey1", "10", "New String Value"},
expectedValue: "New String Value",
expectedResponse: len("New String Value"),
expectedError: nil,
},
{
name: "Test SETRANGE with an offset that leads to a longer resulting string",
key: "SetRangeKey2",
presetValue: "Original String Value",
command: []string{"SETRANGE", "SetRangeKey2", "16", "Portion Replaced With This New String"},
expectedValue: "Original String Portion Replaced With This New String",
expectedResponse: len("Original String Portion Replaced With This New String"),
expectedError: nil,
},
{
name: "SETRANGE with negative offset prepends the string",
key: "SetRangeKey3",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey3", "-10", "Prepended "},
expectedValue: "Prepended This is a preset value",
expectedResponse: len("Prepended This is a preset value"),
expectedError: nil,
},
{
name: "SETRANGE with offset that embeds new string inside the old string",
key: "SetRangeKey4",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey4", "0", "That"},
expectedValue: "That is a preset value",
expectedResponse: len("That is a preset value"),
expectedError: nil,
},
{
name: "SETRANGE with offset longer than original lengths appends the string",
key: "SetRangeKey5",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey5", "100", " Appended"},
expectedValue: "This is a preset value Appended",
expectedResponse: len("This is a preset value Appended"),
expectedError: nil,
},
{
name: "SETRANGE with offset on the last character replaces last character with new string",
key: "SetRangeKey6",
presetValue: "This is a preset value",
command: []string{"SETRANGE", "SetRangeKey6", strconv.Itoa(len("This is a preset value") - 1), " replaced"},
expectedValue: "This is a preset valu replaced",
expectedResponse: len("This is a preset valu replaced"),
expectedError: nil,
},
{
name: " Offset not integer",
command: []string{"SETRANGE", "key", "offset", "value"},
expectedResponse: 0,
expectedError: errors.New("offset must be an integer"),
},
{
name: "SETRANGE target is not a string",
key: "test-int",
presetValue: "10",
command: []string{"SETRANGE", "test-int", "10", "value"},
expectedResponse: 0,
expectedError: errors.New("value at key test-int is not a string"),
},
{
name: "Command too short",
command: []string{"SETRANGE", "key"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Command too long",
command: []string{"SETRANGE", "key", "offset", "value", "value1"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.presetValue != "" {
if err = client.WriteArray([]resp.Value{
resp.StringValue("SET"),
resp.StringValue(test.key),
resp.StringValue(test.presetValue),
}); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if !strings.EqualFold(res.String(), "ok") {
t.Errorf("expected preset response to be OK, got %s", res.String())
}
}
command := make([]resp.Value, len(test.command))
for i, c := range test.command {
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err) t.Error(err)
} }
res, _, err := client.ReadValue() res, _, err := client.ReadValue()
@@ -166,95 +197,100 @@ func Test_HandleSetRange(t *testing.T) {
t.Error(err) t.Error(err)
} }
if !strings.EqualFold(res.String(), "ok") { if test.expectedError != nil {
t.Errorf("expected preset response to be OK, got %s", res.String()) if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
}
return
} }
}
command := make([]resp.Value, len(test.command)) if res.Integer() != test.expectedResponse {
for i, c := range test.command { t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer())
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if test.expectedError != nil {
if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
} }
return })
} }
})
if res.Integer() != test.expectedResponse { t.Run("Test_HandleStrLen", func(t *testing.T) {
t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) t.Parallel()
} conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
}) if err != nil {
} t.Error(err)
} return
}
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn)
func Test_HandleStrLen(t *testing.T) { tests := []struct {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) name string
if err != nil { key string
t.Error(err) presetValue string
} command []string
client := resp.NewConn(conn) expectedResponse int
expectedError error
}{
{
name: "Return the correct string length for an existing string",
key: "StrLenKey1",
presetValue: "Test String",
command: []string{"STRLEN", "StrLenKey1"},
expectedResponse: len("Test String"),
expectedError: nil,
},
{
name: "If the string does not exist, return 0",
key: "StrLenKey2",
presetValue: "",
command: []string{"STRLEN", "StrLenKey2"},
expectedResponse: 0,
expectedError: nil,
},
{
name: "Too few args",
key: "StrLenKey3",
presetValue: "",
command: []string{"STRLEN"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Too many args",
key: "StrLenKey4",
presetValue: "",
command: []string{"STRLEN", "StrLenKey4", "StrLenKey5"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
}
tests := []struct { for _, test := range tests {
name string t.Run(test.name, func(t *testing.T) {
key string if test.presetValue != "" {
presetValue string if err = client.WriteArray([]resp.Value{
command []string resp.StringValue("SET"),
expectedResponse int resp.StringValue(test.key),
expectedError error resp.StringValue(test.presetValue),
}{ }); err != nil {
{ t.Error(err)
name: "Return the correct string length for an existing string", }
key: "StrLenKey1", res, _, err := client.ReadValue()
presetValue: "Test String", if err != nil {
command: []string{"STRLEN", "StrLenKey1"}, t.Error(err)
expectedResponse: len("Test String"), }
expectedError: nil,
},
{
name: "If the string does not exist, return 0",
key: "StrLenKey2",
presetValue: "",
command: []string{"STRLEN", "StrLenKey2"},
expectedResponse: 0,
expectedError: nil,
},
{
name: "Too few args",
key: "StrLenKey3",
presetValue: "",
command: []string{"STRLEN"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Too many args",
key: "StrLenKey4",
presetValue: "",
command: []string{"STRLEN", "StrLenKey4", "StrLenKey5"},
expectedResponse: 0,
expectedError: errors.New(constants.WrongArgsResponse),
},
}
for _, test := range tests { if !strings.EqualFold(res.String(), "ok") {
t.Run(test.name, func(t *testing.T) { t.Errorf("expected preset response to be OK, got %s", res.String())
if test.presetValue != "" { }
if err = client.WriteArray([]resp.Value{ }
resp.StringValue("SET"),
resp.StringValue(test.key), command := make([]resp.Value, len(test.command))
resp.StringValue(test.presetValue), for i, c := range test.command {
}); err != nil { command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err) t.Error(err)
} }
res, _, err := client.ReadValue() res, _, err := client.ReadValue()
@@ -262,140 +298,145 @@ func Test_HandleStrLen(t *testing.T) {
t.Error(err) t.Error(err)
} }
if !strings.EqualFold(res.String(), "ok") { if test.expectedError != nil {
t.Errorf("expected preset response to be OK, got %s", res.String()) if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
}
return
} }
}
command := make([]resp.Value, len(test.command)) if res.Integer() != test.expectedResponse {
for i, c := range test.command { t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer())
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if test.expectedError != nil {
if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
} }
return })
} }
})
if res.Integer() != test.expectedResponse { t.Run("Test_HandleSubStr", func(t *testing.T) {
t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) t.Parallel()
} conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
}) if err != nil {
} t.Error(err)
} return
}
defer func() {
_ = conn.Close()
}()
client := resp.NewConn(conn)
func Test_HandleSubStr(t *testing.T) { tests := []struct {
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) name string
if err != nil { key string
t.Error(err) presetValue string
} command []string
client := resp.NewConn(conn) expectedResponse string
expectedError error
}{
{
name: "Return substring within the range of the string",
key: "SubStrKey1",
presetValue: "Test String One",
command: []string{"SUBSTR", "SubStrKey1", "5", "10"},
expectedResponse: "String",
expectedError: nil,
},
{
name: "Return substring at the end of the string with exact end index",
key: "SubStrKey2",
presetValue: "Test String Two",
command: []string{"SUBSTR", "SubStrKey2", "12", "14"},
expectedResponse: "Two",
expectedError: nil,
},
{
name: "Return substring at the end of the string with end index greater than length",
key: "SubStrKey3",
presetValue: "Test String Three",
command: []string{"SUBSTR", "SubStrKey3", "12", "75"},
expectedResponse: "Three",
expectedError: nil,
},
{
name: "Return the substring at the start of the string with 0 start index",
key: "SubStrKey4",
presetValue: "Test String Four",
command: []string{"SUBSTR", "SubStrKey4", "0", "3"},
expectedResponse: "Test",
expectedError: nil,
},
{
// Return the substring with negative start index.
// Substring should begin abs(start) from the end of the string when start is negative.
name: "Return the substring with negative start index",
key: "SubStrKey5",
presetValue: "Test String Five",
command: []string{"SUBSTR", "SubStrKey5", "-11", "10"},
expectedResponse: "String",
expectedError: nil,
},
{
// Return reverse substring with end index smaller than start index.
// When end index is smaller than start index, the 2 indices are reversed.
name: "Return reverse substring with end index smaller than start index",
key: "SubStrKey6",
presetValue: "Test String Six",
command: []string{"SUBSTR", "SubStrKey6", "4", "0"},
expectedResponse: "tseT",
expectedError: nil,
},
{
name: "Command too short",
command: []string{"SUBSTR", "key", "10"},
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Command too long",
command: []string{"SUBSTR", "key", "10", "15", "20"},
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Start index is not an integer",
command: []string{"SUBSTR", "key", "start", "10"},
expectedError: errors.New("start and end indices must be integers"),
},
{
name: "End index is not an integer",
command: []string{"SUBSTR", "key", "0", "end"},
expectedError: errors.New("start and end indices must be integers"),
},
{
name: "Non-existent key",
command: []string{"SUBSTR", "non-existent-key", "0", "10"},
expectedError: errors.New("key non-existent-key does not exist"),
},
}
tests := []struct { for _, test := range tests {
name string t.Run(test.name, func(t *testing.T) {
key string if test.presetValue != "" {
presetValue string if err = client.WriteArray([]resp.Value{
command []string resp.StringValue("SET"),
expectedResponse string resp.StringValue(test.key),
expectedError error resp.StringValue(test.presetValue),
}{ }); err != nil {
{ t.Error(err)
name: "Return substring within the range of the string", }
key: "SubStrKey1", res, _, err := client.ReadValue()
presetValue: "Test String One", if err != nil {
command: []string{"SUBSTR", "SubStrKey1", "5", "10"}, t.Error(err)
expectedResponse: "String", }
expectedError: nil,
},
{
name: "Return substring at the end of the string with exact end index",
key: "SubStrKey2",
presetValue: "Test String Two",
command: []string{"SUBSTR", "SubStrKey2", "12", "14"},
expectedResponse: "Two",
expectedError: nil,
},
{
name: "Return substring at the end of the string with end index greater than length",
key: "SubStrKey3",
presetValue: "Test String Three",
command: []string{"SUBSTR", "SubStrKey3", "12", "75"},
expectedResponse: "Three",
expectedError: nil,
},
{
name: "Return the substring at the start of the string with 0 start index",
key: "SubStrKey4",
presetValue: "Test String Four",
command: []string{"SUBSTR", "SubStrKey4", "0", "3"},
expectedResponse: "Test",
expectedError: nil,
},
{
// Return the substring with negative start index.
// Substring should begin abs(start) from the end of the string when start is negative.
name: "Return the substring with negative start index",
key: "SubStrKey5",
presetValue: "Test String Five",
command: []string{"SUBSTR", "SubStrKey5", "-11", "10"},
expectedResponse: "String",
expectedError: nil,
},
{
// Return reverse substring with end index smaller than start index.
// When end index is smaller than start index, the 2 indices are reversed.
name: "Return reverse substring with end index smaller than start index",
key: "SubStrKey6",
presetValue: "Test String Six",
command: []string{"SUBSTR", "SubStrKey6", "4", "0"},
expectedResponse: "tseT",
expectedError: nil,
},
{
name: "Command too short",
command: []string{"SUBSTR", "key", "10"},
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Command too long",
command: []string{"SUBSTR", "key", "10", "15", "20"},
expectedError: errors.New(constants.WrongArgsResponse),
},
{
name: "Start index is not an integer",
command: []string{"SUBSTR", "key", "start", "10"},
expectedError: errors.New("start and end indices must be integers"),
},
{
name: "End index is not an integer",
command: []string{"SUBSTR", "key", "0", "end"},
expectedError: errors.New("start and end indices must be integers"),
},
{
name: "Non-existent key",
command: []string{"SUBSTR", "non-existent-key", "0", "10"},
expectedError: errors.New("key non-existent-key does not exist"),
},
}
for _, test := range tests { if !strings.EqualFold(res.String(), "ok") {
t.Run(test.name, func(t *testing.T) { t.Errorf("expected preset response to be OK, got %s", res.String())
if test.presetValue != "" { }
if err = client.WriteArray([]resp.Value{ }
resp.StringValue("SET"),
resp.StringValue(test.key), command := make([]resp.Value, len(test.command))
resp.StringValue(test.presetValue), for i, c := range test.command {
}); err != nil { command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err) t.Error(err)
} }
res, _, err := client.ReadValue() res, _, err := client.ReadValue()
@@ -403,34 +444,17 @@ func Test_HandleSubStr(t *testing.T) {
t.Error(err) t.Error(err)
} }
if !strings.EqualFold(res.String(), "ok") { if test.expectedError != nil {
t.Errorf("expected preset response to be OK, got %s", res.String()) if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
}
return
} }
}
command := make([]resp.Value, len(test.command)) if res.String() != test.expectedResponse {
for i, c := range test.command { t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String())
command[i] = resp.StringValue(c)
}
if err = client.WriteArray(command); err != nil {
t.Error(err)
}
res, _, err := client.ReadValue()
if err != nil {
t.Error(err)
}
if test.expectedError != nil {
if !strings.Contains(res.Error().Error(), test.expectedError.Error()) {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
} }
return })
} }
})
if res.String() != test.expectedResponse {
t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String())
}
})
}
} }

View File

@@ -216,12 +216,13 @@ func (r *Raft) TakeSnapshot() error {
} }
func (r *Raft) RaftShutdown() { func (r *Raft) RaftShutdown() {
// Leadership transfer if current node is the leader // Leadership transfer if current node is the leader.
if r.IsRaftLeader() { if r.IsRaftLeader() {
err := r.raft.LeadershipTransfer().Error() err := r.raft.LeadershipTransfer().Error()
if err != nil { if err != nil {
log.Fatal(err) log.Printf("raft shutdown: %v\n", err)
return
} }
log.Println("Leadership transfer successful.") log.Println("leadership transfer successful.")
} }
} }