From c7560ce9dd8eec6ac04c1e351c65e6641e775c8f Mon Sep 17 00:00:00 2001 From: Kelvin Clement Mwinuka Date: Fri, 31 May 2024 01:30:18 +0800 Subject: [PATCH] Updated all test suites to include connection and server shutdown on cleanup. --- echovault/cluster_test.go | 119 - echovault/echovault.go | 66 +- echovault/echovault_test.go | 603 +- internal/memberlist/memberlist.go | 8 +- internal/modules/admin/commands_test.go | 161 +- internal/modules/connection/commands_test.go | 122 +- internal/modules/generic/commands_test.go | 3390 +++--- internal/modules/hash/commands_test.go | 3220 ++--- internal/modules/list/commands_test.go | 3250 ++--- internal/modules/set/commands_test.go | 4386 +++---- internal/modules/sorted_set/commands_test.go | 10402 +++++++++-------- internal/modules/string/commands_test.go | 722 +- internal/raft/raft.go | 7 +- 13 files changed, 13350 insertions(+), 13106 deletions(-) delete mode 100644 echovault/cluster_test.go diff --git a/echovault/cluster_test.go b/echovault/cluster_test.go deleted file mode 100644 index f5bd3bb..0000000 --- a/echovault/cluster_test.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2024 Kelvin Clement Mwinuka -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package echovault - -// func Test_raftApplyDeleteKey(t *testing.T) { -// nodes, err := makeCluster(5) -// if err != nil { -// t.Error(err) -// return -// } -// -// // Prepare the write data for the cluster. -// tests := []struct { -// key string -// value string -// }{ -// { -// key: "key1", -// value: "value1", -// }, -// { -// key: "key2", -// value: "value2", -// }, -// { -// key: "key3", -// value: "value3", -// }, -// } -// -// // Write all the data to the cluster leader. -// for i, test := range tests { -// node := nodes[0] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("SET"), -// resp.StringValue(test.key), -// resp.StringValue(test.value), -// }); err != nil { -// t.Errorf("could not write data to leader node (test %d): %v", i, err) -// } -// // Read response and make sure we received "ok" response. -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read response from leader node (test %d): %v", i, err) -// } -// if !strings.EqualFold(rd.String(), "ok") { -// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) -// } -// } -// -// // Check if the data has been replicated on a quorum (majority of the cluster). -// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.String() == test.value { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) -// } -// // Delete key across raft cluster. -// if err = nodes[0].server.raftApplyDeleteKey(nodes[0].server.context, test.key); err != nil { -// t.Error(err) -// } -// } -// -// <-time.After(200 * time.Millisecond) // Yield to give key deletion time to take effect across cluster. -// -// // Check if the data is absent in quorum (majority of the cluster). -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.IsNull() { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("found value %s at key %s in cluster quorum", test.value, test.key) -// } -// } -// } diff --git a/echovault/echovault.go b/echovault/echovault.go index e7fe359..52d40aa 100644 --- a/echovault/echovault.go +++ b/echovault/echovault.go @@ -73,7 +73,7 @@ type EchoVault struct { mutex sync.Mutex // Mutex as only one goroutine can edit the LFU cache at a time. cache eviction.CacheLFU // LFU cache represented by a min head. } - // LRU cache used when eviction policy is allkeys-lru or volatile-lru + // LRU cache used when eviction policy is allkeys-lru or volatile-lru. lruCache struct { mutex sync.Mutex // Mutex as only one goroutine can edit the LRU at a time. cache eviction.CacheLRU // LRU cache represented by a max head. @@ -95,9 +95,12 @@ type EchoVault struct { rewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress. stateCopyInProgress atomic.Bool // Atomic boolean that's true when actively copying state for snapshotting or preamble generation. stateMutationInProgress atomic.Bool // Atomic boolean that is set to true when state mutation is in progress. - latestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds - snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode - aofEngine *aof.Engine // AOF engine for standalone mode + latestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds. + snapshotEngine *snapshot.Engine // Snapshot engine for standalone mode. + aofEngine *aof.Engine // AOF engine for standalone mode. + + listener net.Listener // TCP listener. + quit chan struct{} // Channel that signals the closing of all client connections. } // WithContext is an options that for the NewEchoVault function that allows you to @@ -142,6 +145,7 @@ func NewEchoVault(options ...func(echovault *EchoVault)) (*EchoVault, error) { commands = append(commands, str.Commands()...) return commands }(), + quit: make(chan struct{}), } for _, option := range options { @@ -320,30 +324,35 @@ func (server *EchoVault) startTCP() { KeepAlive: 200 * time.Millisecond, } - listener, err := listenConfig.Listen(server.context, "tcp", fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port)) - + listener, err := listenConfig.Listen( + server.context, + "tcp", + fmt.Sprintf("%s:%d", conf.BindAddr, conf.Port), + ) if err != nil { - log.Fatal(err) + log.Printf("listener error: %v", err) + return } if !conf.TLS { // TCP - log.Printf("Starting TCP echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + log.Printf("Starting TCP server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) } if conf.TLS || conf.MTLS { // TLS if conf.TLS { - log.Printf("Starting mTLS echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + log.Printf("Starting mTLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) } else { - log.Printf("Starting TLS echovault at Address %s, Port %d...\n", conf.BindAddr, conf.Port) + log.Printf("Starting TLS server at Address %s, Port %d...\n", conf.BindAddr, conf.Port) } var certificates []tls.Certificate for _, certKeyPair := range conf.CertKeyPairs { c, err := tls.LoadX509KeyPair(certKeyPair[0], certKeyPair[1]) if err != nil { - log.Fatal(err) + log.Printf("load cert key pair: %v\n", err) + return } certificates = append(certificates, c) } @@ -356,14 +365,15 @@ func (server *EchoVault) startTCP() { for _, c := range conf.ClientCAs { ca, err := os.Open(c) if err != nil { - log.Fatal(err) + log.Printf("client cert open: %v\n", err) + return } certBytes, err := io.ReadAll(ca) if err != nil { - log.Fatal(err) + log.Printf("client cert read: %v\n", err) } if ok := clientCerts.AppendCertsFromPEM(certBytes); !ok { - log.Fatal(err) + log.Printf("client cert append: %v\n", err) } } } @@ -375,15 +385,22 @@ func (server *EchoVault) startTCP() { }) } - // Listen to connection + server.listener = listener + + // Listen to connection. for { - conn, err := listener.Accept() - if err != nil { - log.Println("Could not establish connection") - continue + select { + case <-server.quit: + return + default: + conn, err := server.listener.Accept() + if err != nil { + log.Printf("listener error: %v\n", err) + continue + } + // Read loop for connection + go server.handleConnection(conn) } - // Read loop for connection - go server.handleConnection(conn) } } @@ -536,6 +553,13 @@ func (server *EchoVault) rewriteAOF() error { // ShutDown gracefully shuts down the EchoVault instance. // This function shuts down the memberlist and raft layers. func (server *EchoVault) ShutDown() { + if server.listener != nil { + go func() { server.quit <- struct{}{} }() + log.Println("closing tcp listener...") + if err := server.listener.Close(); err != nil { + log.Printf("listener close: %v\n", err) + } + } if server.isInCluster() { server.raft.RaftShutdown() server.memberList.MemberListShutdown() diff --git a/echovault/echovault_test.go b/echovault/echovault_test.go index 0527676..f8899c0 100644 --- a/echovault/echovault_test.go +++ b/echovault/echovault_test.go @@ -23,12 +23,14 @@ import ( "github.com/echovault/echovault/internal" "github.com/tidwall/resp" "io" + "math" "net" "os" "path" "strings" "sync" "testing" + "time" ) type ClientServerPair struct { @@ -38,6 +40,7 @@ type ClientServerPair struct { raftPort int mlPort int bootstrapCluster bool + raw net.Conn client *resp.Conn server *EchoVault } @@ -62,8 +65,6 @@ func getBindAddr() net.IP { return getBindAddrNet(0) } -var setupLock sync.Mutex - func setupServer( serverId string, bootstrapCluster bool, @@ -73,25 +74,20 @@ func setupServer( raftPort, mlPort int, ) (*EchoVault, error) { - setupLock.Lock() - defer setupLock.Unlock() - - ctx := context.Background() - config := DefaultConfig() config.DataDir = "./testdata" config.ForwardCommand = true config.BindAddr = bindAddr config.JoinAddr = joinAddr config.Port = uint16(port) - // config.InMemory = true + config.InMemory = true config.ServerID = serverId config.RaftBindPort = uint16(raftPort) config.MemberListBindPort = uint16(mlPort) config.BootstrapCluster = bootstrapCluster return NewEchoVault( - WithContext(ctx), + WithContext(context.Background()), WithConfig(config), ) } @@ -163,6 +159,7 @@ func makeCluster(size int) ([]ClientServerPair, error) { raftPort: raftPort, mlPort: memberlistPort, bootstrapCluster: bootstrapCluster, + raw: conn, client: client, server: server, } @@ -171,273 +168,318 @@ func makeCluster(size int) ([]ClientServerPair, error) { return pairs, nil } -// func Test_ClusterReplication(t *testing.T) { -// nodes, err := makeCluster(5) -// if err != nil { -// t.Error(err) -// return -// } -// -// // Prepare the write data for the cluster. -// tests := []struct { -// key string -// value string -// }{ -// { -// key: "key1", -// value: "value1", -// }, -// { -// key: "key2", -// value: "value2", -// }, -// { -// key: "key3", -// value: "value3", -// }, -// } -// -// // Write all the data to the cluster leader -// for i, test := range tests { -// node := nodes[0] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("SET"), -// resp.StringValue(test.key), -// resp.StringValue(test.value), -// }); err != nil { -// t.Errorf("could not write data to leader node (test %d): %v", i, err) -// } -// // Read response and make sure we received "ok" response. -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read response from leader node (test %d): %v", i, err) -// } -// if !strings.EqualFold(rd.String(), "ok") { -// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) -// } -// } -// -// // Check if the data has been replicated on a quorum (majority of the cluster). -// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.String() == test.value { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) -// } -// } -// } +func Test_Cluster(t *testing.T) { + nodes, err := makeCluster(5) + if err != nil { + t.Error(err) + return + } -// func Test_ClusterDeleteKey(t *testing.T) { -// nodes, err := makeCluster(5) -// if err != nil { -// t.Error(err) -// return -// } -// -// // Prepare the write data for the cluster -// tests := []struct { -// key string -// value string -// }{ -// { -// key: "key1", -// value: "value1", -// }, -// { -// key: "key2", -// value: "value2", -// }, -// { -// key: "key3", -// value: "value3", -// }, -// } -// -// // Write all the data to the cluster leader -// for i, test := range tests { -// node := nodes[0] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("SET"), -// resp.StringValue(test.key), -// resp.StringValue(test.value), -// }); err != nil { -// t.Errorf("could not write command to leader node (test %d): %v", i, err) -// } -// // Read response and make sure we received "ok" response. -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read response from leader node (test %d): %v", i, err) -// } -// if !strings.EqualFold(rd.String(), "ok") { -// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) -// } -// } -// -// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) -// -// // Check if the data has been replicated on a quorum (majority of the cluster). -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.String() == test.value { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) -// return -// } -// } -// -// // Delete the key on the leader node -// for i, test := range tests { -// node := nodes[0] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("DEL"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write command to leader node (test %d): %v", i, err) -// } -// // Read response and make sure we received "ok" response. -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read response from leader node (test %d): %v", i, err) -// } -// if rd.Integer() != 1 { -// t.Errorf("expected response for test %d to be 1, got %d", i, rd.Integer()) -// } -// } -// -// // Check if the data is absent in quorum (majority of the cluster). -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.IsNull() { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) -// } -// } -// } + defer func() { + for _, node := range nodes { + _ = node.raw.Close() + node.server.ShutDown() + } + }() -// func Test_CommandForwarded(t *testing.T) { -// nodes, err := makeCluster(5) -// if err != nil { -// t.Error(err) -// return -// } -// -// // Prepare the write data for the cluster -// tests := []struct { -// key string -// value string -// }{ -// { -// key: "key1", -// value: "value1", -// }, -// { -// key: "key2", -// value: "value2", -// }, -// { -// key: "key3", -// value: "value3", -// }, -// } -// -// // Write all the data a random cluster follower. -// for i, test := range tests { -// // Send write command to follower node. -// node := nodes[1] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("SET"), -// resp.StringValue(test.key), -// resp.StringValue(test.value), -// }); err != nil { -// t.Errorf("could not write data to leader node (test %d): %v", i, err) -// } -// // Read response and make sure we received "ok" response. -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read response from leader node (test %d): %v", i, err) -// } -// if !strings.EqualFold(rd.String(), "ok") { -// t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) -// } -// } -// -// <-time.After(250 * time.Millisecond) // Short yield to allow change to take effect. -// -// // Check if the data has been replicated on a quorum (majority of the cluster). -// quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) -// for i, test := range tests { -// count := 0 -// for j := 0; j < len(nodes); j++ { -// node := nodes[j] -// if err := node.client.WriteArray([]resp.Value{ -// resp.StringValue("GET"), -// resp.StringValue(test.key), -// }); err != nil { -// t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) -// } -// rd, _, err := node.client.ReadValue() -// if err != nil { -// t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) -// } -// if rd.String() == test.value { -// count += 1 // If the expected value is found, increment the count. -// } -// } -// // Fail if count is less than quorum. -// if count < quorum { -// t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) -// } -// } -// } + // Prepare the write data for the cluster. + tests := map[string][]struct { + key string + value string + }{ + "replication": { + {key: "key1", value: "value1"}, + {key: "key2", value: "value2"}, + {key: "key3", value: "value3"}, + }, + "deletion": { + {key: "key4", value: "value4"}, + {key: "key5", value: "value4"}, + {key: "key6", value: "value5"}, + }, + "raft-apply-delete": { + {key: "key7", value: "value7"}, + {key: "key8", value: "value8"}, + {key: "key9", value: "value9"}, + }, + "forward": { + {key: "key10", value: "value10"}, + {key: "key11", value: "value11"}, + {key: "key12", value: "value12"}, + }, + } + + t.Run("Test_Replication", func(t *testing.T) { + tests := tests["replication"] + // Write all the data to the cluster leader. + for i, test := range tests { + node := nodes[0] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("SET"), resp.StringValue(test.key), resp.StringValue(test.value), + }); err != nil { + t.Errorf("could not write data to leader node (test %d): %v", i, err) + } + // Read response and make sure we received "ok" response. + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read response from leader node (test %d): %v", i, err) + } + if !strings.EqualFold(rd.String(), "ok") { + t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) + } + } + + <-time.After(200 * time.Millisecond) // Yield + + // Check if the data has been replicated on a quorum (majority of the cluster). + quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) + for i, test := range tests { + count := 0 + for j := 0; j < len(nodes); j++ { + node := nodes[j] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("GET"), + resp.StringValue(test.key), + }); err != nil { + t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) + } + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + } + if rd.String() == test.value { + count += 1 // If the expected value is found, increment the count. + } + } + // Fail if count is less than quorum. + if count < quorum { + t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) + } + } + }) + + t.Run("Test_DeleteKey", func(t *testing.T) { + tests := tests["deletion"] + // Write all the data to the cluster leader. + for i, test := range tests { + node := nodes[0] + _, ok, err := node.server.Set(test.key, test.value, SetOptions{}) + if err != nil { + t.Errorf("could not write command to leader node (test %d): %v", i, err) + } + if !ok { + t.Errorf("expected set for test %d ok = true, got ok = false", i) + } + } + + <-time.After(200 * time.Millisecond) // Yield + + // Check if the data has been replicated on a quorum (majority of the cluster). + quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) + for i, test := range tests { + count := 0 + for j := 0; j < len(nodes); j++ { + node := nodes[j] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("GET"), + resp.StringValue(test.key), + }); err != nil { + t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) + } + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + } + if rd.String() == test.value { + count += 1 // If the expected value is found, increment the count. + } + } + // Fail if count is less than quorum. + if count < quorum { + t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) + return + } + } + + // Delete the key on the leader node + // 1. Prepare delete command. + command := []resp.Value{resp.StringValue("DEL")} + for _, test := range tests { + command = append(command, resp.StringValue(test.key)) + } + // 2. Send delete command. + if err := nodes[0].client.WriteArray(command); err != nil { + t.Error(err) + return + } + res, _, err := nodes[0].client.ReadValue() + if err != nil { + t.Error(err) + return + } + // 3. Check the delete count is equal to length of tests. + if res.Integer() != len(tests) { + t.Errorf("expected delete response to be %d, got %d", len(tests), res.Integer()) + } + + <-time.After(200 * time.Millisecond) // Yield + + // Check if the data is absent in quorum (majority of the cluster). + for i, test := range tests { + count := 0 + for j := 0; j < len(nodes); j++ { + node := nodes[j] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("GET"), + resp.StringValue(test.key), + }); err != nil { + t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) + } + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + } + if rd.IsNull() { + count += 1 // If the expected value is found, increment the count. + } + } + // Fail if count is less than quorum. + if count < quorum { + t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) + } + } + }) + + t.Run("Test_raftApplyDeleteKey", func(t *testing.T) { + tests := tests["raft-apply-delete"] + // Write all the data to the cluster leader. + for i, test := range tests { + node := nodes[0] + _, ok, err := node.server.Set(test.key, test.value, SetOptions{}) + if err != nil { + t.Errorf("could not write command to leader node (test %d): %v", i, err) + } + if !ok { + t.Errorf("expected set for test %d ok = true, got ok = false", i) + } + } + + <-time.After(200 * time.Millisecond) // Yield + + // Check if the data has been replicated on a quorum (majority of the cluster). + quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) + for i, test := range tests { + count := 0 + for j := 0; j < len(nodes); j++ { + node := nodes[j] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("GET"), + resp.StringValue(test.key), + }); err != nil { + t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) + } + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + } + if rd.String() == test.value { + count += 1 // If the expected value is found, increment the count. + } + } + // Fail if count is less than quorum. + if count < quorum { + t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) + return + } + } + + // Delete the keys using raftApplyDelete method. + for _, test := range tests { + if err := nodes[0].server.raftApplyDeleteKey(nodes[0].server.context, test.key); err != nil { + t.Error(err) + } + } + + <-time.After(200 * time.Millisecond) // Yield to give key deletion time to take effect across cluster. + + // Check if the data is absent in quorum (majority of the cluster). + for i, test := range tests { + count := 0 + for j := 0; j < len(nodes); j++ { + node := nodes[j] + if err := node.client.WriteArray([]resp.Value{ + resp.StringValue("GET"), + resp.StringValue(test.key), + }); err != nil { + t.Errorf("could not write command to follower node %d (test %d): %v", j, i, err) + } + rd, _, err := node.client.ReadValue() + if err != nil { + t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + } + if rd.IsNull() { + count += 1 // If the expected value is found, increment the count. + } + } + // Fail if count is less than quorum. + if count < quorum { + t.Errorf("found value %s at key %s in cluster quorum", test.value, test.key) + } + } + }) + + // t.Run("Test_ForwardCommand", func(t *testing.T) { + // tests := tests["forward"] + // // Write all the data a random cluster follower. + // for i, test := range tests { + // // Send write command to follower node. + // node := nodes[1] + // if err := node.client.WriteArray([]resp.Value{ + // resp.StringValue("SET"), + // resp.StringValue(test.key), + // resp.StringValue(test.value), + // }); err != nil { + // t.Errorf("could not write data to follower node (test %d): %v", i, err) + // } + // // Read response and make sure we received "ok" response. + // rd, _, err := node.client.ReadValue() + // if err != nil { + // t.Errorf("could not read response from follower node (test %d): %v", i, err) + // } + // if !strings.EqualFold(rd.String(), "ok") { + // t.Errorf("expected response for test %d to be \"OK\", got %s", i, rd.String()) + // } + // } + // + // <-time.After(200 * time.Millisecond) // Short yield to allow change to take effect. + // + // // Check if the data has been replicated on a quorum (majority of the cluster). + // quorum := int(math.Ceil(float64(len(nodes)/2)) + 1) + // for i, test := range tests { + // count := 0 + // for j := 0; j < len(nodes); j++ { + // node := nodes[j] + // if err := node.client.WriteArray([]resp.Value{ + // resp.StringValue("GET"), + // resp.StringValue(test.key), + // }); err != nil { + // t.Errorf("could not write data to follower node %d (test %d): %v", j, i, err) + // } + // rd, _, err := node.client.ReadValue() + // if err != nil { + // t.Errorf("could not read data from follower node %d (test %d): %v", j, i, err) + // } + // if rd.String() == test.value { + // count += 1 // If the expected value is found, increment the count. + // } + // } + // // Fail if count is less than quorum. + // if count < quorum { + // t.Errorf("could not find value %s at key %s in cluster quorum", test.value, test.key) + // } + // } + // }) +} func Test_TLS(t *testing.T) { port, err := internal.GetFreePort() @@ -464,6 +506,7 @@ func Test_TLS(t *testing.T) { server, err := NewEchoVault(WithConfig(conf)) if err != nil { t.Error(err) + return } wg := sync.WaitGroup{} @@ -494,8 +537,12 @@ func Test_TLS(t *testing.T) { }) if err != nil { t.Error(err) + return } - + defer func() { + _ = conn.Close() + server.ShutDown() + }() client := resp.NewConn(conn) // Test that we can set and get a value from the server. @@ -561,6 +608,7 @@ func Test_MTLS(t *testing.T) { server, err := NewEchoVault(WithConfig(conf)) if err != nil { t.Error(err) + return } wg := sync.WaitGroup{} @@ -613,7 +661,10 @@ func Test_MTLS(t *testing.T) { t.Error(err) return } - + defer func() { + _ = conn.Close() + server.ShutDown() + }() client := resp.NewConn(conn) // Test that we can set and get a value from the server. diff --git a/internal/memberlist/memberlist.go b/internal/memberlist/memberlist.go index 4b5e498..cd21091 100644 --- a/internal/memberlist/memberlist.go +++ b/internal/memberlist/memberlist.go @@ -160,13 +160,15 @@ func (m *MemberList) MemberListShutdown() { // Gracefully leave memberlist cluster err := m.memberList.Leave(500 * time.Millisecond) if err != nil { - log.Fatal("Could not gracefully leave memberlist cluster") + log.Printf("memberlist leave: %v\n", err) + return } err = m.memberList.Shutdown() if err != nil { - log.Fatal("Could not gracefully shutdown memberlist background maintenance") + log.Printf("memberlist shutdown: %v\n", err) + return } - log.Println("Successfully shutdown memberlist") + log.Println("successfully shutdown memberlist") } diff --git a/internal/modules/admin/commands_test.go b/internal/modules/admin/commands_test.go index 642d4aa..894e042 100644 --- a/internal/modules/admin/commands_test.go +++ b/internal/modules/admin/commands_test.go @@ -32,7 +32,6 @@ import ( str "github.com/echovault/echovault/internal/modules/string" "github.com/tidwall/resp" "net" - "os" "path" "slices" "strings" @@ -50,38 +49,41 @@ func setupServer(port uint16) (*echovault.EchoVault, error) { } func Test_AdminCommands(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + + mockServer, err := setupServer(uint16(port)) + if err != nil { + t.Error(err) + return + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + wg.Done() + mockServer.Start() + }() + wg.Wait() + t.Cleanup(func() { - _ = os.RemoveAll("./testdata") + mockServer.ShutDown() }) t.Run("Test COMMANDS command", func(t *testing.T) { t.Parallel() - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } + defer func() { + _ = conn.Close() + }() client := resp.NewConn(conn) if err = client.WriteArray([]resp.Value{resp.StringValue("COMMANDS")}); err != nil { @@ -128,31 +130,14 @@ func Test_AdminCommands(t *testing.T) { t.Run("Test COMMAND COUNT command", func(t *testing.T) { t.Parallel() - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } + defer func() { + _ = conn.Close() + }() client := resp.NewConn(conn) if err = client.WriteArray([]resp.Value{resp.StringValue("COMMAND"), resp.StringValue("COUNT")}); err != nil { @@ -199,31 +184,14 @@ func Test_AdminCommands(t *testing.T) { t.Run("Test COMMAND LIST command", func(t *testing.T) { t.Parallel() - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } + defer func() { + _ = conn.Close() + }() client := resp.NewConn(conn) // Get all the commands from the existing modules. @@ -336,17 +304,6 @@ func Test_AdminCommands(t *testing.T) { }) t.Run("Test MODULE LOAD command", func(t *testing.T) { - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - tests := []struct { name string execCommand []resp.Value @@ -433,20 +390,14 @@ func Test_AdminCommands(t *testing.T) { }, } - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } - + defer func() { + _ = conn.Close() + }() respConn := resp.NewConn(conn) for i := 0; i < len(tests); i++ { @@ -505,31 +456,14 @@ func Test_AdminCommands(t *testing.T) { }) t.Run("Test MODULE UNLOAD command", func(t *testing.T) { - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } - + defer func() { + _ = conn.Close() + }() respConn := resp.NewConn(conn) // Load module.set module @@ -693,31 +627,14 @@ func Test_AdminCommands(t *testing.T) { }) t.Run("Test MODULE LIST command", func(t *testing.T) { - port, err := internal.GetFreePort() - if err != nil { - t.Error(err) - return - } - mockServer, err := setupServer(uint16(port)) - if err != nil { - t.Error(err) - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - wg.Done() - mockServer.Start() - }() - wg.Wait() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { t.Error(err) return } - + defer func() { + _ = conn.Close() + }() respConn := resp.NewConn(conn) // Load module.get module with arg diff --git a/internal/modules/connection/commands_test.go b/internal/modules/connection/commands_test.go index 83632a7..8107e10 100644 --- a/internal/modules/connection/commands_test.go +++ b/internal/modules/connection/commands_test.go @@ -28,20 +28,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var port int -var addr = "localhost" +func Test_Connection(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ DataDir: "", EvictionPolicy: constants.NoEviction, - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -49,62 +55,70 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandlePing(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - tests := []struct { - command []resp.Value - expected string - expectedErr error - }{ - { - command: []resp.Value{resp.StringValue("PING")}, - expected: "PONG", - expectedErr: nil, - }, - { - command: []resp.Value{resp.StringValue("PING"), resp.StringValue("Hello, world!")}, - expected: "Hello, world!", - expectedErr: nil, - }, - { - command: []resp.Value{ - resp.StringValue("PING"), - resp.StringValue("Hello, world!"), - resp.StringValue("Once more"), - }, - expected: "", - expectedErr: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - if err = client.WriteArray(test.command); err != nil { + t.Run("Test_HandlePing", func(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { t.Error(err) return } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) + tests := []struct { + command []resp.Value + expected string + expectedErr error + }{ + { + command: []resp.Value{resp.StringValue("PING")}, + expected: "PONG", + expectedErr: nil, + }, + { + command: []resp.Value{resp.StringValue("PING"), resp.StringValue("Hello, world!")}, + expected: "Hello, world!", + expectedErr: nil, + }, + { + command: []resp.Value{ + resp.StringValue("PING"), + resp.StringValue("Hello, world!"), + resp.StringValue("Once more"), + }, + expected: "", + expectedErr: errors.New(constants.WrongArgsResponse), + }, } - if test.expectedErr != nil { - if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) + for _, test := range tests { + if err = client.WriteArray(test.command); err != nil { + t.Error(err) + return } - continue - } - if res.String() != test.expected { - t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String()) + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedErr != nil { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) + } + continue + } + + if res.String() != test.expected { + t.Errorf("expected response \"%s\", got \"%s\"", test.expected, res.String()) + } } - } + }) + } diff --git a/internal/modules/generic/commands_test.go b/internal/modules/generic/commands_test.go index 40c6044..c034ff3 100644 --- a/internal/modules/generic/commands_test.go +++ b/internal/modules/generic/commands_test.go @@ -30,27 +30,32 @@ import ( "time" ) -var addr string -var port int -var mockServer *echovault.EchoVault -var mockClock clock.Clock - type KeyData struct { Value interface{} ExpireAt time.Time } -func init() { - mockClock = clock.NewClock() - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( +func Test_Generic(t *testing.T) { + mockClock := clock.NewClock() + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } + + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -58,533 +63,459 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleSET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse interface{} - expectedValue interface{} - expectedExpiry time.Time - expectedErr error - }{ - { - name: "1. Set normal string value", - command: []string{"SET", "SetKey1", "value1"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value1", - expectedExpiry: time.Time{}, - expectedErr: nil, - }, - { - name: "2. Set normal integer value", - command: []string{"SET", "SetKey2", "1245678910"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "1245678910", - expectedExpiry: time.Time{}, - expectedErr: nil, - }, - { - name: "3. Set normal float value", - command: []string{"SET", "SetKey3", "45782.11341"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "45782.11341", - expectedExpiry: time.Time{}, - expectedErr: nil, - }, - { - name: "4. Only set the value if the key does not exist", - command: []string{"SET", "SetKey4", "value4", "NX"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value4", - expectedExpiry: time.Time{}, - expectedErr: nil, - }, - { - name: "5. Throw error when value already exists with NX flag passed", - command: []string{"SET", "SetKey5", "value5", "NX"}, - presetValues: map[string]KeyData{ - "SetKey5": { - Value: "preset-value5", - ExpireAt: time.Time{}, - }, - }, - expectedResponse: nil, - expectedValue: "preset-value5", - expectedExpiry: time.Time{}, - expectedErr: errors.New("key SetKey5 already exists"), - }, - { - name: "6. Set new key value when key exists with XX flag passed", - command: []string{"SET", "SetKey6", "value6", "XX"}, - presetValues: map[string]KeyData{ - "SetKey6": { - Value: "preset-value6", - ExpireAt: time.Time{}, - }, - }, - expectedResponse: "OK", - expectedValue: "value6", - expectedExpiry: time.Time{}, - expectedErr: nil, - }, - { - name: "7. Return error when setting non-existent key with XX flag", - command: []string{"SET", "SetKey7", "value7", "XX"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("key SetKey7 does not exist"), - }, - { - name: "8. Return error when NX flag is provided after XX flag", - command: []string{"SET", "SetKey8", "value8", "XX", "NX"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify NX when XX is already specified"), - }, - { - name: "9. Return error when XX flag is provided after NX flag", - command: []string{"SET", "SetKey9", "value9", "NX", "XX"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify XX when NX is already specified"), - }, - { - name: "10. Set expiry time on the key to 100 seconds from now", - command: []string{"SET", "SetKey10", "value10", "EX", "100"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value10", - expectedExpiry: mockClock.Now().Add(100 * time.Second), - expectedErr: nil, - }, - { - name: "11. Return error when EX flag is passed without seconds value", - command: []string{"SET", "SetKey11", "value11", "EX"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("seconds value required after EX"), - }, - { - name: "12. Return error when EX flag is passed with invalid (non-integer) value", - command: []string{"SET", "SetKey12", "value12", "EX", "seconds"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("seconds value should be an integer"), - }, - { - name: "13. Return error when trying to set expiry seconds when expiry is already set", - command: []string{"SET", "SetKey13", "value13", "PX", "100000", "EX", "100"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify EX when expiry time is already set"), - }, - { - name: "14. Set expiry time on the key in unix milliseconds", - command: []string{"SET", "SetKey14", "value14", "PX", "4096"}, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value14", - expectedExpiry: mockClock.Now().Add(4096 * time.Millisecond), - expectedErr: nil, - }, - { - name: "15. Return error when PX flag is passed without milliseconds value", - command: []string{"SET", "SetKey15", "value15", "PX"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("milliseconds value required after PX"), - }, - { - name: "16. Return error when PX flag is passed with invalid (non-integer) value", - command: []string{"SET", "SetKey16", "value16", "PX", "milliseconds"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("milliseconds value should be an integer"), - }, - { - name: "17. Return error when trying to set expiry milliseconds when expiry is already provided", - command: []string{"SET", "SetKey17", "value17", "EX", "10", "PX", "1000000"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify PX when expiry time is already set"), - }, - { - name: "18. Set exact expiry time in seconds from unix epoch", - command: []string{ - "SET", "SetKey18", "value18", - "EXAT", fmt.Sprintf("%d", mockClock.Now().Add(200*time.Second).Unix()), - }, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value18", - expectedExpiry: mockClock.Now().Add(200 * time.Second), - expectedErr: nil, - }, - { - name: "19. Return error when trying to set exact seconds expiry time when expiry time is already provided", - command: []string{ - "SET", "SetKey19", "value19", - "EX", "10", - "EXAT", fmt.Sprintf("%d", mockClock.Now().Add(200*time.Second).Unix()), - }, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify EXAT when expiry time is already set"), - }, - { - name: "20. Return error when no seconds value is provided after EXAT flag", - command: []string{"SET", "SetKey20", "value20", "EXAT"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("seconds value required after EXAT"), - }, - { - name: "21. Return error when invalid (non-integer) value is passed after EXAT flag", - command: []string{"SET", "SekKey21", "value21", "EXAT", "seconds"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("seconds value should be an integer"), - }, - { - name: "22. Set exact expiry time in milliseconds from unix epoch", - command: []string{ - "SET", "SetKey22", "value22", - "PXAT", fmt.Sprintf("%d", mockClock.Now().Add(4096*time.Millisecond).UnixMilli()), - }, - presetValues: nil, - expectedResponse: "OK", - expectedValue: "value22", - expectedExpiry: mockClock.Now().Add(4096 * time.Millisecond), - expectedErr: nil, - }, - { - name: "23. Return error when trying to set exact milliseconds expiry time when expiry time is already provided", - command: []string{ - "SET", "SetKey23", "value23", - "PX", "1000", - "PXAT", fmt.Sprintf("%d", mockClock.Now().Add(4096*time.Millisecond).UnixMilli()), - }, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("cannot specify PXAT when expiry time is already set"), - }, - { - name: "24. Return error when no milliseconds value is provided after PXAT flag", - command: []string{"SET", "SetKey24", "value24", "PXAT"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("milliseconds value required after PXAT"), - }, - { - name: "25. Return error when invalid (non-integer) value is passed after EXAT flag", - command: []string{"SET", "SetKey25", "value25", "PXAT", "unix-milliseconds"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "", - expectedExpiry: time.Time{}, - expectedErr: errors.New("milliseconds value should be an integer"), - }, - { - name: "26. Get the previous value when GET flag is passed", - command: []string{"SET", "SetKey26", "value26", "GET", "EX", "1000"}, - presetValues: map[string]KeyData{ - "SetKey26": { - Value: "previous-value", - ExpireAt: time.Time{}, - }, - }, - expectedResponse: "previous-value", - expectedValue: "value26", - expectedExpiry: mockClock.Now().Add(1000 * time.Second), - expectedErr: nil, - }, - { - name: "27. Return nil when GET value is passed and no previous value exists", - command: []string{"SET", "SetKey27", "value27", "GET", "EX", "1000"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: "value27", - expectedExpiry: mockClock.Now().Add(1000 * time.Second), - expectedErr: nil, - }, - { - name: "28. Throw error when unknown optional flag is passed to SET command.", - command: []string{"SET", "SetKey28", "value28", "UNKNOWN-OPTION"}, - presetValues: nil, - expectedResponse: nil, - expectedValue: nil, - expectedExpiry: time.Time{}, - expectedErr: errors.New("unknown option UNKNOWN-OPTION for set command"), - }, - { - name: "29. Command too short", - command: []string{"SET"}, - expectedResponse: nil, - expectedValue: nil, - expectedErr: errors.New(constants.WrongArgsResponse), - }, - { - name: "30. Command too long", - command: []string{"SET", "SetKey30", "value1", "value2", "value3", "value4", "value5", "value6"}, - expectedResponse: nil, - expectedValue: nil, - expectedErr: errors.New(constants.WrongArgsResponse), - }, - } + t.Run("Test_HandleSET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - cmd := []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(k), - resp.StringValue(v.Value.(string))} - err := client.WriteArray(cmd) - if err != nil { - t.Error(err) - } - rd, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - if !strings.EqualFold(rd.String(), "ok") { - t.Errorf("expected preset response to be \"OK\", got %s", rd.String()) + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse interface{} + expectedValue interface{} + expectedExpiry time.Time + expectedErr error + }{ + { + name: "1. Set normal string value", + command: []string{"SET", "SetKey1", "value1"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value1", + expectedExpiry: time.Time{}, + expectedErr: nil, + }, + { + name: "2. Set normal integer value", + command: []string{"SET", "SetKey2", "1245678910"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "1245678910", + expectedExpiry: time.Time{}, + expectedErr: nil, + }, + { + name: "3. Set normal float value", + command: []string{"SET", "SetKey3", "45782.11341"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "45782.11341", + expectedExpiry: time.Time{}, + expectedErr: nil, + }, + { + name: "4. Only set the value if the key does not exist", + command: []string{"SET", "SetKey4", "value4", "NX"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value4", + expectedExpiry: time.Time{}, + expectedErr: nil, + }, + { + name: "5. Throw error when value already exists with NX flag passed", + command: []string{"SET", "SetKey5", "value5", "NX"}, + presetValues: map[string]KeyData{ + "SetKey5": { + Value: "preset-value5", + ExpireAt: time.Time{}, + }, + }, + expectedResponse: nil, + expectedValue: "preset-value5", + expectedExpiry: time.Time{}, + expectedErr: errors.New("key SetKey5 already exists"), + }, + { + name: "6. Set new key value when key exists with XX flag passed", + command: []string{"SET", "SetKey6", "value6", "XX"}, + presetValues: map[string]KeyData{ + "SetKey6": { + Value: "preset-value6", + ExpireAt: time.Time{}, + }, + }, + expectedResponse: "OK", + expectedValue: "value6", + expectedExpiry: time.Time{}, + expectedErr: nil, + }, + { + name: "7. Return error when setting non-existent key with XX flag", + command: []string{"SET", "SetKey7", "value7", "XX"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("key SetKey7 does not exist"), + }, + { + name: "8. Return error when NX flag is provided after XX flag", + command: []string{"SET", "SetKey8", "value8", "XX", "NX"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify NX when XX is already specified"), + }, + { + name: "9. Return error when XX flag is provided after NX flag", + command: []string{"SET", "SetKey9", "value9", "NX", "XX"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify XX when NX is already specified"), + }, + { + name: "10. Set expiry time on the key to 100 seconds from now", + command: []string{"SET", "SetKey10", "value10", "EX", "100"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value10", + expectedExpiry: mockClock.Now().Add(100 * time.Second), + expectedErr: nil, + }, + { + name: "11. Return error when EX flag is passed without seconds value", + command: []string{"SET", "SetKey11", "value11", "EX"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("seconds value required after EX"), + }, + { + name: "12. Return error when EX flag is passed with invalid (non-integer) value", + command: []string{"SET", "SetKey12", "value12", "EX", "seconds"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("seconds value should be an integer"), + }, + { + name: "13. Return error when trying to set expiry seconds when expiry is already set", + command: []string{"SET", "SetKey13", "value13", "PX", "100000", "EX", "100"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify EX when expiry time is already set"), + }, + { + name: "14. Set expiry time on the key in unix milliseconds", + command: []string{"SET", "SetKey14", "value14", "PX", "4096"}, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value14", + expectedExpiry: mockClock.Now().Add(4096 * time.Millisecond), + expectedErr: nil, + }, + { + name: "15. Return error when PX flag is passed without milliseconds value", + command: []string{"SET", "SetKey15", "value15", "PX"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("milliseconds value required after PX"), + }, + { + name: "16. Return error when PX flag is passed with invalid (non-integer) value", + command: []string{"SET", "SetKey16", "value16", "PX", "milliseconds"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("milliseconds value should be an integer"), + }, + { + name: "17. Return error when trying to set expiry milliseconds when expiry is already provided", + command: []string{"SET", "SetKey17", "value17", "EX", "10", "PX", "1000000"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify PX when expiry time is already set"), + }, + { + name: "18. Set exact expiry time in seconds from unix epoch", + command: []string{ + "SET", "SetKey18", "value18", + "EXAT", fmt.Sprintf("%d", mockClock.Now().Add(200*time.Second).Unix()), + }, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value18", + expectedExpiry: mockClock.Now().Add(200 * time.Second), + expectedErr: nil, + }, + { + name: "19. Return error when trying to set exact seconds expiry time when expiry time is already provided", + command: []string{ + "SET", "SetKey19", "value19", + "EX", "10", + "EXAT", fmt.Sprintf("%d", mockClock.Now().Add(200*time.Second).Unix()), + }, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify EXAT when expiry time is already set"), + }, + { + name: "20. Return error when no seconds value is provided after EXAT flag", + command: []string{"SET", "SetKey20", "value20", "EXAT"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("seconds value required after EXAT"), + }, + { + name: "21. Return error when invalid (non-integer) value is passed after EXAT flag", + command: []string{"SET", "SekKey21", "value21", "EXAT", "seconds"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("seconds value should be an integer"), + }, + { + name: "22. Set exact expiry time in milliseconds from unix epoch", + command: []string{ + "SET", "SetKey22", "value22", + "PXAT", fmt.Sprintf("%d", mockClock.Now().Add(4096*time.Millisecond).UnixMilli()), + }, + presetValues: nil, + expectedResponse: "OK", + expectedValue: "value22", + expectedExpiry: mockClock.Now().Add(4096 * time.Millisecond), + expectedErr: nil, + }, + { + name: "23. Return error when trying to set exact milliseconds expiry time when expiry time is already provided", + command: []string{ + "SET", "SetKey23", "value23", + "PX", "1000", + "PXAT", fmt.Sprintf("%d", mockClock.Now().Add(4096*time.Millisecond).UnixMilli()), + }, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("cannot specify PXAT when expiry time is already set"), + }, + { + name: "24. Return error when no milliseconds value is provided after PXAT flag", + command: []string{"SET", "SetKey24", "value24", "PXAT"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("milliseconds value required after PXAT"), + }, + { + name: "25. Return error when invalid (non-integer) value is passed after EXAT flag", + command: []string{"SET", "SetKey25", "value25", "PXAT", "unix-milliseconds"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "", + expectedExpiry: time.Time{}, + expectedErr: errors.New("milliseconds value should be an integer"), + }, + { + name: "26. Get the previous value when GET flag is passed", + command: []string{"SET", "SetKey26", "value26", "GET", "EX", "1000"}, + presetValues: map[string]KeyData{ + "SetKey26": { + Value: "previous-value", + ExpireAt: time.Time{}, + }, + }, + expectedResponse: "previous-value", + expectedValue: "value26", + expectedExpiry: mockClock.Now().Add(1000 * time.Second), + expectedErr: nil, + }, + { + name: "27. Return nil when GET value is passed and no previous value exists", + command: []string{"SET", "SetKey27", "value27", "GET", "EX", "1000"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: "value27", + expectedExpiry: mockClock.Now().Add(1000 * time.Second), + expectedErr: nil, + }, + { + name: "28. Throw error when unknown optional flag is passed to SET command.", + command: []string{"SET", "SetKey28", "value28", "UNKNOWN-OPTION"}, + presetValues: nil, + expectedResponse: nil, + expectedValue: nil, + expectedExpiry: time.Time{}, + expectedErr: errors.New("unknown option UNKNOWN-OPTION for set command"), + }, + { + name: "29. Command too short", + command: []string{"SET"}, + expectedResponse: nil, + expectedValue: nil, + expectedErr: errors.New(constants.WrongArgsResponse), + }, + { + name: "30. Command too long", + command: []string{"SET", "SetKey30", "value1", "value2", "value3", "value4", "value5", "value6"}, + expectedResponse: nil, + expectedValue: nil, + expectedErr: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + cmd := []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(k), + resp.StringValue(v.Value.(string))} + err := client.WriteArray(cmd) + if err != nil { + t.Error(err) + } + rd, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(rd.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", rd.String()) + } } } - } - command := make([]resp.Value, len(test.command)) - for j, c := range test.command { - command[j] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - - if test.expectedErr != nil { - if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for j, c := range test.command { + command[j] = resp.StringValue(c) } - return - } - if err != nil { - t.Error(err) - } - switch test.expectedResponse.(type) { - case string: - if test.expectedResponse != res.String() { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } - case nil: - if !res.IsNull() { - t.Errorf("expcted nil response, got %+v", res) - } - default: - t.Error("test expected result should be nil or string") - } - - key := test.command[1] - - // Compare expected value to response value - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if res.String() != test.expectedValue.(string) { - t.Errorf("expected value %s, got %s", test.expectedValue.(string), res.String()) - } - - // Compare expected expiry to response expiry - if !test.expectedExpiry.Equal(time.Time{}) { - if err = client.WriteArray([]resp.Value{resp.StringValue("EXPIRETIME"), resp.StringValue(key)}); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - res, _, err = client.ReadValue() + + res, _, err := client.ReadValue() + + if test.expectedErr != nil { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), err.Error()) + } + return + } if err != nil { t.Error(err) } - if res.Integer() != int(test.expectedExpiry.Unix()) { - t.Errorf("expected expiry time %d, got %d", test.expectedExpiry.Unix(), res.Integer()) - } - } - }) - } -} -func Test_HandleMSET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - expectedResponse string - expectedValues map[string]interface{} - expectedErr error - }{ - { - name: "1. Set multiple key value pairs", - command: []string{"MSET", "MsetKey1", "value1", "MsetKey2", "10", "MsetKey3", "3.142"}, - expectedResponse: "OK", - expectedValues: map[string]interface{}{"MsetKey1": "value1", "MsetKey2": 10, "MsetKey3": 3.142}, - expectedErr: nil, - }, - { - name: "2. Return error when keys and values are not even", - command: []string{"MSET", "MsetKey1", "value1", "MsetKey2", "10", "MsetKey3"}, - expectedResponse: "", - expectedValues: make(map[string]interface{}), - expectedErr: errors.New("each key must be paired with a value"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - command := make([]resp.Value, len(test.command)) - for j, c := range test.command { - command[j] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedErr != nil { - if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { - t.Errorf("expected error %s, got %s", test.expectedErr.Error(), err.Error()) - } - return - } - - if res.String() != test.expectedResponse { - t.Errorf("expected response %s, got %s", test.expectedResponse, res.String()) - } - - for key, expectedValue := range test.expectedValues { - // Get value from server - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - switch expectedValue.(type) { - default: - t.Error("unexpected type for expectedValue") - case int: - ev, _ := expectedValue.(int) - if res.Integer() != ev { - t.Errorf("expected value %d for key %s, got %d", ev, key, res.Integer()) - } - case float64: - ev, _ := expectedValue.(float64) - if res.Float() != ev { - t.Errorf("expected value %f for key %s, got %f", ev, key, res.Float()) - } + switch test.expectedResponse.(type) { case string: - ev, _ := expectedValue.(string) - if res.String() != ev { - t.Errorf("expected value %s for key %s, got %s", ev, key, res.String()) + if test.expectedResponse != res.String() { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + case nil: + if !res.IsNull() { + t.Errorf("expcted nil response, got %+v", res) + } + default: + t.Error("test expected result should be nil or string") + } + + key := test.command[1] + + // Compare expected value to response value + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if res.String() != test.expectedValue.(string) { + t.Errorf("expected value %s, got %s", test.expectedValue.(string), res.String()) + } + + // Compare expected expiry to response expiry + if !test.expectedExpiry.Equal(time.Time{}) { + if err = client.WriteArray([]resp.Value{resp.StringValue("EXPIRETIME"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if res.Integer() != int(test.expectedExpiry.Unix()) { + t.Errorf("expected expiry time %d, got %d", test.expectedExpiry.Unix(), res.Integer()) } } - } - }) - } -} + }) + } + }) -func Test_HandleGET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleMSET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - key string - value string - }{ - { - name: "1. String", - key: "GetKey1", - value: "value1", - }, - { - name: "2. Integer", - key: "GetKey2", - value: "10", - }, - { - name: "3. Float", - key: "GetKey3", - value: "3.142", - }, - } - // Test successful Get command - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - func(key, value string) { - // Preset the values - err = client.WriteArray([]resp.Value{resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value)}) - if err != nil { + tests := []struct { + name string + command []string + expectedResponse string + expectedValues map[string]interface{} + expectedErr error + }{ + { + name: "1. Set multiple key value pairs", + command: []string{"MSET", "MsetKey1", "value1", "MsetKey2", "10", "MsetKey3", "3.142"}, + expectedResponse: "OK", + expectedValues: map[string]interface{}{"MsetKey1": "value1", "MsetKey2": 10, "MsetKey3": 3.142}, + expectedErr: nil, + }, + { + name: "2. Return error when keys and values are not even", + command: []string{"MSET", "MsetKey1", "value1", "MsetKey2", "10", "MsetKey3"}, + expectedResponse: "", + expectedValues: make(map[string]interface{}), + expectedErr: errors.New("each key must be paired with a value"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + command := make([]resp.Value, len(test.command)) + for j, c := range test.command { + command[j] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } @@ -593,11 +524,151 @@ func Test_HandleGET(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + if test.expectedErr != nil { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error %s, got %s", test.expectedErr.Error(), err.Error()) + } + return } - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + if res.String() != test.expectedResponse { + t.Errorf("expected response %s, got %s", test.expectedResponse, res.String()) + } + + for key, expectedValue := range test.expectedValues { + // Get value from server + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + switch expectedValue.(type) { + default: + t.Error("unexpected type for expectedValue") + case int: + ev, _ := expectedValue.(int) + if res.Integer() != ev { + t.Errorf("expected value %d for key %s, got %d", ev, key, res.Integer()) + } + case float64: + ev, _ := expectedValue.(float64) + if res.Float() != ev { + t.Errorf("expected value %f for key %s, got %f", ev, key, res.Float()) + } + case string: + ev, _ := expectedValue.(string) + if res.String() != ev { + t.Errorf("expected value %s for key %s, got %s", ev, key, res.String()) + } + } + } + }) + } + }) + + t.Run("Test_HandleGET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + value string + }{ + { + name: "1. String", + key: "GetKey1", + value: "value1", + }, + { + name: "2. Integer", + key: "GetKey2", + value: "10", + }, + { + name: "3. Float", + key: "GetKey3", + value: "3.142", + }, + } + // Test successful Get command + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + func(key, value string) { + // Preset the values + err = client.WriteArray([]resp.Value{resp.StringValue("SET"), resp.StringValue(key), resp.StringValue(value)}) + if err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + } + + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if res.String() != test.value { + t.Errorf("expected value %s, got %s", test.value, res.String()) + } + }(test.key, test.value) + }) + } + + // Test get non-existent key + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue("test4")}); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !res.IsNull() { + t.Errorf("expected nil, got: %+v", res) + } + + errorTests := []struct { + name string + command []string + expected string + }{ + { + name: "1. Return error when no GET key is passed", + command: []string{"GET"}, + expected: constants.WrongArgsResponse, + }, + { + name: "2. Return error when too many GET keys are passed", + command: []string{"GET", "GetKey1", "test"}, + expected: constants.WrongArgsResponse, + }, + } + for _, test := range errorTests { + t.Run(test.name, func(t *testing.T) { + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } @@ -606,220 +677,67 @@ func Test_HandleGET(t *testing.T) { t.Error(err) } - if res.String() != test.value { - t.Errorf("expected value %s, got %s", test.value, res.String()) + if !strings.Contains(res.Error().Error(), test.expected) { + t.Errorf("expected error '%s', got: %s", test.expected, err.Error()) } - }(test.key, test.value) - }) - } + }) + } + }) - // Test get non-existent key - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue("test4")}); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - if !res.IsNull() { - t.Errorf("expected nil, got: %+v", res) - } + t.Run("Test_HandleMGET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - errorTests := []struct { - name string - command []string - expected string - }{ - { - name: "1. Return error when no GET key is passed", - command: []string{"GET"}, - expected: constants.WrongArgsResponse, - }, - { - name: "2. Return error when too many GET keys are passed", - command: []string{"GET", "GetKey1", "test"}, - expected: constants.WrongArgsResponse, - }, - } - for _, test := range errorTests { - t.Run(test.name, func(t *testing.T) { - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.Contains(res.Error().Error(), test.expected) { - t.Errorf("expected error '%s', got: %s", test.expected, err.Error()) - } - }) - } -} - -func Test_HandleMGET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetKeys []string - presetValues []string - command []string - expected []interface{} - expectedError error - }{ - { - name: "1. MGET multiple existing values", - presetKeys: []string{"MgetKey1", "MgetKey2", "MgetKey3", "MgetKey4"}, - presetValues: []string{"value1", "value2", "value3", "value4"}, - command: []string{"MGET", "MgetKey1", "MgetKey4", "MgetKey2", "MgetKey3", "MgetKey1"}, - expected: []interface{}{"value1", "value4", "value2", "value3", "value1"}, - expectedError: nil, - }, - { - name: "2. MGET multiple values with nil values spliced in", - presetKeys: []string{"MgetKey5", "MgetKey6", "MgetKey7"}, - presetValues: []string{"value5", "value6", "value7"}, - command: []string{"MGET", "MgetKey5", "MgetKey6", "non-existent", "non-existent", "MgetKey7", "non-existent"}, - expected: []interface{}{"value5", "value6", nil, nil, "value7", nil}, - expectedError: nil, - }, - { - name: "3. Return error when MGET is invoked with no keys", - presetKeys: []string{"MgetKey5"}, - presetValues: []string{"value5"}, - command: []string{"MGET"}, - expected: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - // Set up the values - for i, key := range test.presetKeys { - if err = client.WriteArray([]resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(test.presetValues[i]), - }); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be \"OK\", got \"%s\"", res.String()) - } - } - - // Test the command and its results - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - // If we expect and error, branch out and check error - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error %+v, got: %+v", test.expectedError, err) - } - return - } - - if res.Type().String() != "Array" { - t.Errorf("expected type Array, got: %s", res.Type().String()) - } - for i, value := range res.Array() { - if test.expected[i] == nil { - if !value.IsNull() { - t.Errorf("expected nil value, got %+v", value) - } - continue - } - if value.String() != test.expected[i] { - t.Errorf("expected value %s, got: %s", test.expected[i], value.String()) - } - } - }) - } -} - -func Test_HandleDEL(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - presetValues map[string]string - expectedResponse int - expectToExist map[string]bool - expectedErr error - }{ - { - name: "1. Delete multiple keys", - command: []string{"DEL", "DelKey1", "DelKey2", "DelKey3", "DelKey4", "DelKey5"}, - presetValues: map[string]string{ - "DelKey1": "value1", - "DelKey2": "value2", - "DelKey3": "value3", - "DelKey4": "value4", + tests := []struct { + name string + presetKeys []string + presetValues []string + command []string + expected []interface{} + expectedError error + }{ + { + name: "1. MGET multiple existing values", + presetKeys: []string{"MgetKey1", "MgetKey2", "MgetKey3", "MgetKey4"}, + presetValues: []string{"value1", "value2", "value3", "value4"}, + command: []string{"MGET", "MgetKey1", "MgetKey4", "MgetKey2", "MgetKey3", "MgetKey1"}, + expected: []interface{}{"value1", "value4", "value2", "value3", "value1"}, + expectedError: nil, }, - expectedResponse: 4, - expectToExist: map[string]bool{ - "DelKey1": false, - "DelKey2": false, - "DelKey3": false, - "DelKey4": false, - "DelKey5": false, + { + name: "2. MGET multiple values with nil values spliced in", + presetKeys: []string{"MgetKey5", "MgetKey6", "MgetKey7"}, + presetValues: []string{"value5", "value6", "value7"}, + command: []string{"MGET", "MgetKey5", "MgetKey6", "non-existent", "non-existent", "MgetKey7", "non-existent"}, + expected: []interface{}{"value5", "value6", nil, nil, "value7", nil}, + expectedError: nil, }, - expectedErr: nil, - }, - { - name: "2. Return error when DEL is called with no keys", - command: []string{"DEL"}, - presetValues: nil, - expectedResponse: 0, - expectToExist: nil, - expectedErr: errors.New(constants.WrongArgsResponse), - }, - } + { + name: "3. Return error when MGET is invoked with no keys", + presetKeys: []string{"MgetKey5"}, + presetValues: []string{"value5"}, + command: []string{"MGET"}, + expected: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Set up the values + for i, key := range test.presetKeys { if err = client.WriteArray([]resp.Value{ resp.StringValue("SET"), - resp.StringValue(k), - resp.StringValue(v), + resp.StringValue(key), + resp.StringValue(test.presetValues[i]), }); err != nil { t.Error(err) } @@ -828,1026 +746,1158 @@ func Test_HandleDEL(t *testing.T) { t.Error(err) } if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + t.Errorf("expected preset response to be \"OK\", got \"%s\"", res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedErr != nil { - if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) + // Test the command and its results + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - for key, expected := range test.expectToExist { - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - res, _, err = client.ReadValue() + + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - exists := !res.IsNull() - if exists != expected { - t.Errorf("expected existence of key %s to be %v, got %v", key, expected, exists) - } - } - }) - } -} -func Test_HandlePERSIST(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse int - expectedValues map[string]KeyData - expectedError error - }{ - { - name: "1. Successfully persist a volatile key", - command: []string{"PERSIST", "PersistKey1"}, - presetValues: map[string]KeyData{ - "PersistKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "PersistKey1": {Value: "value1", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "2. Return 0 when trying to persist a non-existent key", - command: []string{"PERSIST", "PersistKey2"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: nil, - }, - { - name: "3. Return 0 when trying to persist a non-volatile key", - command: []string{"PERSIST", "PersistKey3"}, - presetValues: map[string]KeyData{ - "PersistKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "PersistKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "4. Command too short", - command: []string{"PERSIST"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - command: []string{"PERSIST", "PersistKey5", "key6"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} - if !v.ExpireAt.Equal(time.Time{}) { - command = append(command, []resp.Value{ - resp.StringValue("PX"), - resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), - }...) + if test.expectedError != nil { + // If we expect and error, branch out and check error + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error %+v, got: %+v", test.expectedError, err) } - if err = client.WriteArray(command); err != nil { + return + } + + if res.Type().String() != "Array" { + t.Errorf("expected type Array, got: %s", res.Type().String()) + } + for i, value := range res.Array() { + if test.expected[i] == nil { + if !value.IsNull() { + t.Errorf("expected nil value, got %+v", value) + } + continue + } + if value.String() != test.expected[i] { + t.Errorf("expected value %s, got: %s", test.expected[i], value.String()) + } + } + }) + } + }) + + t.Run("Test_HandleDEL", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + command []string + presetValues map[string]string + expectedResponse int + expectToExist map[string]bool + expectedErr error + }{ + { + name: "1. Delete multiple keys", + command: []string{"DEL", "DelKey1", "DelKey2", "DelKey3", "DelKey4", "DelKey5"}, + presetValues: map[string]string{ + "DelKey1": "value1", + "DelKey2": "value2", + "DelKey3": "value3", + "DelKey4": "value4", + }, + expectedResponse: 4, + expectToExist: map[string]bool{ + "DelKey1": false, + "DelKey2": false, + "DelKey3": false, + "DelKey4": false, + "DelKey5": false, + }, + expectedErr: nil, + }, + { + name: "2. Return error when DEL is called with no keys", + command: []string{"DEL"}, + presetValues: nil, + expectedResponse: 0, + expectToExist: nil, + expectedErr: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(k), + resp.StringValue(v), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be \"OK\", got %s", res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedErr != nil { + if !strings.Contains(res.Error().Error(), test.expectedErr.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedErr.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + for key, expected := range test.expectToExist { + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + exists := !res.IsNull() + if exists != expected { + t.Errorf("expected existence of key %s to be %v, got %v", key, expected, exists) } } - } + }) + } + }) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } + t.Run("Test_HandlePERSIST", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - if err = client.WriteArray(command); err != nil { - t.Error(err) - } + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse int + expectedValues map[string]KeyData + expectedError error + }{ + { + name: "1. Successfully persist a volatile key", + command: []string{"PERSIST", "PersistKey1"}, + presetValues: map[string]KeyData{ + "PersistKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "PersistKey1": {Value: "value1", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "2. Return 0 when trying to persist a non-existent key", + command: []string{"PERSIST", "PersistKey2"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: nil, + }, + { + name: "3. Return 0 when trying to persist a non-volatile key", + command: []string{"PERSIST", "PersistKey3"}, + presetValues: map[string]KeyData{ + "PersistKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "PersistKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "4. Command too short", + command: []string{"PERSIST"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + command: []string{"PERSIST", "PersistKey5", "key6"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } - if test.expectedValues == nil { - return - } - - for key, expected := range test.expectedValues { - // Compare the value of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - res, _, err = client.ReadValue() + + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if res.String() != expected.Value.(string) { - t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) - } - // Compare the expiry of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if expected.ExpireAt.Equal(time.Time{}) { - if res.Integer() != -1 { - t.Error("expected key to be persisted, it was not.") + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - continue + return } - if res.Integer() != int(expected.ExpireAt.UnixMilli()) { - t.Errorf("expected expiry %d, got %d", expected.ExpireAt.UnixMilli(), res.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } - } - }) - } -} -func Test_HandleEXPIRETIME(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + if test.expectedValues == nil { + return + } - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse int - expectedError error - }{ - { - name: "1. Return expire time in seconds", - command: []string{"EXPIRETIME", "ExpireTimeKey1"}, - presetValues: map[string]KeyData{ - "ExpireTimeKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, - }, - expectedResponse: int(mockClock.Now().Add(100 * time.Second).Unix()), - expectedError: nil, - }, - { - name: "2. Return expire time in milliseconds", - command: []string{"PEXPIRETIME", "ExpireTimeKey2"}, - presetValues: map[string]KeyData{ - "ExpireTimeKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(4096 * time.Millisecond)}, - }, - expectedResponse: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli()), - expectedError: nil, - }, - { - name: "3. If the key is non-volatile, return -1", - command: []string{"PEXPIRETIME", "ExpireTimeKey3"}, - presetValues: map[string]KeyData{ - "ExpireTimeKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedResponse: -1, - expectedError: nil, - }, - { - name: "4. If the key is non-existent return -2", - command: []string{"PEXPIRETIME", "ExpireTimeKey4"}, - presetValues: nil, - expectedResponse: -2, - expectedError: nil, - }, - { - name: "5. Command too short", - command: []string{"PEXPIRETIME"}, - presetValues: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"PEXPIRETIME", "ExpireTimeKey5", "ExpireTimeKey6"}, - presetValues: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} - if !v.ExpireAt.Equal(time.Time{}) { - command = append(command, []resp.Value{ - resp.StringValue("PX"), - resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), - }...) - } - if err = client.WriteArray(command); err != nil { + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleTTL(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse int - expectedError error - }{ - { - name: "1. Return TTL time in seconds", - command: []string{"TTL", "TTLKey1"}, - presetValues: map[string]KeyData{ - "TTLKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, - }, - expectedResponse: 100, - expectedError: nil, - }, - { - name: "2. Return TTL time in milliseconds", - command: []string{"PTTL", "TTLKey2"}, - presetValues: map[string]KeyData{ - "TTLKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(4096 * time.Millisecond)}, - }, - expectedResponse: 4096, - expectedError: nil, - }, - { - name: "3. If the key is non-volatile, return -1", - command: []string{"TTL", "TTLKey3"}, - presetValues: map[string]KeyData{ - "TTLKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedResponse: -1, - expectedError: nil, - }, - { - name: "4. If the key is non-existent return -2", - command: []string{"TTL", "TTLKey4"}, - presetValues: nil, - expectedResponse: -2, - expectedError: nil, - }, - { - name: "5. Command too short", - command: []string{"TTL"}, - presetValues: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"TTL", "TTLKey5", "TTLKey6"}, - presetValues: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} - if !v.ExpireAt.Equal(time.Time{}) { - command = append(command, []resp.Value{ - resp.StringValue("PX"), - resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), - }...) - } - if err = client.WriteArray(command); err != nil { + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.UnixMilli()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.UnixMilli(), res.Integer()) } } - } + }) + } + }) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } + t.Run("Test_HandleEXPIRETIME", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - if err = client.WriteArray(command); err != nil { - t.Error(err) - } + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse int + expectedError error + }{ + { + name: "1. Return expire time in seconds", + command: []string{"EXPIRETIME", "ExpireTimeKey1"}, + presetValues: map[string]KeyData{ + "ExpireTimeKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, + }, + expectedResponse: int(mockClock.Now().Add(100 * time.Second).Unix()), + expectedError: nil, + }, + { + name: "2. Return expire time in milliseconds", + command: []string{"PEXPIRETIME", "ExpireTimeKey2"}, + presetValues: map[string]KeyData{ + "ExpireTimeKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(4096 * time.Millisecond)}, + }, + expectedResponse: int(mockClock.Now().Add(4096 * time.Millisecond).UnixMilli()), + expectedError: nil, + }, + { + name: "3. If the key is non-volatile, return -1", + command: []string{"PEXPIRETIME", "ExpireTimeKey3"}, + presetValues: map[string]KeyData{ + "ExpireTimeKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedResponse: -1, + expectedError: nil, + }, + { + name: "4. If the key is non-existent return -2", + command: []string{"PEXPIRETIME", "ExpireTimeKey4"}, + presetValues: nil, + expectedResponse: -2, + expectedError: nil, + }, + { + name: "5. Command too short", + command: []string{"PEXPIRETIME"}, + presetValues: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + command: []string{"PEXPIRETIME", "ExpireTimeKey5", "ExpireTimeKey6"}, + presetValues: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleEXPIRE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse int - expectedValues map[string]KeyData - expectedError error - }{ - { - name: "1. Set new expire by seconds", - command: []string{"EXPIRE", "ExpireKey1", "100"}, - presetValues: map[string]KeyData{ - "ExpireKey1": {Value: "value1", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "2. Set new expire by milliseconds", - command: []string{"PEXPIRE", "ExpireKey2", "1000"}, - presetValues: map[string]KeyData{ - "ExpireKey2": {Value: "value2", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(1000 * time.Millisecond)}, - }, - expectedError: nil, - }, - { - name: "3. Set new expire only when key does not have an expiry time with NX flag", - command: []string{"EXPIRE", "ExpireKey3", "1000", "NX"}, - presetValues: map[string]KeyData{ - "ExpireKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey3": {Value: "value3", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "4. Return 0, when NX flag is provided and key already has an expiry time", - command: []string{"EXPIRE", "ExpireKey4", "1000", "NX"}, - presetValues: map[string]KeyData{ - "ExpireKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "5. Set new expire time from now key only when the key already has an expiry time with XX flag", - command: []string{"EXPIRE", "ExpireKey5", "1000", "XX"}, - presetValues: map[string]KeyData{ - "ExpireKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "6. Return 0 when key does not have an expiry and the XX flag is provided", - command: []string{"EXPIRE", "ExpireKey6", "1000", "XX"}, - presetValues: map[string]KeyData{ - "ExpireKey6": {Value: "value6", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireKey6": {Value: "value6", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "7. Set expiry time when the provided time is after the current expiry time when GT flag is provided", - command: []string{"EXPIRE", "ExpireKey7", "1000", "GT"}, - presetValues: map[string]KeyData{ - "ExpireKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "8. Return 0 when GT flag is passed and current expiry time is greater than provided time", - command: []string{"EXPIRE", "ExpireKey8", "1000", "GT"}, - presetValues: map[string]KeyData{ - "ExpireKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "9. Return 0 when GT flag is passed and key does not have an expiry time", - command: []string{"EXPIRE", "ExpireKey9", "1000", "GT"}, - presetValues: map[string]KeyData{ - "ExpireKey9": {Value: "value9", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireKey9": {Value: "value9", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "10. Set expiry time when the provided time is before the current expiry time when LT flag is provided", - command: []string{"EXPIRE", "ExpireKey10", "1000", "LT"}, - presetValues: map[string]KeyData{ - "ExpireKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "11. Return 0 when LT flag is passed and current expiry time is less than provided time", - command: []string{"EXPIRE", "ExpireKey11", "5000", "LT"}, - presetValues: map[string]KeyData{ - "ExpireKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "12. Return 0 when LT flag is passed and key does not have an expiry time", - command: []string{"EXPIRE", "ExpireKey12", "1000", "LT"}, - presetValues: map[string]KeyData{ - "ExpireKey12": {Value: "value12", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireKey12": {Value: "value12", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "13. Return error when unknown flag is passed", - command: []string{"EXPIRE", "ExpireKey13", "1000", "UNKNOWN"}, - presetValues: map[string]KeyData{ - "ExpireKey13": {Value: "value13", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New("unknown option UNKNOWN"), - }, - { - name: "14. Return error when expire time is not a valid integer", - command: []string{"EXPIRE", "ExpireKey14", "expire"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New("expire time must be integer"), - }, - { - name: "15. Command too short", - command: []string{"EXPIRE"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "16. Command too long", - command: []string{"EXPIRE", "ExpireKey16", "10", "NX", "GT"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} - if !v.ExpireAt.Equal(time.Time{}) { - command = append(command, []resp.Value{ - resp.StringValue("PX"), - resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), - }...) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } } - if err = client.WriteArray(command); err != nil { + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleTTL", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse int + expectedError error + }{ + { + name: "1. Return TTL time in seconds", + command: []string{"TTL", "TTLKey1"}, + presetValues: map[string]KeyData{ + "TTLKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, + }, + expectedResponse: 100, + expectedError: nil, + }, + { + name: "2. Return TTL time in milliseconds", + command: []string{"PTTL", "TTLKey2"}, + presetValues: map[string]KeyData{ + "TTLKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(4096 * time.Millisecond)}, + }, + expectedResponse: 4096, + expectedError: nil, + }, + { + name: "3. If the key is non-volatile, return -1", + command: []string{"TTL", "TTLKey3"}, + presetValues: map[string]KeyData{ + "TTLKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedResponse: -1, + expectedError: nil, + }, + { + name: "4. If the key is non-existent return -2", + command: []string{"TTL", "TTLKey4"}, + presetValues: nil, + expectedResponse: -2, + expectedError: nil, + }, + { + name: "5. Command too short", + command: []string{"TTL"}, + presetValues: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + command: []string{"TTL", "TTLKey5", "TTLKey6"}, + presetValues: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleEXPIRE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse int + expectedValues map[string]KeyData + expectedError error + }{ + { + name: "1. Set new expire by seconds", + command: []string{"EXPIRE", "ExpireKey1", "100"}, + presetValues: map[string]KeyData{ + "ExpireKey1": {Value: "value1", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey1": {Value: "value1", ExpireAt: mockClock.Now().Add(100 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "2. Set new expire by milliseconds", + command: []string{"PEXPIRE", "ExpireKey2", "1000"}, + presetValues: map[string]KeyData{ + "ExpireKey2": {Value: "value2", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey2": {Value: "value2", ExpireAt: mockClock.Now().Add(1000 * time.Millisecond)}, + }, + expectedError: nil, + }, + { + name: "3. Set new expire only when key does not have an expiry time with NX flag", + command: []string{"EXPIRE", "ExpireKey3", "1000", "NX"}, + presetValues: map[string]KeyData{ + "ExpireKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey3": {Value: "value3", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "4. Return 0, when NX flag is provided and key already has an expiry time", + command: []string{"EXPIRE", "ExpireKey4", "1000", "NX"}, + presetValues: map[string]KeyData{ + "ExpireKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "5. Set new expire time from now key only when the key already has an expiry time with XX flag", + command: []string{"EXPIRE", "ExpireKey5", "1000", "XX"}, + presetValues: map[string]KeyData{ + "ExpireKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "6. Return 0 when key does not have an expiry and the XX flag is provided", + command: []string{"EXPIRE", "ExpireKey6", "1000", "XX"}, + presetValues: map[string]KeyData{ + "ExpireKey6": {Value: "value6", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireKey6": {Value: "value6", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "7. Set expiry time when the provided time is after the current expiry time when GT flag is provided", + command: []string{"EXPIRE", "ExpireKey7", "1000", "GT"}, + presetValues: map[string]KeyData{ + "ExpireKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "8. Return 0 when GT flag is passed and current expiry time is greater than provided time", + command: []string{"EXPIRE", "ExpireKey8", "1000", "GT"}, + presetValues: map[string]KeyData{ + "ExpireKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "9. Return 0 when GT flag is passed and key does not have an expiry time", + command: []string{"EXPIRE", "ExpireKey9", "1000", "GT"}, + presetValues: map[string]KeyData{ + "ExpireKey9": {Value: "value9", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireKey9": {Value: "value9", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "10. Set expiry time when the provided time is before the current expiry time when LT flag is provided", + command: []string{"EXPIRE", "ExpireKey10", "1000", "LT"}, + presetValues: map[string]KeyData{ + "ExpireKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "11. Return 0 when LT flag is passed and current expiry time is less than provided time", + command: []string{"EXPIRE", "ExpireKey11", "5000", "LT"}, + presetValues: map[string]KeyData{ + "ExpireKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "12. Return 0 when LT flag is passed and key does not have an expiry time", + command: []string{"EXPIRE", "ExpireKey12", "1000", "LT"}, + presetValues: map[string]KeyData{ + "ExpireKey12": {Value: "value12", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireKey12": {Value: "value12", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "13. Return error when unknown flag is passed", + command: []string{"EXPIRE", "ExpireKey13", "1000", "UNKNOWN"}, + presetValues: map[string]KeyData{ + "ExpireKey13": {Value: "value13", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New("unknown option UNKNOWN"), + }, + { + name: "14. Return error when expire time is not a valid integer", + command: []string{"EXPIRE", "ExpireKey14", "expire"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New("expire time must be integer"), + }, + { + name: "15. Command too short", + command: []string{"EXPIRE"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "16. Command too long", + command: []string{"EXPIRE", "ExpireKey16", "10", "NX", "GT"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + if test.expectedValues == nil { + return + } + + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - if test.expectedValues == nil { - return - } - - for key, expected := range test.expectedValues { - // Compare the value of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if res.String() != expected.Value.(string) { - t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) - } - // Compare the expiry of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if expected.ExpireAt.Equal(time.Time{}) { - if res.Integer() != -1 { - t.Error("expected key to be persisted, it was not.") - } - continue - } - if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { - t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) - } - } - }) - } -} - -func Test_HandleEXPIREAT(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - command []string - presetValues map[string]KeyData - expectedResponse int - expectedValues map[string]KeyData - expectedError error - }{ - { - name: "1. Set new expire by unix seconds", - command: []string{"EXPIREAT", "ExpireAtKey1", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix())}, - presetValues: map[string]KeyData{ - "ExpireAtKey1": {Value: "value1", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey1": {Value: "value1", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "2. Set new expire by milliseconds", - command: []string{"PEXPIREAT", "ExpireAtKey2", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).UnixMilli())}, - presetValues: map[string]KeyData{ - "ExpireAtKey2": {Value: "value2", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey2": {Value: "value2", ExpireAt: time.UnixMilli(mockClock.Now().Add(1000 * time.Second).UnixMilli())}, - }, - expectedError: nil, - }, - { - name: "3. Set new expire only when key does not have an expiry time with NX flag", - command: []string{"EXPIREAT", "ExpireAtKey3", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "NX"}, - presetValues: map[string]KeyData{ - "ExpireAtKey3": {Value: "value3", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey3": {Value: "value3", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "4. Return 0, when NX flag is provided and key already has an expiry time", - command: []string{"EXPIREAT", "ExpireAtKey4", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "NX"}, - presetValues: map[string]KeyData{ - "ExpireAtKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireAtKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "5. Set new expire time from now key only when the key already has an expiry time with XX flag", - command: []string{ - "EXPIREAT", "ExpireAtKey5", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "XX", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey5": {Value: "value5", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "6. Return 0 when key does not have an expiry and the XX flag is provided", - command: []string{ - "EXPIREAT", "ExpireAtKey6", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "XX", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey6": {Value: "value6", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireAtKey6": {Value: "value6", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "7. Set expiry time when the provided time is after the current expiry time when GT flag is provided", - command: []string{ - "EXPIREAT", "ExpireAtKey7", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey7": {Value: "value7", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "8. Return 0 when GT flag is passed and current expiry time is greater than provided time", - command: []string{ - "EXPIREAT", "ExpireAtKey8", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireAtKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "9. Return 0 when GT flag is passed and key does not have an expiry time", - command: []string{ - "EXPIREAT", "ExpireAtKey9", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey9": {Value: "value9", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireAtKey9": {Value: "value9", ExpireAt: time.Time{}}, - }, - expectedError: nil, - }, - { - name: "10. Set expiry time when the provided time is before the current expiry time when LT flag is provided", - command: []string{ - "EXPIREAT", "ExpireAtKey10", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "LT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey10": {Value: "value10", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "11. Return 0 when LT flag is passed and current expiry time is less than provided time", - command: []string{ - "EXPIREAT", "ExpireAtKey11", - fmt.Sprintf("%d", mockClock.Now().Add(3000*time.Second).Unix()), "LT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedResponse: 0, - expectedValues: map[string]KeyData{ - "ExpireAtKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, - }, - expectedError: nil, - }, - { - name: "12. Return 0 when LT flag is passed and key does not have an expiry time", - command: []string{ - "EXPIREAT", "ExpireAtKey12", - fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "LT", - }, - presetValues: map[string]KeyData{ - "ExpireAtKey12": {Value: "value12", ExpireAt: time.Time{}}, - }, - expectedResponse: 1, - expectedValues: map[string]KeyData{ - "ExpireAtKey12": {Value: "value12", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, - }, - expectedError: nil, - }, - { - name: "13. Return error when unknown flag is passed", - command: []string{"EXPIREAT", "ExpireAtKey13", "1000", "UNKNOWN"}, - presetValues: map[string]KeyData{ - "ExpireAtKey13": {Value: "value13", ExpireAt: time.Time{}}, - }, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New("unknown option UNKNOWN"), - }, - { - name: "14. Return error when expire time is not a valid integer", - command: []string{"EXPIREAT", "ExpireAtKey14", "expire"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New("expire time must be integer"), - }, - { - name: "15. Command too short", - command: []string{"EXPIREAT"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "16. Command too long", - command: []string{"EXPIREAT", "ExpireAtKey16", "10", "NX", "GT"}, - presetValues: nil, - expectedResponse: 0, - expectedValues: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - for k, v := range test.presetValues { - command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} - if !v.ExpireAt.Equal(time.Time{}) { - command = append(command, []resp.Value{ - resp.StringValue("PX"), - resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), - }...) - } - if err = client.WriteArray(command); err != nil { + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) } } - } + }) + } + }) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } + t.Run("Test_HandleEXPIREAT", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - if err = client.WriteArray(command); err != nil { - t.Error(err) - } + tests := []struct { + name string + command []string + presetValues map[string]KeyData + expectedResponse int + expectedValues map[string]KeyData + expectedError error + }{ + { + name: "1. Set new expire by unix seconds", + command: []string{"EXPIREAT", "ExpireAtKey1", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix())}, + presetValues: map[string]KeyData{ + "ExpireAtKey1": {Value: "value1", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey1": {Value: "value1", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "2. Set new expire by milliseconds", + command: []string{"PEXPIREAT", "ExpireAtKey2", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).UnixMilli())}, + presetValues: map[string]KeyData{ + "ExpireAtKey2": {Value: "value2", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey2": {Value: "value2", ExpireAt: time.UnixMilli(mockClock.Now().Add(1000 * time.Second).UnixMilli())}, + }, + expectedError: nil, + }, + { + name: "3. Set new expire only when key does not have an expiry time with NX flag", + command: []string{"EXPIREAT", "ExpireAtKey3", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "NX"}, + presetValues: map[string]KeyData{ + "ExpireAtKey3": {Value: "value3", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey3": {Value: "value3", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "4. Return 0, when NX flag is provided and key already has an expiry time", + command: []string{"EXPIREAT", "ExpireAtKey4", fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "NX"}, + presetValues: map[string]KeyData{ + "ExpireAtKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireAtKey4": {Value: "value4", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "5. Set new expire time from now key only when the key already has an expiry time with XX flag", + command: []string{ + "EXPIREAT", "ExpireAtKey5", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "XX", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey5": {Value: "value5", ExpireAt: mockClock.Now().Add(30 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey5": {Value: "value5", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "6. Return 0 when key does not have an expiry and the XX flag is provided", + command: []string{ + "EXPIREAT", "ExpireAtKey6", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "XX", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey6": {Value: "value6", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireAtKey6": {Value: "value6", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "7. Set expiry time when the provided time is after the current expiry time when GT flag is provided", + command: []string{ + "EXPIREAT", "ExpireAtKey7", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey7": {Value: "value7", ExpireAt: mockClock.Now().Add(30 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey7": {Value: "value7", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "8. Return 0 when GT flag is passed and current expiry time is greater than provided time", + command: []string{ + "EXPIREAT", "ExpireAtKey8", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireAtKey8": {Value: "value8", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "9. Return 0 when GT flag is passed and key does not have an expiry time", + command: []string{ + "EXPIREAT", "ExpireAtKey9", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "GT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey9": {Value: "value9", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireAtKey9": {Value: "value9", ExpireAt: time.Time{}}, + }, + expectedError: nil, + }, + { + name: "10. Set expiry time when the provided time is before the current expiry time when LT flag is provided", + command: []string{ + "EXPIREAT", "ExpireAtKey10", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "LT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey10": {Value: "value10", ExpireAt: mockClock.Now().Add(3000 * time.Second)}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey10": {Value: "value10", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "11. Return 0 when LT flag is passed and current expiry time is less than provided time", + command: []string{ + "EXPIREAT", "ExpireAtKey11", + fmt.Sprintf("%d", mockClock.Now().Add(3000*time.Second).Unix()), "LT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedResponse: 0, + expectedValues: map[string]KeyData{ + "ExpireAtKey11": {Value: "value11", ExpireAt: mockClock.Now().Add(1000 * time.Second)}, + }, + expectedError: nil, + }, + { + name: "12. Return 0 when LT flag is passed and key does not have an expiry time", + command: []string{ + "EXPIREAT", "ExpireAtKey12", + fmt.Sprintf("%d", mockClock.Now().Add(1000*time.Second).Unix()), "LT", + }, + presetValues: map[string]KeyData{ + "ExpireAtKey12": {Value: "value12", ExpireAt: time.Time{}}, + }, + expectedResponse: 1, + expectedValues: map[string]KeyData{ + "ExpireAtKey12": {Value: "value12", ExpireAt: time.Unix(mockClock.Now().Add(1000*time.Second).Unix(), 0)}, + }, + expectedError: nil, + }, + { + name: "13. Return error when unknown flag is passed", + command: []string{"EXPIREAT", "ExpireAtKey13", "1000", "UNKNOWN"}, + presetValues: map[string]KeyData{ + "ExpireAtKey13": {Value: "value13", ExpireAt: time.Time{}}, + }, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New("unknown option UNKNOWN"), + }, + { + name: "14. Return error when expire time is not a valid integer", + command: []string{"EXPIREAT", "ExpireAtKey14", "expire"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New("expire time must be integer"), + }, + { + name: "15. Command too short", + command: []string{"EXPIREAT"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "16. Command too long", + command: []string{"EXPIREAT", "ExpireAtKey16", "10", "NX", "GT"}, + presetValues: nil, + expectedResponse: 0, + expectedValues: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + for k, v := range test.presetValues { + command := []resp.Value{resp.StringValue("SET"), resp.StringValue(k), resp.StringValue(v.Value.(string))} + if !v.ExpireAt.Equal(time.Time{}) { + command = append(command, []resp.Value{ + resp.StringValue("PX"), + resp.StringValue(fmt.Sprintf("%d", v.ExpireAt.Sub(mockClock.Now()).Milliseconds())), + }...) + } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } - if test.expectedValues == nil { - return - } - - for key, expected := range test.expectedValues { - // Compare the value of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - res, _, err = client.ReadValue() + + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if res.String() != expected.Value.(string) { - t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) - } - // Compare the expiry of the key with what's expected - if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if expected.ExpireAt.Equal(time.Time{}) { - if res.Integer() != -1 { - t.Error("expected key to be persisted, it was not.") + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - continue + return } - if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { - t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) } - } - }) - } + + if test.expectedValues == nil { + return + } + + for key, expected := range test.expectedValues { + // Compare the value of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if res.String() != expected.Value.(string) { + t.Errorf("expected value %s, got %s", expected.Value.(string), res.String()) + } + // Compare the expiry of the key with what's expected + if err = client.WriteArray([]resp.Value{resp.StringValue("PTTL"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if expected.ExpireAt.Equal(time.Time{}) { + if res.Integer() != -1 { + t.Error("expected key to be persisted, it was not.") + } + continue + } + if res.Integer() != int(expected.ExpireAt.Sub(mockClock.Now()).Milliseconds()) { + t.Errorf("expected expiry %d, got %d", expected.ExpireAt.Sub(mockClock.Now()).Milliseconds(), res.Integer()) + } + } + }) + } + }) + } diff --git a/internal/modules/hash/commands_test.go b/internal/modules/hash/commands_test.go index ad7a7a8..e5cbb7a 100644 --- a/internal/modules/hash/commands_test.go +++ b/internal/modules/hash/commands_test.go @@ -30,20 +30,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var addr = "localhost" -var port int +func Test_Hash(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -51,984 +57,148 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleHSET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - // Tests for both HSet and HSetNX - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int // Change count - expectedValue map[string]string - expectedError error - }{ - { - name: "1. HSETNX set field on non-existent hash map", - key: "HsetKey1", - presetValue: nil, - command: []string{"HSETNX", "HsetKey1", "field1", "value1"}, - expectedResponse: 1, - expectedValue: map[string]string{"field1": "value1"}, - expectedError: nil, - }, - { - name: "2. HSETNX set field on existing hash map", - key: "HsetKey2", - presetValue: map[string]string{"field1": "value1"}, - command: []string{"HSETNX", "HsetKey2", "field2", "value2"}, - expectedResponse: 1, - expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, - expectedError: nil, - }, - { - name: "3. HSETNX skips operation when setting on existing field", - key: "HsetKey3", - presetValue: map[string]string{"field1": "value1"}, - command: []string{"HSETNX", "HsetKey3", "field1", "value1-new"}, - expectedResponse: 0, - expectedValue: map[string]string{"field1": "value1"}, - expectedError: nil, - }, - { - name: "4. Regular HSET command on non-existent hash map", - key: "HsetKey4", - presetValue: nil, - command: []string{"HSET", "HsetKey4", "field1", "value1", "field2", "value2"}, - expectedResponse: 2, - expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, - expectedError: nil, - }, - { - name: "5. Regular HSET update on existing hash map", - key: "HsetKey5", - presetValue: map[string]string{"field1": "value1", "field2": "value2"}, - command: []string{"HSET", "HsetKey5", "field1", "value1-new", "field2", "value2-ne2", "field3", "value3"}, - expectedResponse: 3, - expectedValue: map[string]string{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"}, - expectedError: nil, - }, - { - name: "6. HSET overwrites when the target key is not a map", - key: "HsetKey6", - presetValue: "Default preset value", - command: []string{"HSET", "HsetKey6", "field1", "value1"}, - expectedResponse: 1, - expectedValue: map[string]string{"field1": "value1"}, - expectedError: nil, - }, - { - name: "7. HSET returns error when there's a mismatch in key/values", - key: "HsetKey7", - presetValue: nil, - command: []string{"HSET", "HsetKey7", "field1", "value1", "field2"}, - expectedResponse: 0, - expectedValue: map[string]string{}, - expectedError: errors.New("each field must have a corresponding value"), - }, - { - name: "8. Command too short", - key: "HsetKey8", - presetValue: nil, - command: []string{"HSET", "field1"}, - expectedResponse: 0, - expectedValue: map[string]string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + t.Run("Test_HandleHSET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) - } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check that all the values are what is expected - if err := client.WriteArray([]resp.Value{ - resp.StringValue("HGETALL"), - resp.StringValue(test.key), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - for idx, field := range res.Array() { - if idx%2 == 0 { - if res.Array()[idx+1].String() != test.expectedValue[field.String()] { - t.Errorf( - "expected value \"%+v\" for field \"%s\", got \"%+v\"", - test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), - ) - } - } - } - }) - } -} - -func Test_HandleHINCRBY(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - // Tests for both HIncrBy and HIncrByFloat - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse string // Change count - expectedValue map[string]string - expectedError error - }{ - { - name: "1. Increment by integer on non-existent hash should create a new one", - key: "HincrbyKey1", - presetValue: nil, - command: []string{"HINCRBY", "HincrbyKey1", "field1", "1"}, - expectedResponse: "1", - expectedValue: map[string]string{"field1": "1"}, - expectedError: nil, - }, - { - name: "2. Increment by float on non-existent hash should create one", - key: "HincrbyKey2", - presetValue: nil, - command: []string{"HINCRBYFLOAT", "HincrbyKey2", "field1", "3.142"}, - expectedResponse: "3.142", - expectedValue: map[string]string{"field1": "3.142"}, - expectedError: nil, - }, - { - name: "3. Increment by integer on existing hash", - key: "HincrbyKey3", - presetValue: map[string]string{"field1": "1"}, - command: []string{"HINCRBY", "HincrbyKey3", "field1", "10"}, - expectedResponse: "11", - expectedValue: map[string]string{"field1": "11"}, - expectedError: nil, - }, - { - name: "4. Increment by float on an existing hash", - key: "HincrbyKey4", - presetValue: map[string]string{"field1": "3.142"}, - command: []string{"HINCRBYFLOAT", "HincrbyKey4", "field1", "3.142"}, - expectedResponse: "6.284", - expectedValue: map[string]string{"field1": "6.284"}, - expectedError: nil, - }, - { - name: "5. Command too short", - key: "HincrbyKey5", - presetValue: nil, - command: []string{"HINCRBY", "HincrbyKey5"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - key: "HincrbyKey6", - presetValue: nil, - command: []string{"HINCRBY", "HincrbyKey6", "field1", "23", "45"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Error when increment by float does not pass valid float", - key: "HincrbyKey7", - presetValue: nil, - command: []string{"HINCRBYFLOAT", "HincrbyKey7", "field1", "three point one four two"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New("increment must be a float"), - }, - { - name: "8. Error when increment does not pass valid integer", - key: "HincrbyKey8", - presetValue: nil, - command: []string{"HINCRBY", "HincrbyKey8", "field1", "three"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New("increment must be an integer"), - }, - { - name: "9. Error when trying to increment on a key that is not a hash", - key: "HincrbyKey9", - presetValue: "Default value", - command: []string{"HINCRBY", "HincrbyKey9", "field1", "3"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New("value at HincrbyKey9 is not a hash"), - }, - { - name: "10. Error when trying to increment a hash field that is not a number", - key: "HincrbyKey10", - presetValue: map[string]string{"field1": "value1"}, - command: []string{"HINCRBY", "HincrbyKey10", "field1", "3"}, - expectedResponse: "0", - expectedValue: nil, - expectedError: errors.New("value at field field1 is not a number"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) - } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } - - // Check that all the values are what is expected - if err := client.WriteArray([]resp.Value{ - resp.StringValue("HGETALL"), - resp.StringValue(test.key), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - for idx, field := range res.Array() { - if idx%2 == 0 { - if res.Array()[idx+1].String() != test.expectedValue[field.String()] { - t.Errorf( - "expected value \"%+v\" for field \"%s\", got \"%+v\"", - test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), - ) - } - } - } - }) - } -} - -func Test_HandleHGET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string // Change count - expectedValue map[string]string - expectedError error - }{ - { - name: "1. Get values from existing hash.", - key: "HgetKey1", - presetValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, - command: []string{"HGET", "HgetKey1", "field1", "field2", "field3", "field4"}, - expectedResponse: []string{"value1", "365", "3.142", ""}, - expectedValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, - expectedError: nil, - }, - { - name: "2. Return nil when attempting to get from non-existed key", - key: "HgetKey2", - presetValue: nil, - command: []string{"HGET", "HgetKey2", "field1"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: nil, - }, - { - name: "3. Error when trying to get from a value that is not a hash map", - key: "HgetKey3", - presetValue: "Default Value", - command: []string{"HGET", "HgetKey3", "field1"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New("value at HgetKey3 is not a hash"), - }, - { - name: "4. Command too short", - key: "HgetKey4", - presetValue: nil, - command: []string{"HGET", "HgetKey4"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) - } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if test.expectedResponse == nil { - if !res.IsNull() { - t.Errorf("expected nil response, got %+v", res) - } - return - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - - // Check that all the values are what is expected - if err := client.WriteArray([]resp.Value{ - resp.StringValue("HGETALL"), - resp.StringValue(test.key), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - for idx, field := range res.Array() { - if idx%2 == 0 { - if res.Array()[idx+1].String() != test.expectedValue[field.String()] { - t.Errorf( - "expected value \"%+v\" for field \"%s\", got \"%+v\"", - test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), - ) - } - } - } - }) - } -} - -func Test_HandleHSTRLEN(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []int // Change count - expectedValue map[string]string - expectedError error - }{ - { - // Return lengths of field values. - // If the key does not exist, its length should be 0. - name: "1. Return lengths of field values.", - key: "HstrlenKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HSTRLEN", "HstrlenKey1", "field1", "field2", "field3", "field4"}, - expectedResponse: []int{len("value1"), len("123456789"), len("3.142"), 0}, - expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - expectedError: nil, - }, - { - name: "2. Nil response when trying to get HSTRLEN non-existent key", - key: "HstrlenKey2", - presetValue: nil, - command: []string{"HSTRLEN", "HstrlenKey2", "field1"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HstrlenKey3", - presetValue: nil, - command: []string{"HSTRLEN", "HstrlenKey3"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Trying to get lengths on a non hash map returns error", - key: "HstrlenKey4", - presetValue: "Default value", - command: []string{"HSTRLEN", "HstrlenKey4", "field1"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New("value at HstrlenKey4 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) - } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if test.expectedResponse == nil { - if !res.IsNull() { - t.Errorf("expected nil response, got %+v", res) - } - return - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.Integer()) { - t.Errorf("unexpected element \"%d\" in response", item.Integer()) - } - } - - // Check that all the values are what is expected - if err := client.WriteArray([]resp.Value{ - resp.StringValue("HGETALL"), - resp.StringValue(test.key), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - for idx, field := range res.Array() { - if idx%2 == 0 { - if res.Array()[idx+1].String() != test.expectedValue[field.String()] { - t.Errorf( - "expected value \"%+v\" for field \"%s\", got \"%+v\"", - test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), - ) - } - } - } - }) - } -} - -func Test_HandleHVALS(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string - expectedValue map[string]string - expectedError error - }{ - { - name: "1. Return all the values from a hash", - key: "HvalsKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HVALS", "HvalsKey1"}, - expectedResponse: []string{"value1", "123456789", "3.142"}, - expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - expectedError: nil, - }, - { - name: "2. Empty array response when trying to get HSTRLEN non-existent key", - key: "HvalsKey2", - presetValue: nil, - command: []string{"HVALS", "HvalsKey2"}, - expectedResponse: []string{}, - expectedValue: nil, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HvalsKey3", - presetValue: nil, - command: []string{"HVALS"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "HvalsKey4", - presetValue: nil, - command: []string{"HVALS", "HvalsKey4", "HvalsKey4"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HvalsKey5", - presetValue: "Default value", - command: []string{"HVALS", "HvalsKey5"}, - expectedResponse: nil, - expectedValue: nil, - expectedError: errors.New("value at HvalsKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) - } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if test.expectedResponse == nil { - if !res.IsNull() { - t.Errorf("expected nil response, got %+v", res) - } - return - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleHRANDFIELD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Get a random field", - key: "HrandfieldKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HRANDFIELD", "HrandfieldKey1"}, - expectedResponse: []string{"field1", "field2", "field3"}, - expectedError: nil, - }, - { - name: "2. Get a random field with a value", - key: "HrandfieldKey2", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HRANDFIELD", "HrandfieldKey2", "1", "WITHVALUES"}, - expectedResponse: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, - expectedError: nil, - }, - { - name: "3. Get several random fields", - key: "HrandfieldKey3", - presetValue: map[string]string{ - "field1": "value1", - "field2": "123456789", - "field3": "3.142", - "field4": "value4", - "field5": "value5", + // Tests for both HSet and HSetNX + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int // Change count + expectedValue map[string]string + expectedError error + }{ + { + name: "1. HSETNX set field on non-existent hash map", + key: "HsetKey1", + presetValue: nil, + command: []string{"HSETNX", "HsetKey1", "field1", "value1"}, + expectedResponse: 1, + expectedValue: map[string]string{"field1": "value1"}, + expectedError: nil, }, - command: []string{"HRANDFIELD", "HrandfieldKey3", "3"}, - expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, - expectedError: nil, - }, - { - name: "4. Get several random fields with their corresponding values", - key: "HrandfieldKey4", - presetValue: map[string]string{ - "field1": "value1", - "field2": "123456789", - "field3": "3.142", - "field4": "value4", - "field5": "value5", + { + name: "2. HSETNX set field on existing hash map", + key: "HsetKey2", + presetValue: map[string]string{"field1": "value1"}, + command: []string{"HSETNX", "HsetKey2", "field2", "value2"}, + expectedResponse: 1, + expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, + expectedError: nil, }, - command: []string{"HRANDFIELD", "HrandfieldKey4", "3", "WITHVALUES"}, - expectedResponse: []string{ - "field1", "value1", "field2", "123456789", "field3", - "3.142", "field4", "value4", "field5", "value5", + { + name: "3. HSETNX skips operation when setting on existing field", + key: "HsetKey3", + presetValue: map[string]string{"field1": "value1"}, + command: []string{"HSETNX", "HsetKey3", "field1", "value1-new"}, + expectedResponse: 0, + expectedValue: map[string]string{"field1": "value1"}, + expectedError: nil, }, - expectedError: nil, - }, - { - name: "5. Get the entire hash", - key: "HrandfieldKey5", - presetValue: map[string]string{ - "field1": "value1", - "field2": "123456789", - "field3": "3.142", - "field4": "value4", - "field5": "value5", + { + name: "4. Regular HSET command on non-existent hash map", + key: "HsetKey4", + presetValue: nil, + command: []string{"HSET", "HsetKey4", "field1", "value1", "field2", "value2"}, + expectedResponse: 2, + expectedValue: map[string]string{"field1": "value1", "field2": "value2"}, + expectedError: nil, }, - command: []string{"HRANDFIELD", "HrandfieldKey5", "5"}, - expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, - expectedError: nil, - }, - { - name: "6. Get the entire hash with values", - key: "HrandfieldKey5", - presetValue: map[string]string{ - "field1": "value1", - "field2": "123456789", - "field3": "3.142", - "field4": "value4", - "field5": "value5", + { + name: "5. Regular HSET update on existing hash map", + key: "HsetKey5", + presetValue: map[string]string{"field1": "value1", "field2": "value2"}, + command: []string{"HSET", "HsetKey5", "field1", "value1-new", "field2", "value2-ne2", "field3", "value3"}, + expectedResponse: 3, + expectedValue: map[string]string{"field1": "value1-new", "field2": "value2-ne2", "field3": "value3"}, + expectedError: nil, }, - command: []string{"HRANDFIELD", "HrandfieldKey5", "5", "WITHVALUES"}, - expectedResponse: []string{ - "field1", "value1", "field2", "123456789", "field3", - "3.142", "field4", "value4", "field5", "value5", + { + name: "6. HSET overwrites when the target key is not a map", + key: "HsetKey6", + presetValue: "Default preset value", + command: []string{"HSET", "HsetKey6", "field1", "value1"}, + expectedResponse: 1, + expectedValue: map[string]string{"field1": "value1"}, + expectedError: nil, }, - expectedError: nil, - }, - { - name: "7. Command too short", - key: "HrandfieldKey10", - presetValue: nil, - command: []string{"HRANDFIELD"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "8. Command too long", - key: "HrandfieldKey11", - presetValue: nil, - command: []string{"HRANDFIELD", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "9. Trying to get random field on a non hash map returns error", - key: "HrandfieldKey12", - presetValue: "Default value", - command: []string{"HRANDFIELD", "HrandfieldKey12"}, - expectedError: errors.New("value at HrandfieldKey12 is not a hash"), - }, - { - name: "10. Throw error when count provided is not an integer", - key: "HrandfieldKey12", - presetValue: "Default value", - command: []string{"HRANDFIELD", "HrandfieldKey12", "COUNT"}, - expectedError: errors.New("count must be an integer"), - }, - { - name: "11. If fourth argument is provided, it must be \"WITHVALUES\"", - key: "HrandfieldKey12", - presetValue: "Default value", - command: []string{"HRANDFIELD", "HrandfieldKey12", "10", "FLAG"}, - expectedError: errors.New("result modifier must be withvalues"), - }, - } + { + name: "7. HSET returns error when there's a mismatch in key/values", + key: "HsetKey7", + presetValue: nil, + command: []string{"HSET", "HsetKey7", "field1", "value1", "field2"}, + expectedResponse: 0, + expectedValue: map[string]string{}, + expectedError: errors.New("each field must have a corresponding value"), + }, + { + name: "8. Command too short", + key: "HsetKey8", + presetValue: nil, + command: []string{"HSET", "field1"}, + expectedResponse: 0, + expectedValue: map[string]string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + + if err = client.WriteArray(command); err != nil { + t.Error(err) } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1039,127 +209,198 @@ func Test_HandleHRANDFIELD(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if test.expectedResponse == nil { - if !res.IsNull() { - t.Errorf("expected nil response, got %+v", res) - } - return - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleHLEN(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int // Change count - expectedError error - }{ - { - name: "1. Return the correct length of the hash", - key: "HlenKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HLEN", "HlenKey1"}, - expectedResponse: 3, - expectedError: nil, - }, - { - name: "2. 0 response when trying to call HLEN on non-existent key", - key: "HlenKey2", - presetValue: nil, - command: []string{"HLEN", "HlenKey2"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HlenKey3", - presetValue: nil, - command: []string{"HLEN"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - presetValue: nil, - command: []string{"HLEN", "HlenKey4", "HlenKey4"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HlenKey5", - presetValue: "Default value", - command: []string{"HLEN", "HlenKey5"}, - expectedResponse: 0, - expectedError: errors.New("value at HlenKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + }) + } + }) + + t.Run("Test_HandleHINCRBY", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + // Tests for both HIncrBy and HIncrByFloat + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse string // Change count + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Increment by integer on non-existent hash should create a new one", + key: "HincrbyKey1", + presetValue: nil, + command: []string{"HINCRBY", "HincrbyKey1", "field1", "1"}, + expectedResponse: "1", + expectedValue: map[string]string{"field1": "1"}, + expectedError: nil, + }, + { + name: "2. Increment by float on non-existent hash should create one", + key: "HincrbyKey2", + presetValue: nil, + command: []string{"HINCRBYFLOAT", "HincrbyKey2", "field1", "3.142"}, + expectedResponse: "3.142", + expectedValue: map[string]string{"field1": "3.142"}, + expectedError: nil, + }, + { + name: "3. Increment by integer on existing hash", + key: "HincrbyKey3", + presetValue: map[string]string{"field1": "1"}, + command: []string{"HINCRBY", "HincrbyKey3", "field1", "10"}, + expectedResponse: "11", + expectedValue: map[string]string{"field1": "11"}, + expectedError: nil, + }, + { + name: "4. Increment by float on an existing hash", + key: "HincrbyKey4", + presetValue: map[string]string{"field1": "3.142"}, + command: []string{"HINCRBYFLOAT", "HincrbyKey4", "field1", "3.142"}, + expectedResponse: "6.284", + expectedValue: map[string]string{"field1": "6.284"}, + expectedError: nil, + }, + { + name: "5. Command too short", + key: "HincrbyKey5", + presetValue: nil, + command: []string{"HINCRBY", "HincrbyKey5"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + key: "HincrbyKey6", + presetValue: nil, + command: []string{"HINCRBY", "HincrbyKey6", "field1", "23", "45"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Error when increment by float does not pass valid float", + key: "HincrbyKey7", + presetValue: nil, + command: []string{"HINCRBYFLOAT", "HincrbyKey7", "field1", "three point one four two"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New("increment must be a float"), + }, + { + name: "8. Error when increment does not pass valid integer", + key: "HincrbyKey8", + presetValue: nil, + command: []string{"HINCRBY", "HincrbyKey8", "field1", "three"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New("increment must be an integer"), + }, + { + name: "9. Error when trying to increment on a key that is not a hash", + key: "HincrbyKey9", + presetValue: "Default value", + command: []string{"HINCRBY", "HincrbyKey9", "field1", "3"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New("value at HincrbyKey9 is not a hash"), + }, + { + name: "10. Error when trying to increment a hash field that is not a number", + key: "HincrbyKey10", + presetValue: map[string]string{"field1": "value1"}, + command: []string{"HINCRBY", "HincrbyKey10", "field1", "3"}, + expectedResponse: "0", + expectedValue: nil, + expectedError: errors.New("value at field field1 is not a number"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1170,118 +411,143 @@ func Test_HandleHLEN(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleHKeys(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Return an array containing all the keys of the hash", - key: "HkeysKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HKEYS", "HkeysKey1"}, - expectedResponse: []string{"field1", "field2", "field3"}, - expectedError: nil, - }, - { - name: "2. Empty array response when trying to call HKEYS on non-existent key", - key: "HkeysKey2", - presetValue: nil, - command: []string{"HKEYS", "HkeysKey2"}, - expectedResponse: []string{}, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HkeysKey3", - presetValue: nil, - command: []string{"HKEYS"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "HkeysKey4", - presetValue: nil, - command: []string{"HKEYS", "HkeysKey4", "HkeysKey4"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HkeysKey5", - presetValue: "Default value", - command: []string{"HKEYS", "HkeysKey5"}, - expectedError: errors.New("value at HkeysKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + return + } + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + }) + } + }) + + t.Run("Test_HandleHGET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string // Change count + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Get values from existing hash.", + key: "HgetKey1", + presetValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, + command: []string{"HGET", "HgetKey1", "field1", "field2", "field3", "field4"}, + expectedResponse: []string{"value1", "365", "3.142", ""}, + expectedValue: map[string]string{"field1": "value1", "field2": "365", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Return nil when attempting to get from non-existed key", + key: "HgetKey2", + presetValue: nil, + command: []string{"HGET", "HgetKey2", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: nil, + }, + { + name: "3. Error when trying to get from a value that is not a hash map", + key: "HgetKey3", + presetValue: "Default Value", + command: []string{"HGET", "HgetKey3", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New("value at HgetKey3 is not a hash"), + }, + { + name: "4. Command too short", + key: "HgetKey4", + presetValue: nil, + command: []string{"HGET", "HgetKey4"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1292,121 +558,154 @@ func Test_HandleHKeys(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected value \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleHGETALL(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse map[string]string - expectedError error - }{ - { - name: "1. Return an array containing all the fields and values of the hash", - key: "HGetAllKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HGETALL", "HGetAllKey1"}, - expectedResponse: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - expectedError: nil, - }, - { - name: "2. Empty array response when trying to call HGETALL on non-existent key", - key: "HGetAllKey2", - presetValue: nil, - command: []string{"HGETALL", "HGetAllKey2"}, - expectedResponse: nil, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HGetAllKey3", - presetValue: nil, - command: []string{"HGETALL"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "HGetAllKey4", - presetValue: nil, - command: []string{"HGETALL", "HGetAllKey4", "HGetAllKey4"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HGetAllKey5", - presetValue: "Default value", - command: []string{"HGETALL", "HGetAllKey5"}, - expectedResponse: nil, - expectedError: errors.New("value at HGetAllKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + return + } + + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + return + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } + } + } + }) + } + }) + + t.Run("Test_HandleHSTRLEN", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []int // Change count + expectedValue map[string]string + expectedError error + }{ + { + // Return lengths of field values. + // If the key does not exist, its length should be 0. + name: "1. Return lengths of field values.", + key: "HstrlenKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HSTRLEN", "HstrlenKey1", "field1", "field2", "field3", "field4"}, + expectedResponse: []int{len("value1"), len("123456789"), len("3.142"), 0}, + expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Nil response when trying to get HSTRLEN non-existent key", + key: "HstrlenKey2", + presetValue: nil, + command: []string{"HSTRLEN", "HstrlenKey2", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HstrlenKey3", + presetValue: nil, + command: []string{"HSTRLEN", "HstrlenKey3"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Trying to get lengths on a non hash map returns error", + key: "HstrlenKey4", + presetValue: "Default value", + command: []string{"HSTRLEN", "HstrlenKey4", "field1"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New("value at HstrlenKey4 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1417,133 +716,161 @@ func Test_HandleHGETALL(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) + } + return } - return - } - if test.expectedResponse == nil { - if len(res.Array()) != 0 { - t.Errorf("expected response to be empty array, got %+v", res) - } - return - } - - for i, item := range res.Array() { - if i%2 == 0 { - field := item.String() - value := res.Array()[i+1].String() - if test.expectedResponse[field] != value { - t.Errorf("expected value at field \"%s\" to be \"%s\", got \"%s\"", field, test.expectedResponse[field], value) + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.Integer()) { + t.Errorf("unexpected element \"%d\" in response", item.Integer()) } } - } - }) - } -} + // Check that all the values are what is expected + if err := client.WriteArray([]resp.Value{ + resp.StringValue("HGETALL"), + resp.StringValue(test.key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } -func Test_HandleHEXISTS(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse bool - expectedError error - }{ - { - name: "1. Return 1 if the field exists in the hash", - key: "HexistsKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, - command: []string{"HEXISTS", "HexistsKey1", "field1"}, - expectedResponse: true, - expectedError: nil, - }, - { - name: "2. 0 response when trying to call HEXISTS on non-existent key", - key: "HexistsKey2", - presetValue: nil, - command: []string{"HEXISTS", "HexistsKey2", "field1"}, - expectedResponse: false, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "HexistsKey3", - presetValue: nil, - command: []string{"HEXISTS", "HexistsKey3"}, - expectedResponse: false, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "HexistsKey4", - presetValue: nil, - command: []string{"HEXISTS", "HexistsKey4", "field1", "field2"}, - expectedResponse: false, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HexistsKey5", - presetValue: "Default value", - command: []string{"HEXISTS", "HexistsKey5", "field1"}, - expectedResponse: false, - expectedError: errors.New("value at HexistsKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + } + }) + } + }) + + t.Run("Test_HandleHVALS", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Return all the values from a hash", + key: "HvalsKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HVALS", "HvalsKey1"}, + expectedResponse: []string{"value1", "123456789", "3.142"}, + expectedValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Empty array response when trying to get HSTRLEN non-existent key", + key: "HvalsKey2", + presetValue: nil, + command: []string{"HVALS", "HvalsKey2"}, + expectedResponse: []string{}, + expectedValue: nil, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HvalsKey3", + presetValue: nil, + command: []string{"HVALS"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "HvalsKey4", + presetValue: nil, + command: []string{"HVALS", "HvalsKey4", "HvalsKey4"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HvalsKey5", + presetValue: "Default value", + command: []string{"HVALS", "HvalsKey5"}, + expectedResponse: nil, + expectedValue: nil, + expectedError: errors.New("value at HvalsKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1554,125 +881,205 @@ func Test_HandleHEXISTS(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Bool() != test.expectedResponse { - t.Errorf("expected response to be %v, got %v", test.expectedResponse, res.Bool()) - } - }) - } -} - -func Test_HandleHDEL(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int - expectedValue map[string]string - expectedError error - }{ - { - name: "1. Return count of deleted fields in the specified hash", - key: "HdelKey1", - presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142", "field7": "value7"}, - command: []string{"HDEL", "HdelKey1", "field1", "field2", "field3", "field4", "field5", "field6"}, - expectedResponse: 3, - expectedValue: map[string]string{"field7": "value7"}, - expectedError: nil, - }, - { - name: "2. 0 response when passing delete fields that are non-existent on valid hash", - key: "HdelKey2", - presetValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, - command: []string{"HDEL", "HdelKey2", "field4", "field5", "field6"}, - expectedResponse: 0, - expectedValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, - expectedError: nil, - }, - { - name: "3. 0 response when trying to call HDEL on non-existent key", - key: "HdelKey3", - presetValue: nil, - command: []string{"HDEL", "HdelKey3", "field1"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: nil, - }, - { - name: "4. Command too short", - key: "HdelKey4", - presetValue: nil, - command: []string{"HDEL", "HdelKey4"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non hash map returns error", - key: "HdelKey5", - presetValue: "Default value", - command: []string{"HDEL", "HdelKey5", "field1"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New("value at HdelKey5 is not a hash"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case map[string]string: - command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} - for key, value := range test.presetValue.(map[string]string) { - command = append(command, []resp.Value{ - resp.StringValue(key), - resp.StringValue(value)}..., - ) + return + } + + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) } - expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + return + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleHRANDFIELD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Get a random field", + key: "HrandfieldKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HRANDFIELD", "HrandfieldKey1"}, + expectedResponse: []string{"field1", "field2", "field3"}, + expectedError: nil, + }, + { + name: "2. Get a random field with a value", + key: "HrandfieldKey2", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HRANDFIELD", "HrandfieldKey2", "1", "WITHVALUES"}, + expectedResponse: []string{"field1", "value1", "field2", "123456789", "field3", "3.142"}, + expectedError: nil, + }, + { + name: "3. Get several random fields", + key: "HrandfieldKey3", + presetValue: map[string]string{ + "field1": "value1", + "field2": "123456789", + "field3": "3.142", + "field4": "value4", + "field5": "value5", + }, + command: []string{"HRANDFIELD", "HrandfieldKey3", "3"}, + expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, + expectedError: nil, + }, + { + name: "4. Get several random fields with their corresponding values", + key: "HrandfieldKey4", + presetValue: map[string]string{ + "field1": "value1", + "field2": "123456789", + "field3": "3.142", + "field4": "value4", + "field5": "value5", + }, + command: []string{"HRANDFIELD", "HrandfieldKey4", "3", "WITHVALUES"}, + expectedResponse: []string{ + "field1", "value1", "field2", "123456789", "field3", + "3.142", "field4", "value4", "field5", "value5", + }, + expectedError: nil, + }, + { + name: "5. Get the entire hash", + key: "HrandfieldKey5", + presetValue: map[string]string{ + "field1": "value1", + "field2": "123456789", + "field3": "3.142", + "field4": "value4", + "field5": "value5", + }, + command: []string{"HRANDFIELD", "HrandfieldKey5", "5"}, + expectedResponse: []string{"field1", "field2", "field3", "field4", "field5"}, + expectedError: nil, + }, + { + name: "6. Get the entire hash with values", + key: "HrandfieldKey5", + presetValue: map[string]string{ + "field1": "value1", + "field2": "123456789", + "field3": "3.142", + "field4": "value4", + "field5": "value5", + }, + command: []string{"HRANDFIELD", "HrandfieldKey5", "5", "WITHVALUES"}, + expectedResponse: []string{ + "field1", "value1", "field2", "123456789", "field3", + "3.142", "field4", "value4", "field5", "value5", + }, + expectedError: nil, + }, + { + name: "7. Command too short", + key: "HrandfieldKey10", + presetValue: nil, + command: []string{"HRANDFIELD"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "8. Command too long", + key: "HrandfieldKey11", + presetValue: nil, + command: []string{"HRANDFIELD", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11", "HrandfieldKey11"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "9. Trying to get random field on a non hash map returns error", + key: "HrandfieldKey12", + presetValue: "Default value", + command: []string{"HRANDFIELD", "HrandfieldKey12"}, + expectedError: errors.New("value at HrandfieldKey12 is not a hash"), + }, + { + name: "10. Throw error when count provided is not an integer", + key: "HrandfieldKey12", + presetValue: "Default value", + command: []string{"HRANDFIELD", "HrandfieldKey12", "COUNT"}, + expectedError: errors.New("count must be an integer"), + }, + { + name: "11. If fourth argument is provided, it must be \"WITHVALUES\"", + key: "HrandfieldKey12", + presetValue: "Default value", + command: []string{"HRANDFIELD", "HrandfieldKey12", "10", "FLAG"}, + expectedError: errors.New("result modifier must be withvalues"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1683,45 +1090,692 @@ func Test_HandleHDEL(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if test.expectedResponse == nil { + if !res.IsNull() { + t.Errorf("expected nil response, got %+v", res) + } + return } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - for idx, field := range res.Array() { - if idx%2 == 0 { - if res.Array()[idx+1].String() != test.expectedValue[field.String()] { - t.Errorf( - "expected value \"%+v\" for field \"%s\", got \"%+v\"", - test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), - ) + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) } } - } - }) - } + }) + } + }) + + t.Run("Test_HandleHLEN", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int // Change count + expectedError error + }{ + { + name: "1. Return the correct length of the hash", + key: "HlenKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HLEN", "HlenKey1"}, + expectedResponse: 3, + expectedError: nil, + }, + { + name: "2. 0 response when trying to call HLEN on non-existent key", + key: "HlenKey2", + presetValue: nil, + command: []string{"HLEN", "HlenKey2"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HlenKey3", + presetValue: nil, + command: []string{"HLEN"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + presetValue: nil, + command: []string{"HLEN", "HlenKey4", "HlenKey4"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HlenKey5", + presetValue: "Default value", + command: []string{"HLEN", "HlenKey5"}, + expectedResponse: 0, + expectedError: errors.New("value at HlenKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleHKeys", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Return an array containing all the keys of the hash", + key: "HkeysKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HKEYS", "HkeysKey1"}, + expectedResponse: []string{"field1", "field2", "field3"}, + expectedError: nil, + }, + { + name: "2. Empty array response when trying to call HKEYS on non-existent key", + key: "HkeysKey2", + presetValue: nil, + command: []string{"HKEYS", "HkeysKey2"}, + expectedResponse: []string{}, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HkeysKey3", + presetValue: nil, + command: []string{"HKEYS"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "HkeysKey4", + presetValue: nil, + command: []string{"HKEYS", "HkeysKey4", "HkeysKey4"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HkeysKey5", + presetValue: "Default value", + command: []string{"HKEYS", "HkeysKey5"}, + expectedError: errors.New("value at HkeysKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected value \"%s\" in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleHGETALL", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse map[string]string + expectedError error + }{ + { + name: "1. Return an array containing all the fields and values of the hash", + key: "HGetAllKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HGETALL", "HGetAllKey1"}, + expectedResponse: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + expectedError: nil, + }, + { + name: "2. Empty array response when trying to call HGETALL on non-existent key", + key: "HGetAllKey2", + presetValue: nil, + command: []string{"HGETALL", "HGetAllKey2"}, + expectedResponse: nil, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HGetAllKey3", + presetValue: nil, + command: []string{"HGETALL"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "HGetAllKey4", + presetValue: nil, + command: []string{"HGETALL", "HGetAllKey4", "HGetAllKey4"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HGetAllKey5", + presetValue: "Default value", + command: []string{"HGETALL", "HGetAllKey5"}, + expectedResponse: nil, + expectedError: errors.New("value at HGetAllKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if test.expectedResponse == nil { + if len(res.Array()) != 0 { + t.Errorf("expected response to be empty array, got %+v", res) + } + return + } + + for i, item := range res.Array() { + if i%2 == 0 { + field := item.String() + value := res.Array()[i+1].String() + if test.expectedResponse[field] != value { + t.Errorf("expected value at field \"%s\" to be \"%s\", got \"%s\"", field, test.expectedResponse[field], value) + } + } + } + + }) + } + }) + + t.Run("Test_HandleHEXISTS", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse bool + expectedError error + }{ + { + name: "1. Return 1 if the field exists in the hash", + key: "HexistsKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142"}, + command: []string{"HEXISTS", "HexistsKey1", "field1"}, + expectedResponse: true, + expectedError: nil, + }, + { + name: "2. 0 response when trying to call HEXISTS on non-existent key", + key: "HexistsKey2", + presetValue: nil, + command: []string{"HEXISTS", "HexistsKey2", "field1"}, + expectedResponse: false, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "HexistsKey3", + presetValue: nil, + command: []string{"HEXISTS", "HexistsKey3"}, + expectedResponse: false, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "HexistsKey4", + presetValue: nil, + command: []string{"HEXISTS", "HexistsKey4", "field1", "field2"}, + expectedResponse: false, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HexistsKey5", + presetValue: "Default value", + command: []string{"HEXISTS", "HexistsKey5", "field1"}, + expectedResponse: false, + expectedError: errors.New("value at HexistsKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Bool() != test.expectedResponse { + t.Errorf("expected response to be %v, got %v", test.expectedResponse, res.Bool()) + } + }) + } + }) + + t.Run("Test_HandleHDEL", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int + expectedValue map[string]string + expectedError error + }{ + { + name: "1. Return count of deleted fields in the specified hash", + key: "HdelKey1", + presetValue: map[string]string{"field1": "value1", "field2": "123456789", "field3": "3.142", "field7": "value7"}, + command: []string{"HDEL", "HdelKey1", "field1", "field2", "field3", "field4", "field5", "field6"}, + expectedResponse: 3, + expectedValue: map[string]string{"field7": "value7"}, + expectedError: nil, + }, + { + name: "2. 0 response when passing delete fields that are non-existent on valid hash", + key: "HdelKey2", + presetValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, + command: []string{"HDEL", "HdelKey2", "field4", "field5", "field6"}, + expectedResponse: 0, + expectedValue: map[string]string{"field1": "value1", "field2": "value2", "field3": "value3"}, + expectedError: nil, + }, + { + name: "3. 0 response when trying to call HDEL on non-existent key", + key: "HdelKey3", + presetValue: nil, + command: []string{"HDEL", "HdelKey3", "field1"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: nil, + }, + { + name: "4. Command too short", + key: "HdelKey4", + presetValue: nil, + command: []string{"HDEL", "HdelKey4"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to get lengths on a non hash map returns error", + key: "HdelKey5", + presetValue: "Default value", + command: []string{"HDEL", "HdelKey5", "field1"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New("value at HdelKey5 is not a hash"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case map[string]string: + command = []resp.Value{resp.StringValue("HSET"), resp.StringValue(test.key)} + for key, value := range test.presetValue.(map[string]string) { + command = append(command, []resp.Value{ + resp.StringValue(key), + resp.StringValue(value)}..., + ) + } + expected = strconv.Itoa(len(test.presetValue.(map[string]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + for idx, field := range res.Array() { + if idx%2 == 0 { + if res.Array()[idx+1].String() != test.expectedValue[field.String()] { + t.Errorf( + "expected value \"%+v\" for field \"%s\", got \"%+v\"", + test.expectedValue[field.String()], field.String(), res.Array()[idx+1].String(), + ) + } + } + } + }) + } + }) } diff --git a/internal/modules/list/commands_test.go b/internal/modules/list/commands_test.go index b368b10..1e93f2c 100644 --- a/internal/modules/list/commands_test.go +++ b/internal/modules/list/commands_test.go @@ -30,20 +30,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var addr = "localhost" -var port int +func Test_List(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -51,1172 +57,93 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleLLEN(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. If key exists and is a list, return the lists length", - key: "LlenKey1", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LLEN", "LlenKey1"}, - expectedResponse: 4, - expectedError: nil, - }, - { - name: "2. If key does not exist, return 0", - key: "LlenKey2", - presetValue: nil, - command: []string{"LLEN", "LlenKey2"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "LlenKey3", - presetValue: nil, - command: []string{"LLEN"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "LlenKey4", - presetValue: nil, - command: []string{"LLEN", "LlenKey4", "LlenKey4"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to get lengths on a non-list returns error", - key: "LlenKey5", - presetValue: "Default value", - command: []string{"LLEN", "LlenKey5"}, - expectedResponse: 0, - expectedError: errors.New("LLEN command on non-list item"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response to be %d, got %d", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleLINDEX(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse string - expectedError error - }{ - { - name: "1. Return last element within range", - key: "LindexKey1", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LINDEX", "LindexKey1", "3"}, - expectedResponse: "value4", - expectedError: nil, - }, - { - name: "2. Return first element within range", - key: "LindexKey2", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LINDEX", "LindexKey1", "0"}, - expectedResponse: "value1", - expectedError: nil, - }, - { - name: "3. Return middle element within range", - key: "LindexKey3", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LINDEX", "LindexKey1", "1"}, - expectedResponse: "value2", - expectedError: nil, - }, - { - name: "4. If key does not exist, return error", - key: "LindexKey4", - presetValue: nil, - command: []string{"LINDEX", "LindexKey4", "0"}, - expectedResponse: "", - expectedError: errors.New("LINDEX command on non-list item"), - }, - { - name: "5. Command too short", - key: "LindexKey3", - presetValue: nil, - command: []string{"LINDEX", "LindexKey3"}, - expectedResponse: "", - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: " 6. Command too long", - key: "LindexKey4", - presetValue: nil, - command: []string{"LINDEX", "LindexKey4", "0", "20"}, - expectedResponse: "", - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Trying to get element by index on a non-list returns error", - key: "LindexKey5", - presetValue: "Default value", - command: []string{"LINDEX", "LindexKey5", "0"}, - expectedResponse: "", - expectedError: errors.New("LINDEX command on non-list item"), - }, - { - name: "8. Trying to get index out of range index beyond last index", - key: "LindexKey6", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LINDEX", "LindexKey6", "3"}, - expectedResponse: "", - expectedError: errors.New("index must be within list range"), - }, - { - name: "9. Trying to get index out of range with negative index", - key: "LindexKey7", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LINDEX", "LindexKey7", "-1"}, - expectedResponse: "", - expectedError: errors.New("index must be within list range"), - }, - { - name: " 10. Return error when index is not an integer", - key: "LindexKey8", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LINDEX", "LindexKey8", "index"}, - expectedResponse: "", - expectedError: errors.New("index must be an integer"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } - }) - } -} - -func Test_HandleLRANGE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - // Return sub-list within range. - // Both start and end indices are positive. - // End index is greater than start index. - name: "1. Return sub-list within range.", - key: "LrangeKey1", - presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, - command: []string{"LRANGE", "LrangeKey1", "3", "6"}, - expectedResponse: []string{"value4", "value5", "value6", "value7"}, - expectedError: nil, - }, - { - name: "2. Return sub-list from start index to the end of the list when end index is -1", - key: "LrangeKey2", - presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, - command: []string{"LRANGE", "LrangeKey2", "3", "-1"}, - expectedResponse: []string{"value4", "value5", "value6", "value7", "value8"}, - expectedError: nil, - }, - { - name: "3. Return the reversed sub-list when the end index is greater than -1 but less than start index", - key: "LrangeKey3", - presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, - command: []string{"LRANGE", "LrangeKey3", "3", "0"}, - expectedResponse: []string{"value4", "value3", "value2", "value1"}, - expectedError: nil, - }, - { - name: "4. If key does not exist, return error", - key: "LrangeKey4", - presetValue: nil, - command: []string{"LRANGE", "LrangeKey4", "0", "2"}, - expectedResponse: nil, - expectedError: errors.New("LRANGE command on non-list item"), - }, - { - name: "5. Command too short", - key: "LrangeKey5", - presetValue: nil, - command: []string{"LRANGE", "LrangeKey5"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - key: "LrangeKey6", - presetValue: nil, - command: []string{"LRANGE", "LrangeKey6", "0", "element", "element"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Error when executing command on non-list command", - key: "LrangeKey5", - presetValue: "Default value", - command: []string{"LRANGE", "LrangeKey5", "0", "3"}, - expectedResponse: nil, - expectedError: errors.New("LRANGE command on non-list item"), - }, - { - name: "8. Error when start index is less than 0", - key: "LrangeKey7", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LRANGE", "LrangeKey7", "-1", "3"}, - expectedResponse: nil, - expectedError: errors.New("start index must be within list boundary"), - }, - { - name: "9. Error when start index is higher than the length of the list", - key: "LrangeKey8", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LRANGE", "LrangeKey8", "10", "11"}, - expectedResponse: nil, - expectedError: errors.New("start index must be within list boundary"), - }, - { - name: "10. Return error when start index is not an integer", - key: "LrangeKey9", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LRANGE", "LrangeKey9", "start", "7"}, - expectedResponse: nil, - expectedError: errors.New("start and end indices must be integers"), - }, - { - name: "11. Return error when end index is not an integer", - key: "LrangeKey10", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LRANGE", "LrangeKey10", "0", "end"}, - expectedResponse: nil, - expectedError: errors.New("start and end indices must be integers"), - }, - { - name: "12. Error when start and end indices are equal", - key: "LrangeKey11", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LRANGE", "LrangeKey11", "1", "1"}, - expectedResponse: nil, - expectedError: errors.New("start and end indices cannot be equal"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response of length %d, got length %d", len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleLSET(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue []string - expectedError error - }{ - { - name: "1. Return last element within range", - key: "LsetKey1", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LSET", "LsetKey1", "3", "new-value"}, - expectedValue: []string{"value1", "value2", "value3", "new-value"}, - expectedError: nil, - }, - { - name: "2. Return first element within range", - key: "LsetKey2", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LSET", "LsetKey2", "0", "new-value"}, - expectedValue: []string{"new-value", "value2", "value3", "value4"}, - expectedError: nil, - }, - { - name: "3. Return middle element within range", - key: "LsetKey3", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LSET", "LsetKey3", "1", "new-value"}, - expectedValue: []string{"value1", "new-value", "value3", "value4"}, - expectedError: nil, - }, - { - name: "4. If key does not exist, return error", - key: "LsetKey4", - presetValue: nil, - command: []string{"LSET", "LsetKey4", "0", "element"}, - expectedValue: nil, - expectedError: errors.New("LSET command on non-list item"), - }, - { - name: "5. Command too short", - key: "LsetKey5", - presetValue: nil, - command: []string{"LSET", "LsetKey5"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - key: "LsetKey6", - presetValue: nil, - command: []string{"LSET", "LsetKey6", "0", "element", "element"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Trying to get element by index on a non-list returns error", - key: "LsetKey5", - presetValue: "Default value", - command: []string{"LSET", "LsetKey5", "0", "element"}, - expectedValue: nil, - expectedError: errors.New("LSET command on non-list item"), - }, - { - name: "8. Trying to get index out of range index beyond last index", - key: "LsetKey6", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LSET", "LsetKey6", "3", "element"}, - expectedValue: nil, - expectedError: errors.New("index must be within list range"), - }, - { - name: "9. Trying to get index out of range with negative index", - key: "LsetKey7", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LSET", "LsetKey7", "-1", "element"}, - expectedValue: nil, - expectedError: errors.New("index must be within list range"), - }, - { - name: "10. Return error when index is not an integer", - key: "LsetKey8", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LSET", "LsetKey8", "index", "element"}, - expectedValue: nil, - expectedError: errors.New("index must be an integer"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected response OK, got \"%s\"", res.String()) - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) - } - } - }) - } -} - -func Test_HandleLTRIM(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue []string - expectedError error - }{ - { - // Return trim within range. - // Both start and end indices are positive. - // End index is greater than start index. - name: "1. Return trim within range.", - key: "LtrimKey1", - presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, - command: []string{"LTRIM", "LtrimKey1", "3", "6"}, - expectedValue: []string{"value4", "value5", "value6"}, - expectedError: nil, - }, - { - name: "2. Return element from start index to end index when end index is greater than length of the list", - key: "LtrimKey2", - presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, - command: []string{"LTRIM", "LtrimKey2", "5", "-1"}, - expectedValue: []string{"value6", "value7", "value8"}, - expectedError: nil, - }, - { - name: "3. Return error when end index is smaller than start index but greater than -1", - key: "LtrimKey3", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LTRIM", "LtrimKey3", "3", "1"}, - expectedValue: nil, - expectedError: errors.New("end index must be greater than start index or -1"), - }, - { - name: "4. If key does not exist, return error", - key: "LtrimKey4", - presetValue: nil, - command: []string{"LTRIM", "LtrimKey4", "0", "2"}, - expectedValue: nil, - expectedError: errors.New("LTRIM command on non-list item"), - }, - { - name: "5. Command too short", - key: "LtrimKey5", - presetValue: nil, - command: []string{"LTRIM", "LtrimKey5"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - key: "LtrimKey6", - presetValue: nil, - command: []string{"LTRIM", "LtrimKey6", "0", "element", "element"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Trying to get element by index on a non-list returns error", - key: "LtrimKey5", - presetValue: "Default value", - command: []string{"LTRIM", "LtrimKey5", "0", "3"}, - expectedValue: nil, - expectedError: errors.New("LTRIM command on non-list item"), - }, - { - name: "8. Error when start index is less than 0", - key: "LtrimKey7", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LTRIM", "LtrimKey7", "-1", "3"}, - expectedValue: nil, - expectedError: errors.New("start index must be within list boundary"), - }, - { - name: "9. Error when start index is higher than the length of the list", - key: "LtrimKey8", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LTRIM", "LtrimKey8", "10", "11"}, - expectedValue: nil, - expectedError: errors.New("start index must be within list boundary"), - }, - { - name: "10. Return error when start index is not an integer", - key: "LtrimKey9", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LTRIM", "LtrimKey9", "start", "7"}, - expectedValue: nil, - expectedError: errors.New("start and end indices must be integers"), - }, - { - name: "11. Return error when end index is not an integer", - key: "LtrimKey10", - presetValue: []string{"value1", "value2", "value3"}, - command: []string{"LTRIM", "LtrimKey10", "0", "end"}, - expectedValue: nil, - expectedError: errors.New("start and end indices must be integers"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected response OK, got \"%s\"", res.String()) - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) - } - } - }) - } -} - -func Test_HandleLREM(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue []string - expectedError error - }{ - { - name: "1. Remove the first 3 elements that appear in the list", - key: "LremKey1", - presetValue: []string{"1", "2", "4", "4", "5", "6", "7", "4", "8", "4", "9", "10", "5", "4"}, - command: []string{"LREM", "LremKey1", "3", "4"}, - expectedValue: []string{"1", "2", "5", "6", "7", "8", "4", "9", "10", "5", "4"}, - expectedError: nil, - }, - { - name: "2. Remove the last 3 elements that appear in the list", - key: "LremKey2", - presetValue: []string{"1", "2", "4", "4", "5", "6", "7", "4", "8", "4", "9", "10", "5", "4"}, - command: []string{"LREM", "LremKey2", "-3", "4"}, - expectedValue: []string{"1", "2", "4", "4", "5", "6", "7", "8", "9", "10", "5"}, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "LremKey3", - presetValue: nil, - command: []string{"LREM", "LremKey3"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "LremKey4", - presetValue: nil, - command: []string{"LREM", "LremKey4", "0", "element", "element"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Throw error when count is not an integer", - key: "LremKey5", - presetValue: nil, - command: []string{"LREM", "LremKey5", "count", "value1"}, - expectedValue: nil, - expectedError: errors.New("count must be an integer"), - }, - { - name: "6. Throw error on non-list item", - key: "LremKey6", - presetValue: "Default value", - command: []string{"LREM", "LremKey6", "0", "value1"}, - expectedValue: nil, - expectedError: errors.New("LREM command on non-list item"), - }, - { - name: "7. Throw error on non-existent item", - key: "LremKey7", - presetValue: "Default value", - command: []string{"LREM", "LremKey7", "0", "value1"}, - expectedValue: nil, - expectedError: errors.New("LREM command on non-list item"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(len(test.presetValue.([]string))) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected response OK, got \"%s\"", res.String()) - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) - } - } - }) - } -} - -func Test_HandleLMOVE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue map[string]interface{} - command []string - expectedValue map[string][]string - expectedError error - }{ - { - name: "1. Move element from LEFT of left list to LEFT of right list", - presetValue: map[string]interface{}{ - "source1": []string{"one", "two", "three"}, - "destination1": []string{"one", "two", "three"}, + t.Cleanup(func() { + mockServer.ShutDown() + }) + + t.Run("Test_HandleLLEN", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. If key exists and is a list, return the lists length", + key: "LlenKey1", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LLEN", "LlenKey1"}, + expectedResponse: 4, + expectedError: nil, }, - command: []string{"LMOVE", "source1", "destination1", "LEFT", "LEFT"}, - expectedValue: map[string][]string{ - "source1": {"two", "three"}, - "destination1": {"one", "one", "two", "three"}, + { + name: "2. If key does not exist, return 0", + key: "LlenKey2", + presetValue: nil, + command: []string{"LLEN", "LlenKey2"}, + expectedResponse: 0, + expectedError: nil, }, - expectedError: nil, - }, - { - name: "2. Move element from LEFT of left list to RIGHT of right list", - presetValue: map[string]interface{}{ - "source2": []string{"one", "two", "three"}, - "destination2": []string{"one", "two", "three"}, + { + name: "3. Command too short", + key: "LlenKey3", + presetValue: nil, + command: []string{"LLEN"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"LMOVE", "source2", "destination2", "LEFT", "RIGHT"}, - expectedValue: map[string][]string{ - "source2": {"two", "three"}, - "destination2": {"one", "two", "three", "one"}, + { + name: "4. Command too long", + key: "LlenKey4", + presetValue: nil, + command: []string{"LLEN", "LlenKey4", "LlenKey4"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - expectedError: nil, - }, - { - name: "3. Move element from RIGHT of left list to LEFT of right list", - presetValue: map[string]interface{}{ - "source3": []string{"one", "two", "three"}, - "destination3": []string{"one", "two", "three"}, + { + name: "5. Trying to get lengths on a non-list returns error", + key: "LlenKey5", + presetValue: "Default value", + command: []string{"LLEN", "LlenKey5"}, + expectedResponse: 0, + expectedError: errors.New("LLEN command on non-list item"), }, - command: []string{"LMOVE", "source3", "destination3", "RIGHT", "LEFT"}, - expectedValue: map[string][]string{ - "source3": {"one", "two"}, - "destination3": {"three", "one", "two", "three"}, - }, - expectedError: nil, - }, - { - name: "4. Move element from RIGHT of left list to RIGHT of right list", - presetValue: map[string]interface{}{ - "source4": []string{"one", "two", "three"}, - "destination4": []string{"one", "two", "three"}, - }, - command: []string{"LMOVE", "source4", "destination4", "RIGHT", "RIGHT"}, - expectedValue: map[string][]string{ - "source4": {"one", "two"}, - "destination4": {"one", "two", "three", "three"}, - }, - expectedError: nil, - }, - { - name: "5. Throw error when the right list is non-existent", - presetValue: map[string]interface{}{ - "source5": []string{"one", "two", "three"}, - }, - command: []string{"LMOVE", "source5", "destination5", "LEFT", "LEFT"}, - expectedValue: nil, - expectedError: errors.New("both source and destination must be lists"), - }, - { - name: "6. Throw error when right list in not a list", - presetValue: map[string]interface{}{ - "source6": []string{"one", "two", "tree"}, - "destination6": "Default value", - }, - command: []string{"LMOVE", "source6", "destination6", "LEFT", "LEFT"}, - expectedValue: nil, - expectedError: errors.New("both source and destination must be lists"), - }, - { - name: "7. Throw error when left list is non-existent", - presetValue: map[string]interface{}{ - "destination7": []string{"one", "two", "three"}, - }, - command: []string{"LMOVE", "source7", "destination7", "LEFT", "LEFT"}, - expectedValue: nil, - expectedError: errors.New("both source and destination must be lists"), - }, - { - name: "8. Throw error when left list is not a list", - presetValue: map[string]interface{}{ - "source8": "Default value", - "destination8": []string{"one", "two", "three"}, - }, - command: []string{"LMOVE", "source8", "destination8", "LEFT", "LEFT"}, - expectedValue: nil, - expectedError: errors.New("both source and destination must be lists"), - }, - { - name: "9. Throw error when command is too short", - presetValue: map[string]interface{}{}, - command: []string{"LMOVE", "source9", "destination9"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "10. Throw error when command is too long", - presetValue: map[string]interface{}{}, - command: []string{"LMOVE", "source10", "destination10", "LEFT", "LEFT", "RIGHT"}, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "11. Throw error when WHEREFROM argument is not LEFT/RIGHT", - presetValue: map[string]interface{}{}, - command: []string{"LMOVE", "source11", "destination11", "UP", "RIGHT"}, - expectedValue: nil, - expectedError: errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT"), - }, - { - name: "12. Throw error when WHERETO argument is not LEFT/RIGHT", - presetValue: map[string]interface{}{}, - command: []string{"LMOVE", "source11", "destination11", "LEFT", "DOWN"}, - expectedValue: nil, - expectedError: errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - for key, value := range test.presetValue { + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { var command []resp.Value var expected string - switch value.(type) { + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(key)} - for _, element := range value.([]string) { + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { command = append(command, []resp.Value{resp.StringValue(element)}...) } - expected = strconv.Itoa(len(value.([]string))) + expected = strconv.Itoa(len(test.presetValue.([]string))) } if err = client.WriteArray(command); err != nil { @@ -1231,36 +158,551 @@ func Test_HandleLMOVE(t *testing.T) { t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected response OK, got \"%s\"", res.String()) - } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response to be %d, got %d", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleLINDEX", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse string + expectedError error + }{ + { + name: "1. Return last element within range", + key: "LindexKey1", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LINDEX", "LindexKey1", "3"}, + expectedResponse: "value4", + expectedError: nil, + }, + { + name: "2. Return first element within range", + key: "LindexKey2", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LINDEX", "LindexKey1", "0"}, + expectedResponse: "value1", + expectedError: nil, + }, + { + name: "3. Return middle element within range", + key: "LindexKey3", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LINDEX", "LindexKey1", "1"}, + expectedResponse: "value2", + expectedError: nil, + }, + { + name: "4. If key does not exist, return error", + key: "LindexKey4", + presetValue: nil, + command: []string{"LINDEX", "LindexKey4", "0"}, + expectedResponse: "", + expectedError: errors.New("LINDEX command on non-list item"), + }, + { + name: "5. Command too short", + key: "LindexKey3", + presetValue: nil, + command: []string{"LINDEX", "LindexKey3"}, + expectedResponse: "", + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: " 6. Command too long", + key: "LindexKey4", + presetValue: nil, + command: []string{"LINDEX", "LindexKey4", "0", "20"}, + expectedResponse: "", + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Trying to get element by index on a non-list returns error", + key: "LindexKey5", + presetValue: "Default value", + command: []string{"LINDEX", "LindexKey5", "0"}, + expectedResponse: "", + expectedError: errors.New("LINDEX command on non-list item"), + }, + { + name: "8. Trying to get index out of range index beyond last index", + key: "LindexKey6", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LINDEX", "LindexKey6", "3"}, + expectedResponse: "", + expectedError: errors.New("index must be within list range"), + }, + { + name: "9. Trying to get index out of range with negative index", + key: "LindexKey7", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LINDEX", "LindexKey7", "-1"}, + expectedResponse: "", + expectedError: errors.New("index must be within list range"), + }, + { + name: " 10. Return error when index is not an integer", + key: "LindexKey8", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LINDEX", "LindexKey8", "index"}, + expectedResponse: "", + expectedError: errors.New("index must be an integer"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + }) + } + }) + + t.Run("Test_HandleLRANGE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + // Return sub-list within range. + // Both start and end indices are positive. + // End index is greater than start index. + name: "1. Return sub-list within range.", + key: "LrangeKey1", + presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, + command: []string{"LRANGE", "LrangeKey1", "3", "6"}, + expectedResponse: []string{"value4", "value5", "value6", "value7"}, + expectedError: nil, + }, + { + name: "2. Return sub-list from start index to the end of the list when end index is -1", + key: "LrangeKey2", + presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, + command: []string{"LRANGE", "LrangeKey2", "3", "-1"}, + expectedResponse: []string{"value4", "value5", "value6", "value7", "value8"}, + expectedError: nil, + }, + { + name: "3. Return the reversed sub-list when the end index is greater than -1 but less than start index", + key: "LrangeKey3", + presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, + command: []string{"LRANGE", "LrangeKey3", "3", "0"}, + expectedResponse: []string{"value4", "value3", "value2", "value1"}, + expectedError: nil, + }, + { + name: "4. If key does not exist, return error", + key: "LrangeKey4", + presetValue: nil, + command: []string{"LRANGE", "LrangeKey4", "0", "2"}, + expectedResponse: nil, + expectedError: errors.New("LRANGE command on non-list item"), + }, + { + name: "5. Command too short", + key: "LrangeKey5", + presetValue: nil, + command: []string{"LRANGE", "LrangeKey5"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + key: "LrangeKey6", + presetValue: nil, + command: []string{"LRANGE", "LrangeKey6", "0", "element", "element"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Error when executing command on non-list command", + key: "LrangeKey5", + presetValue: "Default value", + command: []string{"LRANGE", "LrangeKey5", "0", "3"}, + expectedResponse: nil, + expectedError: errors.New("LRANGE command on non-list item"), + }, + { + name: "8. Error when start index is less than 0", + key: "LrangeKey7", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LRANGE", "LrangeKey7", "-1", "3"}, + expectedResponse: nil, + expectedError: errors.New("start index must be within list boundary"), + }, + { + name: "9. Error when start index is higher than the length of the list", + key: "LrangeKey8", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LRANGE", "LrangeKey8", "10", "11"}, + expectedResponse: nil, + expectedError: errors.New("start index must be within list boundary"), + }, + { + name: "10. Return error when start index is not an integer", + key: "LrangeKey9", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LRANGE", "LrangeKey9", "start", "7"}, + expectedResponse: nil, + expectedError: errors.New("start and end indices must be integers"), + }, + { + name: "11. Return error when end index is not an integer", + key: "LrangeKey10", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LRANGE", "LrangeKey10", "0", "end"}, + expectedResponse: nil, + expectedError: errors.New("start and end indices must be integers"), + }, + { + name: "12. Error when start and end indices are equal", + key: "LrangeKey11", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LRANGE", "LrangeKey11", "1", "1"}, + expectedResponse: nil, + expectedError: errors.New("start and end indices cannot be equal"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response of length %d, got length %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleLSET", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue []string + expectedError error + }{ + { + name: "1. Return last element within range", + key: "LsetKey1", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LSET", "LsetKey1", "3", "new-value"}, + expectedValue: []string{"value1", "value2", "value3", "new-value"}, + expectedError: nil, + }, + { + name: "2. Return first element within range", + key: "LsetKey2", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LSET", "LsetKey2", "0", "new-value"}, + expectedValue: []string{"new-value", "value2", "value3", "value4"}, + expectedError: nil, + }, + { + name: "3. Return middle element within range", + key: "LsetKey3", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LSET", "LsetKey3", "1", "new-value"}, + expectedValue: []string{"value1", "new-value", "value3", "value4"}, + expectedError: nil, + }, + { + name: "4. If key does not exist, return error", + key: "LsetKey4", + presetValue: nil, + command: []string{"LSET", "LsetKey4", "0", "element"}, + expectedValue: nil, + expectedError: errors.New("LSET command on non-list item"), + }, + { + name: "5. Command too short", + key: "LsetKey5", + presetValue: nil, + command: []string{"LSET", "LsetKey5"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + key: "LsetKey6", + presetValue: nil, + command: []string{"LSET", "LsetKey6", "0", "element", "element"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Trying to get element by index on a non-list returns error", + key: "LsetKey5", + presetValue: "Default value", + command: []string{"LSET", "LsetKey5", "0", "element"}, + expectedValue: nil, + expectedError: errors.New("LSET command on non-list item"), + }, + { + name: "8. Trying to get index out of range index beyond last index", + key: "LsetKey6", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LSET", "LsetKey6", "3", "element"}, + expectedValue: nil, + expectedError: errors.New("index must be within list range"), + }, + { + name: "9. Trying to get index out of range with negative index", + key: "LsetKey7", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LSET", "LsetKey7", "-1", "element"}, + expectedValue: nil, + expectedError: errors.New("index must be within list range"), + }, + { + name: "10. Return error when index is not an integer", + key: "LsetKey8", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LSET", "LsetKey8", "index", "element"}, + expectedValue: nil, + expectedError: errors.New("index must be an integer"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected response OK, got \"%s\"", res.String()) + } - for key, list := range test.expectedValue { if err = client.WriteArray([]resp.Value{ resp.StringValue("LRANGE"), - resp.StringValue(key), + resp.StringValue(test.key), resp.StringValue("0"), resp.StringValue("-1"), }); err != nil { @@ -1272,105 +714,171 @@ func Test_HandleLMOVE(t *testing.T) { t.Error(err) } - if len(res.Array()) != len(list) { + if len(res.Array()) != len(test.expectedValue) { t.Errorf("expected list at key \"%s\" to be length %d, got %d", - key, len(test.expectedValue), len(res.Array())) + test.key, len(test.expectedValue), len(res.Array())) } for _, item := range res.Array() { - if !slices.Contains(list, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list %s", item.String(), key) + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) } } - } - }) - } -} + }) + } + }) -func Test_HandleLPUSH(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleLTRIM", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int - expectedValue []string - expectedError error - }{ - { - name: "1. LPUSHX to existing list prepends the element to the list", - key: "LpushKey1", - presetValue: []string{"1", "2", "4", "5"}, - command: []string{"LPUSHX", "LpushKey1", "value1", "value2"}, - expectedResponse: 6, - expectedValue: []string{"value1", "value2", "1", "2", "4", "5"}, - expectedError: nil, - }, - { - name: "2. LPUSH on existing list prepends the elements to the list", - key: "LpushKey2", - presetValue: []string{"1", "2", "4", "5"}, - command: []string{"LPUSH", "LpushKey2", "value1", "value2"}, - expectedResponse: 6, - expectedValue: []string{"value1", "value2", "1", "2", "4", "5"}, - expectedError: nil, - }, - { - name: "3. LPUSH on non-existent list creates the list", - key: "LpushKey3", - presetValue: nil, - command: []string{"LPUSH", "LpushKey3", "value1", "value2"}, - expectedResponse: 2, - expectedValue: []string{"value1", "value2"}, - expectedError: nil, - }, - { - name: "4. Command too short", - key: "LpushKey5", - presetValue: nil, - command: []string{"LPUSH", "LpushKey5"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. LPUSHX command returns error on non-existent list", - key: "LpushKey6", - presetValue: nil, - command: []string{"LPUSHX", "LpushKey7", "count", "value1"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New("LPUSHX command on non-existent key"), - }, - } + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue []string + expectedError error + }{ + { + // Return trim within range. + // Both start and end indices are positive. + // End index is greater than start index. + name: "1. Return trim within range.", + key: "LtrimKey1", + presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, + command: []string{"LTRIM", "LtrimKey1", "3", "6"}, + expectedValue: []string{"value4", "value5", "value6"}, + expectedError: nil, + }, + { + name: "2. Return element from start index to end index when end index is greater than length of the list", + key: "LtrimKey2", + presetValue: []string{"value1", "value2", "value3", "value4", "value5", "value6", "value7", "value8"}, + command: []string{"LTRIM", "LtrimKey2", "5", "-1"}, + expectedValue: []string{"value6", "value7", "value8"}, + expectedError: nil, + }, + { + name: "3. Return error when end index is smaller than start index but greater than -1", + key: "LtrimKey3", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LTRIM", "LtrimKey3", "3", "1"}, + expectedValue: nil, + expectedError: errors.New("end index must be greater than start index or -1"), + }, + { + name: "4. If key does not exist, return error", + key: "LtrimKey4", + presetValue: nil, + command: []string{"LTRIM", "LtrimKey4", "0", "2"}, + expectedValue: nil, + expectedError: errors.New("LTRIM command on non-list item"), + }, + { + name: "5. Command too short", + key: "LtrimKey5", + presetValue: nil, + command: []string{"LTRIM", "LtrimKey5"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + key: "LtrimKey6", + presetValue: nil, + command: []string{"LTRIM", "LtrimKey6", "0", "element", "element"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Trying to get element by index on a non-list returns error", + key: "LtrimKey5", + presetValue: "Default value", + command: []string{"LTRIM", "LtrimKey5", "0", "3"}, + expectedValue: nil, + expectedError: errors.New("LTRIM command on non-list item"), + }, + { + name: "8. Error when start index is less than 0", + key: "LtrimKey7", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LTRIM", "LtrimKey7", "-1", "3"}, + expectedValue: nil, + expectedError: errors.New("start index must be within list boundary"), + }, + { + name: "9. Error when start index is higher than the length of the list", + key: "LtrimKey8", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LTRIM", "LtrimKey8", "10", "11"}, + expectedValue: nil, + expectedError: errors.New("start index must be within list boundary"), + }, + { + name: "10. Return error when start index is not an integer", + key: "LtrimKey9", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LTRIM", "LtrimKey9", "start", "7"}, + expectedValue: nil, + expectedError: errors.New("start and end indices must be integers"), + }, + { + name: "11. Return error when end index is not an integer", + key: "LtrimKey10", + presetValue: []string{"value1", "value2", "value3"}, + command: []string{"LTRIM", "LtrimKey10", "0", "end"}, + expectedValue: nil, + expectedError: errors.New("start and end indices must be integers"), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) + + if err = client.WriteArray(command); err != nil { + t.Error(err) } - expected = strconv.Itoa(len(test.presetValue.([]string))) + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1381,147 +889,161 @@ func Test_HandleLPUSH(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) - } - } - }) - } -} - -func Test_HandleRPUSH(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse int - expectedValue []string - expectedError error - }{ - { - name: "1. RPUSHX to existing list prepends the element to the list", - key: "RpushKey1", - presetValue: []string{"1", "2", "4", "5"}, - command: []string{"RPUSHX", "RpushKey1", "value1", "value2"}, - expectedResponse: 6, - expectedValue: []string{"1", "2", "4", "5", "value1", "value2"}, - expectedError: nil, - }, - { - name: "2. RPUSH on existing list prepends the elements to the list", - key: "RpushKey2", - presetValue: []string{"1", "2", "4", "5"}, - command: []string{"RPUSH", "RpushKey2", "value1", "value2"}, - expectedResponse: 6, - expectedValue: []string{"1", "2", "4", "5", "value1", "value2"}, - expectedError: nil, - }, - { - name: "3. RPUSH on non-existent list creates the list", - key: "RpushKey3", - presetValue: nil, - command: []string{"RPUSH", "RpushKey3", "value1", "value2"}, - expectedResponse: 2, - expectedValue: []string{"value1", "value2"}, - expectedError: nil, - }, - { - name: "4. Command too short", - key: "RpushKey5", - presetValue: nil, - command: []string{"RPUSH", "RpushKey5"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. RPUSHX command returns error on non-existent list", - key: "RpushKey6", - presetValue: nil, - command: []string{"RPUSHX", "RpushKey7", "count", "value1"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: errors.New("RPUSHX command on non-existent key"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) + return + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected response OK, got \"%s\"", res.String()) + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(test.key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != len(test.expectedValue) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + test.key, len(test.expectedValue), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) } - expected = strconv.Itoa(len(test.presetValue.([]string))) + } + }) + } + }) + + t.Run("Test_HandleLREM", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue []string + expectedError error + }{ + { + name: "1. Remove the first 3 elements that appear in the list", + key: "LremKey1", + presetValue: []string{"1", "2", "4", "4", "5", "6", "7", "4", "8", "4", "9", "10", "5", "4"}, + command: []string{"LREM", "LremKey1", "3", "4"}, + expectedValue: []string{"1", "2", "5", "6", "7", "8", "4", "9", "10", "5", "4"}, + expectedError: nil, + }, + { + name: "2. Remove the last 3 elements that appear in the list", + key: "LremKey2", + presetValue: []string{"1", "2", "4", "4", "5", "6", "7", "4", "8", "4", "9", "10", "5", "4"}, + command: []string{"LREM", "LremKey2", "-3", "4"}, + expectedValue: []string{"1", "2", "4", "4", "5", "6", "7", "8", "9", "10", "5"}, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "LremKey3", + presetValue: nil, + command: []string{"LREM", "LremKey3"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "LremKey4", + presetValue: nil, + command: []string{"LREM", "LremKey4", "0", "element", "element"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Throw error when count is not an integer", + key: "LremKey5", + presetValue: nil, + command: []string{"LREM", "LremKey5", "count", "value1"}, + expectedValue: nil, + expectedError: errors.New("count must be an integer"), + }, + { + name: "6. Throw error on non-list item", + key: "LremKey6", + presetValue: "Default value", + command: []string{"LREM", "LremKey6", "0", "value1"}, + expectedValue: nil, + expectedError: errors.New("LREM command on non-list item"), + }, + { + name: "7. Throw error on non-existent item", + key: "LremKey7", + presetValue: "Default value", + command: []string{"LREM", "LremKey7", "0", "value1"}, + expectedValue: nil, + expectedError: errors.New("LREM command on non-list item"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1532,156 +1054,225 @@ func Test_HandleRPUSH(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) - } - } - }) - } -} - -func Test_HandlePOP(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse string - expectedValue []string - expectedError error - }{ - { - name: "1. LPOP returns last element and removed first element from the list", - key: "PopKey1", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"LPOP", "PopKey1"}, - expectedResponse: "value1", - expectedValue: []string{"value2", "value3", "value4"}, - expectedError: nil, - }, - { - name: "2. RPOP returns last element and removed last element from the list", - key: "PopKey2", - presetValue: []string{"value1", "value2", "value3", "value4"}, - command: []string{"RPOP", "PopKey2"}, - expectedResponse: "value4", - expectedValue: []string{"value1", "value2", "value3"}, - expectedError: nil, - }, - { - name: "3. Command too short", - key: "PopKey3", - presetValue: nil, - command: []string{"LPOP"}, - expectedResponse: "", - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - key: "PopKey4", - presetValue: nil, - command: []string{"LPOP", "PopKey4", "PopKey4"}, - expectedResponse: "", - expectedValue: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Trying to execute LPOP from a non-list item return an error", - key: "PopKey5", - presetValue: "Default value", - command: []string{"LPOP", "PopKey5"}, - expectedResponse: "", - expectedValue: nil, - expectedError: errors.New("LPOP command on non-list item"), - }, - { - name: "6. Trying to execute RPOP from a non-list item return an error", - key: "PopKey6", - presetValue: "Default value", - command: []string{"RPOP", "PopKey6"}, - expectedResponse: "", - expectedValue: nil, - expectedError: errors.New("RPOP command on non-list item"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case []string: - command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} - for _, element := range test.presetValue.([]string) { - command = append(command, []resp.Value{resp.StringValue(element)}...) + return + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected response OK, got \"%s\"", res.String()) + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(test.key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != len(test.expectedValue) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + test.key, len(test.expectedValue), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) } - expected = strconv.Itoa(len(test.presetValue.([]string))) + } + }) + } + }) + + t.Run("Test_HandleLMOVE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValue map[string]interface{} + command []string + expectedValue map[string][]string + expectedError error + }{ + { + name: "1. Move element from LEFT of left list to LEFT of right list", + presetValue: map[string]interface{}{ + "source1": []string{"one", "two", "three"}, + "destination1": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source1", "destination1", "LEFT", "LEFT"}, + expectedValue: map[string][]string{ + "source1": {"two", "three"}, + "destination1": {"one", "one", "two", "three"}, + }, + expectedError: nil, + }, + { + name: "2. Move element from LEFT of left list to RIGHT of right list", + presetValue: map[string]interface{}{ + "source2": []string{"one", "two", "three"}, + "destination2": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source2", "destination2", "LEFT", "RIGHT"}, + expectedValue: map[string][]string{ + "source2": {"two", "three"}, + "destination2": {"one", "two", "three", "one"}, + }, + expectedError: nil, + }, + { + name: "3. Move element from RIGHT of left list to LEFT of right list", + presetValue: map[string]interface{}{ + "source3": []string{"one", "two", "three"}, + "destination3": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source3", "destination3", "RIGHT", "LEFT"}, + expectedValue: map[string][]string{ + "source3": {"one", "two"}, + "destination3": {"three", "one", "two", "three"}, + }, + expectedError: nil, + }, + { + name: "4. Move element from RIGHT of left list to RIGHT of right list", + presetValue: map[string]interface{}{ + "source4": []string{"one", "two", "three"}, + "destination4": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source4", "destination4", "RIGHT", "RIGHT"}, + expectedValue: map[string][]string{ + "source4": {"one", "two"}, + "destination4": {"one", "two", "three", "three"}, + }, + expectedError: nil, + }, + { + name: "5. Throw error when the right list is non-existent", + presetValue: map[string]interface{}{ + "source5": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source5", "destination5", "LEFT", "LEFT"}, + expectedValue: nil, + expectedError: errors.New("both source and destination must be lists"), + }, + { + name: "6. Throw error when right list in not a list", + presetValue: map[string]interface{}{ + "source6": []string{"one", "two", "tree"}, + "destination6": "Default value", + }, + command: []string{"LMOVE", "source6", "destination6", "LEFT", "LEFT"}, + expectedValue: nil, + expectedError: errors.New("both source and destination must be lists"), + }, + { + name: "7. Throw error when left list is non-existent", + presetValue: map[string]interface{}{ + "destination7": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source7", "destination7", "LEFT", "LEFT"}, + expectedValue: nil, + expectedError: errors.New("both source and destination must be lists"), + }, + { + name: "8. Throw error when left list is not a list", + presetValue: map[string]interface{}{ + "source8": "Default value", + "destination8": []string{"one", "two", "three"}, + }, + command: []string{"LMOVE", "source8", "destination8", "LEFT", "LEFT"}, + expectedValue: nil, + expectedError: errors.New("both source and destination must be lists"), + }, + { + name: "9. Throw error when command is too short", + presetValue: map[string]interface{}{}, + command: []string{"LMOVE", "source9", "destination9"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "10. Throw error when command is too long", + presetValue: map[string]interface{}{}, + command: []string{"LMOVE", "source10", "destination10", "LEFT", "LEFT", "RIGHT"}, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "11. Throw error when WHEREFROM argument is not LEFT/RIGHT", + presetValue: map[string]interface{}{}, + command: []string{"LMOVE", "source11", "destination11", "UP", "RIGHT"}, + expectedValue: nil, + expectedError: errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT"), + }, + { + name: "12. Throw error when WHERETO argument is not LEFT/RIGHT", + presetValue: map[string]interface{}{}, + command: []string{"LMOVE", "source11", "destination11", "LEFT", "DOWN"}, + expectedValue: nil, + expectedError: errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + for key, value := range test.presetValue { + + var command []resp.Value + var expected string + + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(key)} + for _, element := range value.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(value.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1692,59 +1283,518 @@ func Test_HandlePOP(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected response OK, got \"%s\"", res.String()) } - return - } - if res.String() != test.expectedResponse { - t.Errorf("expected response %s, got %s", test.expectedResponse, res.String()) - } + for key, list := range test.expectedValue { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } - if err = client.WriteArray([]resp.Value{ - resp.StringValue("LRANGE"), - resp.StringValue(test.key), - resp.StringValue("0"), - resp.StringValue("-1"), - }); err != nil { - t.Error(err) - } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } + if len(res.Array()) != len(list) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + key, len(test.expectedValue), len(res.Array())) + } - if len(res.Array()) != len(test.expectedValue) { - t.Errorf("expected list at key \"%s\" to be length %d, got %d", - test.key, len(test.expectedValue), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedValue, item.String()) { - t.Errorf("unexpected value \"%s\" in updated list", item.String()) + for _, item := range res.Array() { + if !slices.Contains(list, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list %s", item.String(), key) + } + } } - } - }) - } + }) + } + }) + + t.Run("Test_HandleLPUSH", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int + expectedValue []string + expectedError error + }{ + { + name: "1. LPUSHX to existing list prepends the element to the list", + key: "LpushKey1", + presetValue: []string{"1", "2", "4", "5"}, + command: []string{"LPUSHX", "LpushKey1", "value1", "value2"}, + expectedResponse: 6, + expectedValue: []string{"value1", "value2", "1", "2", "4", "5"}, + expectedError: nil, + }, + { + name: "2. LPUSH on existing list prepends the elements to the list", + key: "LpushKey2", + presetValue: []string{"1", "2", "4", "5"}, + command: []string{"LPUSH", "LpushKey2", "value1", "value2"}, + expectedResponse: 6, + expectedValue: []string{"value1", "value2", "1", "2", "4", "5"}, + expectedError: nil, + }, + { + name: "3. LPUSH on non-existent list creates the list", + key: "LpushKey3", + presetValue: nil, + command: []string{"LPUSH", "LpushKey3", "value1", "value2"}, + expectedResponse: 2, + expectedValue: []string{"value1", "value2"}, + expectedError: nil, + }, + { + name: "4. Command too short", + key: "LpushKey5", + presetValue: nil, + command: []string{"LPUSH", "LpushKey5"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. LPUSHX command returns error on non-existent list", + key: "LpushKey6", + presetValue: nil, + command: []string{"LPUSHX", "LpushKey7", "count", "value1"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New("LPUSHX command on non-existent key"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(test.key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != len(test.expectedValue) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + test.key, len(test.expectedValue), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleRPUSH", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse int + expectedValue []string + expectedError error + }{ + { + name: "1. RPUSHX to existing list prepends the element to the list", + key: "RpushKey1", + presetValue: []string{"1", "2", "4", "5"}, + command: []string{"RPUSHX", "RpushKey1", "value1", "value2"}, + expectedResponse: 6, + expectedValue: []string{"1", "2", "4", "5", "value1", "value2"}, + expectedError: nil, + }, + { + name: "2. RPUSH on existing list prepends the elements to the list", + key: "RpushKey2", + presetValue: []string{"1", "2", "4", "5"}, + command: []string{"RPUSH", "RpushKey2", "value1", "value2"}, + expectedResponse: 6, + expectedValue: []string{"1", "2", "4", "5", "value1", "value2"}, + expectedError: nil, + }, + { + name: "3. RPUSH on non-existent list creates the list", + key: "RpushKey3", + presetValue: nil, + command: []string{"RPUSH", "RpushKey3", "value1", "value2"}, + expectedResponse: 2, + expectedValue: []string{"value1", "value2"}, + expectedError: nil, + }, + { + name: "4. Command too short", + key: "RpushKey5", + presetValue: nil, + command: []string{"RPUSH", "RpushKey5"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. RPUSHX command returns error on non-existent list", + key: "RpushKey6", + presetValue: nil, + command: []string{"RPUSHX", "RpushKey7", "count", "value1"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: errors.New("RPUSHX command on non-existent key"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(test.key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != len(test.expectedValue) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + test.key, len(test.expectedValue), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandlePOP", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse string + expectedValue []string + expectedError error + }{ + { + name: "1. LPOP returns last element and removed first element from the list", + key: "PopKey1", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"LPOP", "PopKey1"}, + expectedResponse: "value1", + expectedValue: []string{"value2", "value3", "value4"}, + expectedError: nil, + }, + { + name: "2. RPOP returns last element and removed last element from the list", + key: "PopKey2", + presetValue: []string{"value1", "value2", "value3", "value4"}, + command: []string{"RPOP", "PopKey2"}, + expectedResponse: "value4", + expectedValue: []string{"value1", "value2", "value3"}, + expectedError: nil, + }, + { + name: "3. Command too short", + key: "PopKey3", + presetValue: nil, + command: []string{"LPOP"}, + expectedResponse: "", + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + key: "PopKey4", + presetValue: nil, + command: []string{"LPOP", "PopKey4", "PopKey4"}, + expectedResponse: "", + expectedValue: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Trying to execute LPOP from a non-list item return an error", + key: "PopKey5", + presetValue: "Default value", + command: []string{"LPOP", "PopKey5"}, + expectedResponse: "", + expectedValue: nil, + expectedError: errors.New("LPOP command on non-list item"), + }, + { + name: "6. Trying to execute RPOP from a non-list item return an error", + key: "PopKey6", + presetValue: "Default value", + command: []string{"RPOP", "PopKey6"}, + expectedResponse: "", + expectedValue: nil, + expectedError: errors.New("RPOP command on non-list item"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case []string: + command = []resp.Value{resp.StringValue("LPUSH"), resp.StringValue(test.key)} + for _, element := range test.presetValue.([]string) { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(len(test.presetValue.([]string))) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.String() != test.expectedResponse { + t.Errorf("expected response %s, got %s", test.expectedResponse, res.String()) + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("LRANGE"), + resp.StringValue(test.key), + resp.StringValue("0"), + resp.StringValue("-1"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != len(test.expectedValue) { + t.Errorf("expected list at key \"%s\" to be length %d, got %d", + test.key, len(test.expectedValue), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedValue, item.String()) { + t.Errorf("unexpected value \"%s\" in updated list", item.String()) + } + } + }) + } + }) } diff --git a/internal/modules/set/commands_test.go b/internal/modules/set/commands_test.go index b7108c5..2971c31 100644 --- a/internal/modules/set/commands_test.go +++ b/internal/modules/set/commands_test.go @@ -31,20 +31,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var addr = "localhost" -var port int +func Test_Set(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -52,371 +58,93 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleSADD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - tests := []struct { - name string - preset bool - presetValue interface{} - key string - command []string - expectedValue *set.Set - expectedResponse int - expectedError error - }{ - { - name: "1. Create new set on a non-existent key, return count of added elements", - preset: false, - presetValue: nil, - key: "SaddKey1", - command: []string{"SADD", "SaddKey1", "one", "two", "three", "four"}, - expectedValue: set.NewSet([]string{"one", "two", "three", "four"}), - expectedResponse: 4, - expectedError: nil, - }, - { - name: "2. Add members to an exiting set, skip members that already exist in the set, return added count.", - preset: true, - presetValue: set.NewSet([]string{"one", "two", "three", "four"}), - key: "SaddKey2", - command: []string{"SADD", "SaddKey2", "three", "four", "five", "six", "seven"}, - expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven"}), - expectedResponse: 3, - expectedError: nil, - }, - { - name: "3. Throw error when trying to add to a key that does not hold a set", - preset: true, - presetValue: "Default value", - key: "SaddKey3", - command: []string{"SADD", "SaddKey3", "member"}, - expectedResponse: 0, - expectedError: errors.New("value at key SaddKey3 is not a set"), - }, - { - name: "4. Command too short", - preset: false, - key: "SaddKey4", - command: []string{"SADD", "SaddKey4"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + t.Run("Test_HandleSADD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check if the resulting set(s) contain the expected members. - if test.expectedValue == nil { - return - } - - if err := client.WriteArray([]resp.Value{resp.StringValue("SMEMBERS"), resp.StringValue(test.key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.key, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, item := range res.Array() { - if !test.expectedValue.Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) - } - } - }) - } -} - -func Test_HandleSCARD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedValue *set.Set - expectedResponse int - expectedError error - }{ - { - name: "1. Get cardinality of valid set.", - presetValue: set.NewSet([]string{"one", "two", "three", "four"}), - key: "ScardKey1", - command: []string{"SCARD", "ScardKey1"}, - expectedValue: nil, - expectedResponse: 4, - expectedError: nil, - }, - { - name: "2. Return 0 when trying to get cardinality on non-existent key", - presetValue: nil, - key: "ScardKey2", - command: []string{"SCARD", "ScardKey2"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Throw error when trying to get cardinality of a value that is not a set", - presetValue: "Default value", - key: "ScardKey3", - command: []string{"SCARD", "ScardKey3"}, - expectedResponse: 0, - expectedError: errors.New("value at key ScardKey3 is not a set"), - }, - { - name: "4. Command too short", - key: "ScardKey4", - command: []string{"SCARD"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - key: "ScardKey5", - command: []string{"SCARD", "ScardKey5", "ScardKey5"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleSDIFF(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Get the difference between 2 sets.", - presetValues: map[string]interface{}{ - "SdiffKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SdiffKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + tests := []struct { + name string + preset bool + presetValue interface{} + key string + command []string + expectedValue *set.Set + expectedResponse int + expectedError error + }{ + { + name: "1. Create new set on a non-existent key, return count of added elements", + preset: false, + presetValue: nil, + key: "SaddKey1", + command: []string{"SADD", "SaddKey1", "one", "two", "three", "four"}, + expectedValue: set.NewSet([]string{"one", "two", "three", "four"}), + expectedResponse: 4, + expectedError: nil, }, - command: []string{"SDIFF", "SdiffKey1", "SdiffKey2"}, - expectedResponse: []string{"one", "two"}, - expectedError: nil, - }, - { - name: "2. Get the difference between 3 sets.", - presetValues: map[string]interface{}{ - "SdiffKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SdiffKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffKey5": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + { + name: "2. Add members to an exiting set, skip members that already exist in the set, return added count.", + preset: true, + presetValue: set.NewSet([]string{"one", "two", "three", "four"}), + key: "SaddKey2", + command: []string{"SADD", "SaddKey2", "three", "four", "five", "six", "seven"}, + expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven"}), + expectedResponse: 3, + expectedError: nil, }, - command: []string{"SDIFF", "SdiffKey3", "SdiffKey4", "SdiffKey5"}, - expectedResponse: []string{"three", "four", "five", "six"}, - expectedError: nil, - }, - { - name: "3. Return base set element if base set is the only valid set", - presetValues: map[string]interface{}{ - "SdiffKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SdiffKey7": "Default value", - "SdiffKey8": "123456789", + { + name: "3. Throw error when trying to add to a key that does not hold a set", + preset: true, + presetValue: "Default value", + key: "SaddKey3", + command: []string{"SADD", "SaddKey3", "member"}, + expectedResponse: 0, + expectedError: errors.New("value at key SaddKey3 is not a set"), }, - command: []string{"SDIFF", "SdiffKey6", "SdiffKey7", "SdiffKey8"}, - expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, - expectedError: nil, - }, - { - name: "4. Throw error when base set is not a set.", - presetValues: map[string]interface{}{ - "SdiffKey9": "Default value", - "SdiffKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + { + name: "4. Command too short", + preset: false, + key: "SaddKey4", + command: []string{"SADD", "SaddKey4"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"SDIFF", "SdiffKey9", "SdiffKey10", "SdiffKey11"}, - expectedResponse: nil, - expectedError: errors.New("value at key SdiffKey9 is not a set"), - }, - { - name: "5. Throw error when base set is non-existent.", - presetValues: map[string]interface{}{ - "SdiffKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SDIFF", "non-existent", "SdiffKey7", "SdiffKey8"}, - expectedResponse: nil, - expectedError: errors.New("key for base set \"non-existent\" does not exist"), - }, - { - name: "6. Command too short", - command: []string{"SDIFF"}, - expectedResponse: []string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { command = append(command, []resp.Value{resp.StringValue(element)}...) } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -431,771 +159,10 @@ func Test_HandleSDIFF(t *testing.T) { t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length \"%d\", got \"%d\"", - len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleSDIFFSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *set.Set - expectedResponse int - expectedError error - }{ - { - name: "1. Get the difference between 2 sets.", - presetValues: map[string]interface{}{ - "SdiffStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SdiffStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), - }, - destination: "SdiffStoreDestination1", - command: []string{"SDIFFSTORE", "SdiffStoreDestination1", "SdiffStoreKey1", "SdiffStoreKey2"}, - expectedValue: set.NewSet([]string{"one", "two"}), - expectedResponse: 2, - expectedError: nil, - }, - { - name: "2. Get the difference between 3 sets.", - presetValues: map[string]interface{}{ - "SdiffStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SdiffStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffStoreKey5": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - destination: "SdiffStoreDestination2", - command: []string{"SDIFFSTORE", "SdiffStoreDestination2", "SdiffStoreKey3", "SdiffStoreKey4", "SdiffStoreKey5"}, - expectedValue: set.NewSet([]string{"three", "four", "five", "six"}), - expectedResponse: 4, - expectedError: nil, - }, - { - name: "3. Return base set element if base set is the only valid set", - presetValues: map[string]interface{}{ - "SdiffStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SdiffStoreKey7": "Default value", - "SdiffStoreKey8": "123456789", - }, - destination: "SdiffStoreDestination3", - command: []string{"SDIFFSTORE", "SdiffStoreDestination3", "SdiffStoreKey6", "SdiffStoreKey7", "SdiffStoreKey8"}, - expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - expectedResponse: 8, - expectedError: nil, - }, - { - name: "4. Throw error when base set is not a set.", - presetValues: map[string]interface{}{ - "SdiffStoreKey9": "Default value", - "SdiffStoreKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffStoreKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - destination: "SdiffStoreDestination4", - command: []string{"SDIFFSTORE", "SdiffStoreDestination4", "SdiffStoreKey9", "SdiffStoreKey10", "SdiffStoreKey11"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at key SdiffStoreKey9 is not a set"), - }, - { - name: "5. Throw error when base set is non-existent.", - destination: "SdiffStoreDestination5", - presetValues: map[string]interface{}{ - "SdiffStoreKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SdiffStoreKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SDIFFSTORE", "SdiffStoreDestination5", "non-existent", "SdiffStoreKey7", "SdiffStoreKey8"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("key for base set \"non-existent\" does not exist"), - }, - { - name: "6. Command too short", - command: []string{"SDIFFSTORE", "SdiffStoreDestination6"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check if the resulting set(s) contain the expected members. - if test.expectedValue == nil { - return - } - - if err := client.WriteArray([]resp.Value{ - resp.StringValue("SMEMBERS"), - resp.StringValue(test.destination), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, item := range res.Array() { - if !test.expectedValue.Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) - } - } - }) - } -} - -func Test_HandleSINTER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Get the intersection between 2 sets.", - presetValues: map[string]interface{}{ - "SinterKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SinterKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), - }, - command: []string{"SINTER", "SinterKey1", "SinterKey2"}, - expectedResponse: []string{"three", "four", "five"}, - expectedError: nil, - }, - { - name: "2. Get the intersection between 3 sets.", - presetValues: map[string]interface{}{ - "SinterKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SinterKey5": set.NewSet([]string{"one", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTER", "SinterKey3", "SinterKey4", "SinterKey5"}, - expectedResponse: []string{"one", "eight"}, - expectedError: nil, - }, - { - name: "3. Throw an error if any of the provided keys are not sets", - presetValues: map[string]interface{}{ - "SinterKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterKey7": "Default value", - "SinterKey8": set.NewSet([]string{"one"}), - }, - command: []string{"SINTER", "SinterKey6", "SinterKey7", "SinterKey8"}, - expectedResponse: nil, - expectedError: errors.New("value at key SinterKey7 is not a set"), - }, - { - name: "4. Throw error when base set is not a set.", - presetValues: map[string]interface{}{ - "SinterKey9": "Default value", - "SinterKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SinterKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTER", "SinterKey9", "SinterKey10", "SinterKey11"}, - expectedResponse: nil, - expectedError: errors.New("value at key SinterKey9 is not a set"), - }, - { - name: "5. If any of the keys does not exist, return an empty array.", - presetValues: map[string]interface{}{ - "SinterKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SinterKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTER", "non-existent", "SinterKey7", "SinterKey8"}, - expectedResponse: []string{}, - expectedError: nil, - }, - { - name: "6. Command too short", - command: []string{"SINTER"}, - expectedResponse: []string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length \"%d\", got \"%d\"", - len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleSINTERCARD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Get the full intersect cardinality between 2 sets.", - presetValues: map[string]interface{}{ - "SinterCardKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SinterCardKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), - }, - command: []string{"SINTERCARD", "SinterCardKey1", "SinterCardKey2"}, - expectedResponse: 3, - expectedError: nil, - }, - { - name: "2. Get an intersect cardinality between 2 sets with a limit", - presetValues: map[string]interface{}{ - "SinterCardKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"}), - "SinterCardKey4": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve"}), - }, - command: []string{"SINTERCARD", "SinterCardKey3", "SinterCardKey4", "LIMIT", "3"}, - expectedResponse: 3, - expectedError: nil, - }, - { - name: "3. Get the full intersect cardinality between 3 sets.", - presetValues: map[string]interface{}{ - "SinterCardKey5": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterCardKey6": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SinterCardKey7": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTERCARD", "SinterCardKey5", "SinterCardKey6", "SinterCardKey7"}, - expectedResponse: 2, - expectedError: nil, - }, - { - name: "4. Get the intersection of 3 sets with a limit", - presetValues: map[string]interface{}{ - "SinterCardKey8": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterCardKey9": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SinterCardKey10": set.NewSet([]string{"one", "two", "seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTERCARD", "SinterCardKey8", "SinterCardKey9", "SinterCardKey10", "LIMIT", "2"}, - expectedResponse: 2, - expectedError: nil, - }, - { - name: "5. Return 0 if any of the keys does not exist", - presetValues: map[string]interface{}{ - "SinterCardKey11": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterCardKey12": "Default value", - "SinterCardKey13": set.NewSet([]string{"one"}), - }, - command: []string{"SINTERCARD", "SinterCardKey11", "SinterCardKey12", "SinterCardKey13", "non-existent"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "6. Throw error when one of the keys is not a valid set.", - presetValues: map[string]interface{}{ - "SinterCardKey14": "Default value", - "SinterCardKey15": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SinterCardKey16": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTERCARD", "SinterCardKey14", "SinterCardKey15", "SinterCardKey16"}, - expectedResponse: 0, - expectedError: errors.New("value at key SinterCardKey14 is not a set"), - }, - { - name: "7. Command too short", - command: []string{"SINTERCARD"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response array of length \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } - -} - -func Test_HandleSINTERSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *set.Set - expectedResponse int - expectedError error - }{ - { - name: "1. Get the intersection between 2 sets and store it at the destination.", - presetValues: map[string]interface{}{ - "SinterStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SinterStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), - }, - destination: "SinterStoreDestination1", - command: []string{"SINTERSTORE", "SinterStoreDestination1", "SinterStoreKey1", "SinterStoreKey2"}, - expectedValue: set.NewSet([]string{"three", "four", "five"}), - expectedResponse: 3, - expectedError: nil, - }, - { - name: "2. Get the intersection between 3 sets and store it at the destination key.", - presetValues: map[string]interface{}{ - "SinterStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SinterStoreKey5": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), - }, - destination: "SinterStoreDestination2", - command: []string{"SINTERSTORE", "SinterStoreDestination2", "SinterStoreKey3", "SinterStoreKey4", "SinterStoreKey5"}, - expectedValue: set.NewSet([]string{"one", "eight"}), - expectedResponse: 2, - expectedError: nil, - }, - { - name: "3. Throw error when any of the keys is not a set", - presetValues: map[string]interface{}{ - "SinterStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SinterStoreKey7": "Default value", - "SinterStoreKey8": set.NewSet([]string{"one"}), - }, - destination: "SinterStoreDestination3", - command: []string{"SINTERSTORE", "SinterStoreDestination3", "SinterStoreKey6", "SinterStoreKey7", "SinterStoreKey8"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at key SinterStoreKey7 is not a set"), - }, - { - name: "4. Throw error when base set is not a set.", - presetValues: map[string]interface{}{ - "SinterStoreKey9": "Default value", - "SinterStoreKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SinterStoreKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - destination: "SinterStoreDestination4", - command: []string{"SINTERSTORE", "SinterStoreDestination4", "SinterStoreKey9", "SinterStoreKey10", "SinterStoreKey11"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at key SinterStoreKey9 is not a set"), - }, - { - name: "5. Return an empty intersection if one of the keys does not exist.", - destination: "SinterStoreDestination5", - presetValues: map[string]interface{}{ - "SinterStoreKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SinterStoreKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), - }, - command: []string{"SINTERSTORE", "SinterStoreDestination5", "non-existent", "SinterStoreKey7", "SinterStoreKey8"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "6. Command too short", - command: []string{"SINTERSTORE", "SinterStoreDestination6"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check if the resulting set(s) contain the expected members. - if test.expectedValue == nil { - return - } - - if err := client.WriteArray([]resp.Value{ - resp.StringValue("SMEMBERS"), - resp.StringValue(test.destination), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, item := range res.Array() { - if !test.expectedValue.Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) - } - } - }) - } -} - -func Test_HandleSISMEMBER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Return 1 when element is a member of the set", - presetValue: set.NewSet([]string{"one", "two", "three", "four"}), - key: "SIsMemberKey1", - command: []string{"SISMEMBER", "SIsMemberKey1", "three"}, - expectedResponse: 1, - expectedError: nil, - }, - { - name: "2. Return 0 when element is not a member of the set", - presetValue: set.NewSet([]string{"one", "two", "three", "four"}), - key: "SIsMemberKey2", - command: []string{"SISMEMBER", "SIsMemberKey2", "five"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Throw error when trying to assert membership when the key does not hold a valid set", - presetValue: "Default value", - key: "SIsMemberKey3", - command: []string{"SISMEMBER", "SIsMemberKey3", "one"}, - expectedResponse: 0, - expectedError: errors.New("value at key SIsMemberKey3 is not a set"), - }, - { - name: "4. Command too short", - key: "SIsMemberKey4", - command: []string{"SISMEMBER", "SIsMemberKey4"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - key: "SIsMemberKey5", - command: []string{"SISMEMBER", "SIsMemberKey5", "one", "two", "three"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1206,390 +173,129 @@ func Test_HandleSISMEMBER(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleSMEMBERS(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Return all the members of the set.", - key: "SmembersKey1", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five"}), - command: []string{"SMEMBERS", "SmembersKey1"}, - expectedResponse: []string{"one", "two", "three", "four", "five"}, - expectedError: nil, - }, - { - name: "2. If the key does not exist, return an empty array.", - key: "SmembersKey2", - presetValue: nil, - command: []string{"SMEMBERS", "SmembersKey2"}, - expectedResponse: nil, - expectedError: nil, - }, - { - name: "3. Throw error when the provided key is not a set.", - key: "SmembersKey3", - presetValue: "Default value", - command: []string{"SMEMBERS", "SmembersKey3"}, - expectedResponse: nil, - expectedError: errors.New("value at key SmembersKey3 is not a set"), - }, - { - name: "4. Command too short", - command: []string{"SMEMBERS"}, - expectedResponse: []string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - command: []string{"SMEMBERS", "SmembersKey5", "SmembersKey6"}, - expectedResponse: []string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + return } - if err = client.WriteArray(command); err != nil { + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check if the resulting set(s) contain the expected members. + if test.expectedValue == nil { + return + } + + if err := client.WriteArray([]resp.Value{resp.StringValue("SMEMBERS"), resp.StringValue(test.key)}); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.key, test.expectedValue.Cardinality(), len(res.Array())) } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length \"%d\", got \"%d\"", - len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - }) - } -} - -func Test_HandleSMISMEMBER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedResponse []int - expectedError error - }{ - { - // 1. Return set membership status for multiple elements - // Return 1 for present and 0 for absent - // The placement of the membership status flag should me consistent with the order the elements - // are in within the original command - name: "1. Return set membership status for multiple elements", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven"}), - key: "SmismemberKey1", - command: []string{"SMISMEMBER", "SmismemberKey1", "three", "four", "five", "six", "eight", "nine", "seven"}, - expectedResponse: []int{1, 1, 1, 1, 0, 0, 1}, - expectedError: nil, - }, - { - name: "2. If the set key does not exist, return an array of zeroes as long as the list of members", - presetValue: nil, - key: "SmismemberKey2", - command: []string{"SMISMEMBER", "SmismemberKey2", "one", "two", "three", "four"}, - expectedResponse: []int{0, 0, 0, 0}, - expectedError: nil, - }, - { - name: "3. Throw error when trying to assert membership when the key does not hold a valid set", - presetValue: "Default value", - key: "SmismemberKey3", - command: []string{"SMISMEMBER", "SmismemberKey3", "one"}, - expectedResponse: nil, - expectedError: errors.New("value at key SmismemberKey3 is not a set"), - }, - { - name: "4. Command too short", - presetValue: nil, - key: "SmismemberKey4", - command: []string{"SMISMEMBER", "SmismemberKey4"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + for _, item := range res.Array() { + if !test.expectedValue.Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) } + }) + } + }) - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } + t.Run("Test_HandleSCARD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length \"%d\", got \"%d\"", - len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.Integer()) { - t.Errorf("unexpected element \"%d\" in response", item.Integer()) - } - } - }) - } -} - -func Test_HandleSMOVE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedValues map[string]interface{} - expectedResponse int - expectedError error - }{ - { - name: "1. Return 1 after a successful move of a member from source set to destination set", - presetValues: map[string]interface{}{ - "SmoveSource1": set.NewSet([]string{"one", "two", "three", "four"}), - "SmoveDestination1": set.NewSet([]string{"five", "six", "seven", "eight"}), + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedValue *set.Set + expectedResponse int + expectedError error + }{ + { + name: "1. Get cardinality of valid set.", + presetValue: set.NewSet([]string{"one", "two", "three", "four"}), + key: "ScardKey1", + command: []string{"SCARD", "ScardKey1"}, + expectedValue: nil, + expectedResponse: 4, + expectedError: nil, }, - command: []string{"SMOVE", "SmoveSource1", "SmoveDestination1", "four"}, - expectedValues: map[string]interface{}{ - "SmoveSource1": set.NewSet([]string{"one", "two", "three"}), - "SmoveDestination1": set.NewSet([]string{"four", "five", "six", "seven", "eight"}), + { + name: "2. Return 0 when trying to get cardinality on non-existent key", + presetValue: nil, + key: "ScardKey2", + command: []string{"SCARD", "ScardKey2"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: nil, }, - expectedResponse: 1, - expectedError: nil, - }, - { - name: "2. Return 0 when trying to move a member from source set to destination set when it doesn't exist in source", - presetValues: map[string]interface{}{ - "SmoveSource2": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SmoveDestination2": set.NewSet([]string{"five", "six", "seven", "eight"}), + { + name: "3. Throw error when trying to get cardinality of a value that is not a set", + presetValue: "Default value", + key: "ScardKey3", + command: []string{"SCARD", "ScardKey3"}, + expectedResponse: 0, + expectedError: errors.New("value at key ScardKey3 is not a set"), }, - command: []string{"SMOVE", "SmoveSource2", "SmoveDestination2", "six"}, - expectedValues: map[string]interface{}{ - "SmoveSource2": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SmoveDestination2": set.NewSet([]string{"five", "six", "seven", "eight"}), + { + name: "4. Command too short", + key: "ScardKey4", + command: []string{"SCARD"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Return error when the source key is not a set", - presetValues: map[string]interface{}{ - "SmoveSource3": "Default value", - "SmoveDestination3": set.NewSet([]string{"five", "six", "seven", "eight"}), + { + name: "5. Command too long", + key: "ScardKey5", + command: []string{"SCARD", "ScardKey5", "ScardKey5"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"SMOVE", "SmoveSource3", "SmoveDestination3", "five"}, - expectedValues: map[string]interface{}{ - "SmoveSource3": "Default value", - "SmoveDestination3": set.NewSet([]string{"five", "six", "seven", "eight"}), - }, - expectedResponse: 0, - expectedError: errors.New("source is not a set"), - }, - { - name: "4. Return error when the destination key is not a set", - presetValues: map[string]interface{}{ - "SmoveSource4": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SmoveDestination4": "Default value", - }, - command: []string{"SMOVE", "SmoveSource4", "SmoveDestination4", "five"}, - expectedValues: map[string]interface{}{ - "SmoveSource4": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SmoveDestination4": "Default value", - }, - expectedResponse: 0, - expectedError: errors.New("destination is not a set"), - }, - { - name: "5. Command too short", - presetValues: nil, - command: []string{"SMOVE", "SmoveSource5", "SmoveSource6"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - presetValues: nil, - command: []string{"SMOVE", "SmoveSource5", "SmoveSource6", "member1", "member2"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { command = append(command, []resp.Value{resp.StringValue(element)}...) } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -1604,157 +310,153 @@ func Test_HandleSMOVE(t *testing.T) { t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - // Check if the resulting set(s) contain the expected members. - if test.expectedValues == nil { - return - } + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } - for key, value := range test.expectedValues { - switch value.(type) { - case string: - if err := client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - if res.String() != value.(string) { - t.Errorf("expected value at key \"%s\" to be \"%s\", got \"%s\"", key, value.(string), res.String()) - } - case *set.Set: - if err := client.WriteArray([]resp.Value{ - resp.StringValue("SMEMBERS"), - resp.StringValue(key), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + }) - if len(res.Array()) != value.(*set.Set).Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - key, value.(*set.Set).Cardinality(), len(res.Array())) - } + t.Run("Test_HandleSDIFF", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, item := range res.Array() { - if !value.(*set.Set).Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Get the difference between 2 sets.", + presetValues: map[string]interface{}{ + "SdiffKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SdiffKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + command: []string{"SDIFF", "SdiffKey1", "SdiffKey2"}, + expectedResponse: []string{"one", "two"}, + expectedError: nil, + }, + { + name: "2. Get the difference between 3 sets.", + presetValues: map[string]interface{}{ + "SdiffKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SdiffKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffKey5": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SDIFF", "SdiffKey3", "SdiffKey4", "SdiffKey5"}, + expectedResponse: []string{"three", "four", "five", "six"}, + expectedError: nil, + }, + { + name: "3. Return base set element if base set is the only valid set", + presetValues: map[string]interface{}{ + "SdiffKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SdiffKey7": "Default value", + "SdiffKey8": "123456789", + }, + command: []string{"SDIFF", "SdiffKey6", "SdiffKey7", "SdiffKey8"}, + expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, + expectedError: nil, + }, + { + name: "4. Throw error when base set is not a set.", + presetValues: map[string]interface{}{ + "SdiffKey9": "Default value", + "SdiffKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SDIFF", "SdiffKey9", "SdiffKey10", "SdiffKey11"}, + expectedResponse: nil, + expectedError: errors.New("value at key SdiffKey9 is not a set"), + }, + { + name: "5. Throw error when base set is non-existent.", + presetValues: map[string]interface{}{ + "SdiffKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SDIFF", "non-existent", "SdiffKey7", "SdiffKey8"}, + expectedResponse: nil, + expectedError: errors.New("key for base set \"non-existent\" does not exist"), + }, + { + name: "6. Command too short", + command: []string{"SDIFF"}, + expectedResponse: []string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } } - } - }) - } -} -func Test_HandleSPOP(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue int // The final cardinality of the resulting set - expectedResponse []string - expectedError error - }{ - { - name: "1. Return multiple popped elements and modify the set", - key: "SpopKey1", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - command: []string{"SPOP", "SpopKey1", "3"}, - expectedValue: 5, - expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, - expectedError: nil, - }, - { - name: "2. Return error when the source key is not a set", - key: "SpopKey2", - presetValue: "Default value", - command: []string{"SPOP", "SpopKey2"}, - expectedValue: 0, - expectedResponse: nil, - expectedError: errors.New("value at SpopKey2 is not a set"), - }, - { - name: "3. Command too short", - presetValue: nil, - command: []string{"SPOP"}, - expectedValue: 0, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - presetValue: nil, - command: []string{"SPOP", "SpopSource5", "SpopSource6", "member1", "member2"}, - expectedValue: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Throw error when count is not an integer", - presetValue: nil, - command: []string{"SPOP", "SpopKey1", "count"}, - expectedValue: 0, - expectedError: errors.New("count must be an integer"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) - } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1765,143 +467,158 @@ func Test_HandleSPOP(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - // Check that each returned element is in the list of expected elements. - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - - // Check if the resulting set's cardinality is as expected. - if err := client.WriteArray([]resp.Value{resp.StringValue("SCARD"), resp.StringValue(test.key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if res.Integer() != test.expectedValue { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.key, test.expectedValue, res.Integer()) - } - }) - } -} - -func Test_HandleSRANDMEMBER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue int // The final cardinality of the resulting set - allowRepeat bool - expectedResponse []string - expectedError error - }{ - { - // 1. Return multiple random elements without removing them - // Count is positive, do not allow repeated elements - name: "1. Return multiple random elements without removing them", - key: "SRandMemberKey1", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - command: []string{"SRANDMEMBER", "SRandMemberKey1", "3"}, - expectedValue: 8, - allowRepeat: false, - expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, - expectedError: nil, - }, - { - // 2. Return multiple random elements without removing them - // Count is negative, so allow repeated numbers - name: "2. Return multiple random elements without removing them", - key: "SRandMemberKey2", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - command: []string{"SRANDMEMBER", "SRandMemberKey2", "-5"}, - expectedValue: 8, - allowRepeat: true, - expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, - expectedError: nil, - }, - { - name: "3. Return error when the source key is not a set", - key: "SRandMemberKey3", - presetValue: "Default value", - command: []string{"SRANDMEMBER", "SRandMemberKey3"}, - expectedValue: 0, - expectedResponse: []string{}, - expectedError: errors.New("value at SRandMemberKey3 is not a set"), - }, - { - name: "4. Command too short", - command: []string{"SRANDMEMBER"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - command: []string{"SRANDMEMBER", "SRandMemberSource5", "SRandMemberSource6", "member1", "member2"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Throw error when count is not an integer", - command: []string{"SRANDMEMBER", "SRandMemberKey1", "count"}, - expectedError: errors.New("count must be an integer"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length \"%d\", got \"%d\"", + len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + }) + } + }) + + t.Run("Test_HandleSDIFFSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *set.Set + expectedResponse int + expectedError error + }{ + { + name: "1. Get the difference between 2 sets.", + presetValues: map[string]interface{}{ + "SdiffStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SdiffStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + destination: "SdiffStoreDestination1", + command: []string{"SDIFFSTORE", "SdiffStoreDestination1", "SdiffStoreKey1", "SdiffStoreKey2"}, + expectedValue: set.NewSet([]string{"one", "two"}), + expectedResponse: 2, + expectedError: nil, + }, + { + name: "2. Get the difference between 3 sets.", + presetValues: map[string]interface{}{ + "SdiffStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SdiffStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffStoreKey5": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + destination: "SdiffStoreDestination2", + command: []string{"SDIFFSTORE", "SdiffStoreDestination2", "SdiffStoreKey3", "SdiffStoreKey4", "SdiffStoreKey5"}, + expectedValue: set.NewSet([]string{"three", "four", "five", "six"}), + expectedResponse: 4, + expectedError: nil, + }, + { + name: "3. Return base set element if base set is the only valid set", + presetValues: map[string]interface{}{ + "SdiffStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SdiffStoreKey7": "Default value", + "SdiffStoreKey8": "123456789", + }, + destination: "SdiffStoreDestination3", + command: []string{"SDIFFSTORE", "SdiffStoreDestination3", "SdiffStoreKey6", "SdiffStoreKey7", "SdiffStoreKey8"}, + expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + expectedResponse: 8, + expectedError: nil, + }, + { + name: "4. Throw error when base set is not a set.", + presetValues: map[string]interface{}{ + "SdiffStoreKey9": "Default value", + "SdiffStoreKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffStoreKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + destination: "SdiffStoreDestination4", + command: []string{"SDIFFSTORE", "SdiffStoreDestination4", "SdiffStoreKey9", "SdiffStoreKey10", "SdiffStoreKey11"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at key SdiffStoreKey9 is not a set"), + }, + { + name: "5. Throw error when base set is non-existent.", + destination: "SdiffStoreDestination5", + presetValues: map[string]interface{}{ + "SdiffStoreKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SdiffStoreKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SDIFFSTORE", "SdiffStoreDestination5", "non-existent", "SdiffStoreKey7", "SdiffStoreKey8"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("key for base set \"non-existent\" does not exist"), + }, + { + name: "6. Command too short", + command: []string{"SDIFFSTORE", "SdiffStoreDestination6"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1912,140 +629,166 @@ func Test_HandleSRANDMEMBER(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - // Check that each returned element is in the list of expected elements. - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) - } - } - - // If no repeats are allowed, check if the response contains any repeated elements - if !test.allowRepeat { - s := set.NewSet(func() []string { - elements := make([]string, len(res.Array())) - for i, item := range res.Array() { - elements[i] = item.String() + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - return elements - }()) - if s.Cardinality() != len(res.Array()) { - t.Error("response has repeated elements, expected only unique elements.") + return } - } - // Check if the resulting set's cardinality is as expected. - if err := client.WriteArray([]resp.Value{resp.StringValue("SCARD"), resp.StringValue(test.key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } - if res.Integer() != test.expectedValue { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.key, test.expectedValue, res.Integer()) - } - }) - } -} + // Check if the resulting set(s) contain the expected members. + if test.expectedValue == nil { + return + } -func Test_HandleSREM(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + if err := client.WriteArray([]resp.Value{ + resp.StringValue("SMEMBERS"), + resp.StringValue(test.destination), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue *set.Set // The final cardinality of the resulting set - expectedResponse int - expectedError error - }{ - { - name: "1. Remove multiple elements and return the number of elements removed", - key: "SremKey1", - presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - command: []string{"SREM", "SremKey1", "one", "two", "three", "nine"}, - expectedValue: set.NewSet([]string{"four", "five", "six", "seven", "eight"}), - expectedResponse: 3, - expectedError: nil, - }, - { - name: "2. If key does not exist, return 0", - key: "SremKey2", - presetValue: nil, - command: []string{"SREM", "SremKey1", "one", "two", "three", "nine"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Return error when the source key is not a set", - key: "SremKey3", - presetValue: "Default value", - command: []string{"SREM", "SremKey3", "one"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at key SremKey3 is not a set"), - }, - { - name: "4. Command too short", - command: []string{"SREM", "SremKey"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + for _, item := range res.Array() { + if !test.expectedValue.Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) } - expected = "ok" - case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} - for _, element := range test.presetValue.(*set.Set).GetAll() { - command = append(command, []resp.Value{resp.StringValue(element)}...) + } + }) + } + }) + + t.Run("Test_HandleSINTER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Get the intersection between 2 sets.", + presetValues: map[string]interface{}{ + "SinterKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SinterKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + command: []string{"SINTER", "SinterKey1", "SinterKey2"}, + expectedResponse: []string{"three", "four", "five"}, + expectedError: nil, + }, + { + name: "2. Get the intersection between 3 sets.", + presetValues: map[string]interface{}{ + "SinterKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SinterKey5": set.NewSet([]string{"one", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTER", "SinterKey3", "SinterKey4", "SinterKey5"}, + expectedResponse: []string{"one", "eight"}, + expectedError: nil, + }, + { + name: "3. Throw an error if any of the provided keys are not sets", + presetValues: map[string]interface{}{ + "SinterKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterKey7": "Default value", + "SinterKey8": set.NewSet([]string{"one"}), + }, + command: []string{"SINTER", "SinterKey6", "SinterKey7", "SinterKey8"}, + expectedResponse: nil, + expectedError: errors.New("value at key SinterKey7 is not a set"), + }, + { + name: "4. Throw error when base set is not a set.", + presetValues: map[string]interface{}{ + "SinterKey9": "Default value", + "SinterKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SinterKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTER", "SinterKey9", "SinterKey10", "SinterKey11"}, + expectedResponse: nil, + expectedError: errors.New("value at key SinterKey9 is not a set"), + }, + { + name: "5. If any of the keys does not exist, return an empty array.", + presetValues: map[string]interface{}{ + "SinterKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SinterKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTER", "non-existent", "SinterKey7", "SinterKey8"}, + expectedResponse: []string{}, + expectedError: nil, + }, + { + name: "6. Command too short", + command: []string{"SINTER"}, + expectedResponse: []string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } } - expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -2056,151 +799,444 @@ func Test_HandleSREM(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length \"%d\", got \"%d\"", + len(test.expectedResponse), len(res.Array())) } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check if the resulting set(s) contain the expected members. - if test.expectedValue == nil { - return - } - - if err := client.WriteArray([]resp.Value{resp.StringValue("SMEMBERS"), resp.StringValue(test.key)}); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.key, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, item := range res.Array() { - if !test.expectedValue.Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } } - } - }) - } -} + }) + } + }) -func Test_HandleSUNION(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleSINTERCARD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Get the union between 2 sets.", - presetValues: map[string]interface{}{ - "SunionKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SunionKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Get the full intersect cardinality between 2 sets.", + presetValues: map[string]interface{}{ + "SinterCardKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SinterCardKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + command: []string{"SINTERCARD", "SinterCardKey1", "SinterCardKey2"}, + expectedResponse: 3, + expectedError: nil, }, - command: []string{"SUNION", "SunionKey1", "SunionKey2"}, - expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, - expectedError: nil, - }, - { - name: "2. Get the union between 3 sets.", - presetValues: map[string]interface{}{ - "SunionKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SunionKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SunionKey5": set.NewSet([]string{"one", "eight", "nine", "ten", "twelve"}), + { + name: "2. Get an intersect cardinality between 2 sets with a limit", + presetValues: map[string]interface{}{ + "SinterCardKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"}), + "SinterCardKey4": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve"}), + }, + command: []string{"SINTERCARD", "SinterCardKey3", "SinterCardKey4", "LIMIT", "3"}, + expectedResponse: 3, + expectedError: nil, }, - command: []string{"SUNION", "SunionKey3", "SunionKey4", "SunionKey5"}, - expectedResponse: []string{ - "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", - "ten", "eleven", "twelve", "thirty-six", + { + name: "3. Get the full intersect cardinality between 3 sets.", + presetValues: map[string]interface{}{ + "SinterCardKey5": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterCardKey6": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SinterCardKey7": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTERCARD", "SinterCardKey5", "SinterCardKey6", "SinterCardKey7"}, + expectedResponse: 2, + expectedError: nil, }, - expectedError: nil, - }, - { - name: "3. Throw an error if any of the provided keys are not sets", - presetValues: map[string]interface{}{ - "SunionKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SunionKey7": "Default value", - "SunionKey8": set.NewSet([]string{"one"}), + { + name: "4. Get the intersection of 3 sets with a limit", + presetValues: map[string]interface{}{ + "SinterCardKey8": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterCardKey9": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SinterCardKey10": set.NewSet([]string{"one", "two", "seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTERCARD", "SinterCardKey8", "SinterCardKey9", "SinterCardKey10", "LIMIT", "2"}, + expectedResponse: 2, + expectedError: nil, }, - command: []string{"SUNION", "SunionKey6", "SunionKey7", "SunionKey8"}, - expectedResponse: nil, - expectedError: errors.New("value at key SunionKey7 is not a set"), - }, - { - name: "4. Throw error any of the keys does not hold a set.", - presetValues: map[string]interface{}{ - "SunionKey9": "Default value", - "SunionKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), - "SunionKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + { + name: "5. Return 0 if any of the keys does not exist", + presetValues: map[string]interface{}{ + "SinterCardKey11": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterCardKey12": "Default value", + "SinterCardKey13": set.NewSet([]string{"one"}), + }, + command: []string{"SINTERCARD", "SinterCardKey11", "SinterCardKey12", "SinterCardKey13", "non-existent"}, + expectedResponse: 0, + expectedError: nil, }, - command: []string{"SUNION", "SunionKey9", "SunionKey10", "SunionKey11"}, - expectedResponse: nil, - expectedError: errors.New("value at key SunionKey9 is not a set"), - }, - { - name: "6. Command too short", - command: []string{"SUNION"}, - expectedResponse: []string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + { + name: "6. Throw error when one of the keys is not a valid set.", + presetValues: map[string]interface{}{ + "SinterCardKey14": "Default value", + "SinterCardKey15": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SinterCardKey16": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTERCARD", "SinterCardKey14", "SinterCardKey15", "SinterCardKey16"}, + expectedResponse: 0, + expectedError: errors.New("value at key SinterCardKey14 is not a set"), + }, + { + name: "7. Command too short", + command: []string{"SINTERCARD"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + + }) + + t.Run("Test_HandleSINTERSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *set.Set + expectedResponse int + expectedError error + }{ + { + name: "1. Get the intersection between 2 sets and store it at the destination.", + presetValues: map[string]interface{}{ + "SinterStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SinterStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + destination: "SinterStoreDestination1", + command: []string{"SINTERSTORE", "SinterStoreDestination1", "SinterStoreKey1", "SinterStoreKey2"}, + expectedValue: set.NewSet([]string{"three", "four", "five"}), + expectedResponse: 3, + expectedError: nil, + }, + { + name: "2. Get the intersection between 3 sets and store it at the destination key.", + presetValues: map[string]interface{}{ + "SinterStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SinterStoreKey5": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), + }, + destination: "SinterStoreDestination2", + command: []string{"SINTERSTORE", "SinterStoreDestination2", "SinterStoreKey3", "SinterStoreKey4", "SinterStoreKey5"}, + expectedValue: set.NewSet([]string{"one", "eight"}), + expectedResponse: 2, + expectedError: nil, + }, + { + name: "3. Throw error when any of the keys is not a set", + presetValues: map[string]interface{}{ + "SinterStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SinterStoreKey7": "Default value", + "SinterStoreKey8": set.NewSet([]string{"one"}), + }, + destination: "SinterStoreDestination3", + command: []string{"SINTERSTORE", "SinterStoreDestination3", "SinterStoreKey6", "SinterStoreKey7", "SinterStoreKey8"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at key SinterStoreKey7 is not a set"), + }, + { + name: "4. Throw error when base set is not a set.", + presetValues: map[string]interface{}{ + "SinterStoreKey9": "Default value", + "SinterStoreKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SinterStoreKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + destination: "SinterStoreDestination4", + command: []string{"SINTERSTORE", "SinterStoreDestination4", "SinterStoreKey9", "SinterStoreKey10", "SinterStoreKey11"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at key SinterStoreKey9 is not a set"), + }, + { + name: "5. Return an empty intersection if one of the keys does not exist.", + destination: "SinterStoreDestination5", + presetValues: map[string]interface{}{ + "SinterStoreKey12": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SinterStoreKey13": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SINTERSTORE", "SinterStoreDestination5", "non-existent", "SinterStoreKey7", "SinterStoreKey8"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "6. Command too short", + command: []string{"SINTERSTORE", "SinterStoreDestination6"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check if the resulting set(s) contain the expected members. + if test.expectedValue == nil { + return + } + + if err := client.WriteArray([]resp.Value{ + resp.StringValue("SMEMBERS"), + resp.StringValue(test.destination), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, item := range res.Array() { + if !test.expectedValue.Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleSISMEMBER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Return 1 when element is a member of the set", + presetValue: set.NewSet([]string{"one", "two", "three", "four"}), + key: "SIsMemberKey1", + command: []string{"SISMEMBER", "SIsMemberKey1", "three"}, + expectedResponse: 1, + expectedError: nil, + }, + { + name: "2. Return 0 when element is not a member of the set", + presetValue: set.NewSet([]string{"one", "two", "three", "four"}), + key: "SIsMemberKey2", + command: []string{"SISMEMBER", "SIsMemberKey2", "five"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Throw error when trying to assert membership when the key does not hold a valid set", + presetValue: "Default value", + key: "SIsMemberKey3", + command: []string{"SISMEMBER", "SIsMemberKey3", "one"}, + expectedResponse: 0, + expectedError: errors.New("value at key SIsMemberKey3 is not a set"), + }, + { + name: "4. Command too short", + key: "SIsMemberKey4", + command: []string{"SISMEMBER", "SIsMemberKey4"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + key: "SIsMemberKey5", + command: []string{"SISMEMBER", "SIsMemberKey5", "one", "two", "three"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { command = append(command, []resp.Value{resp.StringValue(element)}...) } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -2215,128 +1251,112 @@ func Test_HandleSUNION(t *testing.T) { t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length \"%d\", got \"%d\"", - len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - if !slices.Contains(test.expectedResponse, item.String()) { - t.Errorf("unexpected element \"%s\" in response", item.String()) + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - } - }) - } -} -func Test_HandleSUNIONSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *set.Set - expectedResponse int - expectedError error - }{ - { - name: "1. Get the intersection between 2 sets and store it at the destination.", - presetValues: map[string]interface{}{ - "SunionStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), - "SunionStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), - }, - destination: "SunionStoreDestination1", - command: []string{"SUNIONSTORE", "SunionStoreDestination1", "SunionStoreKey1", "SunionStoreKey2"}, - expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - expectedResponse: 8, - expectedError: nil, - }, - { - name: "2. Get the intersection between 3 sets and store it at the destination key.", - presetValues: map[string]interface{}{ - "SunionStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SunionStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), - "SunionStoreKey5": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), - }, - destination: "SunionStoreDestination2", - command: []string{"SUNIONSTORE", "SunionStoreDestination2", "SunionStoreKey3", "SunionStoreKey4", "SunionStoreKey5"}, - expectedValue: set.NewSet([]string{ - "one", "two", "three", "four", "five", "six", "seven", "eight", - "nine", "ten", "eleven", "twelve", "thirty-six", - }), - expectedResponse: 13, - expectedError: nil, - }, - { - name: "3. Throw error when any of the keys is not a set", - presetValues: map[string]interface{}{ - "SunionStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), - "SunionStoreKey7": "Default value", - "SunionStoreKey8": set.NewSet([]string{"one"}), - }, - destination: "SunionStoreDestination3", - command: []string{"SUNIONSTORE", "SunionStoreDestination3", "SunionStoreKey6", "SunionStoreKey7", "SunionStoreKey8"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at key SunionStoreKey7 is not a set"), - }, - { - name: "5. Command too short", - command: []string{"SUNIONSTORE", "SunionStoreDestination6"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + }) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + t.Run("Test_HandleSMEMBERS", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Return all the members of the set.", + key: "SmembersKey1", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five"}), + command: []string{"SMEMBERS", "SmembersKey1"}, + expectedResponse: []string{"one", "two", "three", "four", "five"}, + expectedError: nil, + }, + { + name: "2. If the key does not exist, return an empty array.", + key: "SmembersKey2", + presetValue: nil, + command: []string{"SMEMBERS", "SmembersKey2"}, + expectedResponse: nil, + expectedError: nil, + }, + { + name: "3. Throw error when the provided key is not a set.", + key: "SmembersKey3", + presetValue: "Default value", + command: []string{"SMEMBERS", "SmembersKey3"}, + expectedResponse: nil, + expectedError: errors.New("value at key SmembersKey3 is not a set"), + }, + { + name: "4. Command too short", + command: []string{"SMEMBERS"}, + expectedResponse: []string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + command: []string{"SMEMBERS", "SmembersKey5", "SmembersKey6"}, + expectedResponse: []string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *set.Set: - command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} - for _, element := range value.(*set.Set).GetAll() { + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { command = append(command, []resp.Value{resp.StringValue(element)}...) } - expected = strconv.Itoa(value.(*set.Set).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -2351,58 +1371,1112 @@ func Test_HandleSUNIONSTORE(t *testing.T) { t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - - // Check if the resulting set(s) contain the expected members. - if test.expectedValue == nil { - return - } - - if err := client.WriteArray([]resp.Value{ - resp.StringValue("SMEMBERS"), - resp.StringValue(test.destination), - }); err != nil { - t.Error(err) - } - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, item := range res.Array() { - if !test.expectedValue.Contains(item.String()) { - t.Errorf("unexpected memeber \"%s\", in response", item.String()) + if err = client.WriteArray(command); err != nil { + t.Error(err) } - } - }) - } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length \"%d\", got \"%d\"", + len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleSMISMEMBER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedResponse []int + expectedError error + }{ + { + // 1. Return set membership status for multiple elements + // Return 1 for present and 0 for absent + // The placement of the membership status flag should me consistent with the order the elements + // are in within the original command + name: "1. Return set membership status for multiple elements", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven"}), + key: "SmismemberKey1", + command: []string{"SMISMEMBER", "SmismemberKey1", "three", "four", "five", "six", "eight", "nine", "seven"}, + expectedResponse: []int{1, 1, 1, 1, 0, 0, 1}, + expectedError: nil, + }, + { + name: "2. If the set key does not exist, return an array of zeroes as long as the list of members", + presetValue: nil, + key: "SmismemberKey2", + command: []string{"SMISMEMBER", "SmismemberKey2", "one", "two", "three", "four"}, + expectedResponse: []int{0, 0, 0, 0}, + expectedError: nil, + }, + { + name: "3. Throw error when trying to assert membership when the key does not hold a valid set", + presetValue: "Default value", + key: "SmismemberKey3", + command: []string{"SMISMEMBER", "SmismemberKey3", "one"}, + expectedResponse: nil, + expectedError: errors.New("value at key SmismemberKey3 is not a set"), + }, + { + name: "4. Command too short", + presetValue: nil, + key: "SmismemberKey4", + command: []string{"SMISMEMBER", "SmismemberKey4"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length \"%d\", got \"%d\"", + len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.Integer()) { + t.Errorf("unexpected element \"%d\" in response", item.Integer()) + } + } + }) + } + }) + + t.Run("Test_HandleSMOVE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedValues map[string]interface{} + expectedResponse int + expectedError error + }{ + { + name: "1. Return 1 after a successful move of a member from source set to destination set", + presetValues: map[string]interface{}{ + "SmoveSource1": set.NewSet([]string{"one", "two", "three", "four"}), + "SmoveDestination1": set.NewSet([]string{"five", "six", "seven", "eight"}), + }, + command: []string{"SMOVE", "SmoveSource1", "SmoveDestination1", "four"}, + expectedValues: map[string]interface{}{ + "SmoveSource1": set.NewSet([]string{"one", "two", "three"}), + "SmoveDestination1": set.NewSet([]string{"four", "five", "six", "seven", "eight"}), + }, + expectedResponse: 1, + expectedError: nil, + }, + { + name: "2. Return 0 when trying to move a member from source set to destination set when it doesn't exist in source", + presetValues: map[string]interface{}{ + "SmoveSource2": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SmoveDestination2": set.NewSet([]string{"five", "six", "seven", "eight"}), + }, + command: []string{"SMOVE", "SmoveSource2", "SmoveDestination2", "six"}, + expectedValues: map[string]interface{}{ + "SmoveSource2": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SmoveDestination2": set.NewSet([]string{"five", "six", "seven", "eight"}), + }, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Return error when the source key is not a set", + presetValues: map[string]interface{}{ + "SmoveSource3": "Default value", + "SmoveDestination3": set.NewSet([]string{"five", "six", "seven", "eight"}), + }, + command: []string{"SMOVE", "SmoveSource3", "SmoveDestination3", "five"}, + expectedValues: map[string]interface{}{ + "SmoveSource3": "Default value", + "SmoveDestination3": set.NewSet([]string{"five", "six", "seven", "eight"}), + }, + expectedResponse: 0, + expectedError: errors.New("source is not a set"), + }, + { + name: "4. Return error when the destination key is not a set", + presetValues: map[string]interface{}{ + "SmoveSource4": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SmoveDestination4": "Default value", + }, + command: []string{"SMOVE", "SmoveSource4", "SmoveDestination4", "five"}, + expectedValues: map[string]interface{}{ + "SmoveSource4": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SmoveDestination4": "Default value", + }, + expectedResponse: 0, + expectedError: errors.New("destination is not a set"), + }, + { + name: "5. Command too short", + presetValues: nil, + command: []string{"SMOVE", "SmoveSource5", "SmoveSource6"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + presetValues: nil, + command: []string{"SMOVE", "SmoveSource5", "SmoveSource6", "member1", "member2"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check if the resulting set(s) contain the expected members. + if test.expectedValues == nil { + return + } + + for key, value := range test.expectedValues { + switch value.(type) { + case string: + if err := client.WriteArray([]resp.Value{resp.StringValue("GET"), resp.StringValue(key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + if res.String() != value.(string) { + t.Errorf("expected value at key \"%s\" to be \"%s\", got \"%s\"", key, value.(string), res.String()) + } + case *set.Set: + if err := client.WriteArray([]resp.Value{ + resp.StringValue("SMEMBERS"), + resp.StringValue(key), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != value.(*set.Set).Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + key, value.(*set.Set).Cardinality(), len(res.Array())) + } + + for _, item := range res.Array() { + if !value.(*set.Set).Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) + } + } + } + } + }) + } + }) + + t.Run("Test_HandleSPOP", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue int // The final cardinality of the resulting set + expectedResponse []string + expectedError error + }{ + { + name: "1. Return multiple popped elements and modify the set", + key: "SpopKey1", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + command: []string{"SPOP", "SpopKey1", "3"}, + expectedValue: 5, + expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, + expectedError: nil, + }, + { + name: "2. Return error when the source key is not a set", + key: "SpopKey2", + presetValue: "Default value", + command: []string{"SPOP", "SpopKey2"}, + expectedValue: 0, + expectedResponse: nil, + expectedError: errors.New("value at SpopKey2 is not a set"), + }, + { + name: "3. Command too short", + presetValue: nil, + command: []string{"SPOP"}, + expectedValue: 0, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "4. Command too long", + presetValue: nil, + command: []string{"SPOP", "SpopSource5", "SpopSource6", "member1", "member2"}, + expectedValue: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Throw error when count is not an integer", + presetValue: nil, + command: []string{"SPOP", "SpopKey1", "count"}, + expectedValue: 0, + expectedError: errors.New("count must be an integer"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + // Check that each returned element is in the list of expected elements. + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + + // Check if the resulting set's cardinality is as expected. + if err := client.WriteArray([]resp.Value{resp.StringValue("SCARD"), resp.StringValue(test.key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if res.Integer() != test.expectedValue { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.key, test.expectedValue, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleSRANDMEMBER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue int // The final cardinality of the resulting set + allowRepeat bool + expectedResponse []string + expectedError error + }{ + { + // 1. Return multiple random elements without removing them + // Count is positive, do not allow repeated elements + name: "1. Return multiple random elements without removing them", + key: "SRandMemberKey1", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + command: []string{"SRANDMEMBER", "SRandMemberKey1", "3"}, + expectedValue: 8, + allowRepeat: false, + expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, + expectedError: nil, + }, + { + // 2. Return multiple random elements without removing them + // Count is negative, so allow repeated numbers + name: "2. Return multiple random elements without removing them", + key: "SRandMemberKey2", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + command: []string{"SRANDMEMBER", "SRandMemberKey2", "-5"}, + expectedValue: 8, + allowRepeat: true, + expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, + expectedError: nil, + }, + { + name: "3. Return error when the source key is not a set", + key: "SRandMemberKey3", + presetValue: "Default value", + command: []string{"SRANDMEMBER", "SRandMemberKey3"}, + expectedValue: 0, + expectedResponse: []string{}, + expectedError: errors.New("value at SRandMemberKey3 is not a set"), + }, + { + name: "4. Command too short", + command: []string{"SRANDMEMBER"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + command: []string{"SRANDMEMBER", "SRandMemberSource5", "SRandMemberSource6", "member1", "member2"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Throw error when count is not an integer", + command: []string{"SRANDMEMBER", "SRandMemberKey1", "count"}, + expectedError: errors.New("count must be an integer"), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + // Check that each returned element is in the list of expected elements. + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + + // If no repeats are allowed, check if the response contains any repeated elements + if !test.allowRepeat { + s := set.NewSet(func() []string { + elements := make([]string, len(res.Array())) + for i, item := range res.Array() { + elements[i] = item.String() + } + return elements + }()) + if s.Cardinality() != len(res.Array()) { + t.Error("response has repeated elements, expected only unique elements.") + } + } + + // Check if the resulting set's cardinality is as expected. + if err := client.WriteArray([]resp.Value{resp.StringValue("SCARD"), resp.StringValue(test.key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if res.Integer() != test.expectedValue { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.key, test.expectedValue, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleSREM", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue *set.Set // The final cardinality of the resulting set + expectedResponse int + expectedError error + }{ + { + name: "1. Remove multiple elements and return the number of elements removed", + key: "SremKey1", + presetValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + command: []string{"SREM", "SremKey1", "one", "two", "three", "nine"}, + expectedValue: set.NewSet([]string{"four", "five", "six", "seven", "eight"}), + expectedResponse: 3, + expectedError: nil, + }, + { + name: "2. If key does not exist, return 0", + key: "SremKey2", + presetValue: nil, + command: []string{"SREM", "SremKey1", "one", "two", "three", "nine"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Return error when the source key is not a set", + key: "SremKey3", + presetValue: "Default value", + command: []string{"SREM", "SremKey3", "one"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at key SremKey3 is not a set"), + }, + { + name: "4. Command too short", + command: []string{"SREM", "SremKey"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(test.key)} + for _, element := range test.presetValue.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(test.presetValue.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check if the resulting set(s) contain the expected members. + if test.expectedValue == nil { + return + } + + if err := client.WriteArray([]resp.Value{resp.StringValue("SMEMBERS"), resp.StringValue(test.key)}); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.key, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, item := range res.Array() { + if !test.expectedValue.Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleSUNION", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Get the union between 2 sets.", + presetValues: map[string]interface{}{ + "SunionKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SunionKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + command: []string{"SUNION", "SunionKey1", "SunionKey2"}, + expectedResponse: []string{"one", "two", "three", "four", "five", "six", "seven", "eight"}, + expectedError: nil, + }, + { + name: "2. Get the union between 3 sets.", + presetValues: map[string]interface{}{ + "SunionKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SunionKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SunionKey5": set.NewSet([]string{"one", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SUNION", "SunionKey3", "SunionKey4", "SunionKey5"}, + expectedResponse: []string{ + "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", + "ten", "eleven", "twelve", "thirty-six", + }, + expectedError: nil, + }, + { + name: "3. Throw an error if any of the provided keys are not sets", + presetValues: map[string]interface{}{ + "SunionKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SunionKey7": "Default value", + "SunionKey8": set.NewSet([]string{"one"}), + }, + command: []string{"SUNION", "SunionKey6", "SunionKey7", "SunionKey8"}, + expectedResponse: nil, + expectedError: errors.New("value at key SunionKey7 is not a set"), + }, + { + name: "4. Throw error any of the keys does not hold a set.", + presetValues: map[string]interface{}{ + "SunionKey9": "Default value", + "SunionKey10": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), + "SunionKey11": set.NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), + }, + command: []string{"SUNION", "SunionKey9", "SunionKey10", "SunionKey11"}, + expectedResponse: nil, + expectedError: errors.New("value at key SunionKey9 is not a set"), + }, + { + name: "6. Command too short", + command: []string{"SUNION"}, + expectedResponse: []string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length \"%d\", got \"%d\"", + len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + if !slices.Contains(test.expectedResponse, item.String()) { + t.Errorf("unexpected element \"%s\" in response", item.String()) + } + } + }) + } + }) + + t.Run("Test_HandleSUNIONSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *set.Set + expectedResponse int + expectedError error + }{ + { + name: "1. Get the intersection between 2 sets and store it at the destination.", + presetValues: map[string]interface{}{ + "SunionStoreKey1": set.NewSet([]string{"one", "two", "three", "four", "five"}), + "SunionStoreKey2": set.NewSet([]string{"three", "four", "five", "six", "seven", "eight"}), + }, + destination: "SunionStoreDestination1", + command: []string{"SUNIONSTORE", "SunionStoreDestination1", "SunionStoreKey1", "SunionStoreKey2"}, + expectedValue: set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + expectedResponse: 8, + expectedError: nil, + }, + { + name: "2. Get the intersection between 3 sets and store it at the destination key.", + presetValues: map[string]interface{}{ + "SunionStoreKey3": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SunionStoreKey4": set.NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven", "eight"}), + "SunionStoreKey5": set.NewSet([]string{"one", "seven", "eight", "nine", "ten", "twelve"}), + }, + destination: "SunionStoreDestination2", + command: []string{"SUNIONSTORE", "SunionStoreDestination2", "SunionStoreKey3", "SunionStoreKey4", "SunionStoreKey5"}, + expectedValue: set.NewSet([]string{ + "one", "two", "three", "four", "five", "six", "seven", "eight", + "nine", "ten", "eleven", "twelve", "thirty-six", + }), + expectedResponse: 13, + expectedError: nil, + }, + { + name: "3. Throw error when any of the keys is not a set", + presetValues: map[string]interface{}{ + "SunionStoreKey6": set.NewSet([]string{"one", "two", "three", "four", "five", "six", "seven", "eight"}), + "SunionStoreKey7": "Default value", + "SunionStoreKey8": set.NewSet([]string{"one"}), + }, + destination: "SunionStoreDestination3", + command: []string{"SUNIONSTORE", "SunionStoreDestination3", "SunionStoreKey6", "SunionStoreKey7", "SunionStoreKey8"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at key SunionStoreKey7 is not a set"), + }, + { + name: "5. Command too short", + command: []string{"SUNIONSTORE", "SunionStoreDestination6"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *set.Set: + command = []resp.Value{resp.StringValue("SADD"), resp.StringValue(key)} + for _, element := range value.(*set.Set).GetAll() { + command = append(command, []resp.Value{resp.StringValue(element)}...) + } + expected = strconv.Itoa(value.(*set.Set).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + + // Check if the resulting set(s) contain the expected members. + if test.expectedValue == nil { + return + } + + if err := client.WriteArray([]resp.Value{ + resp.StringValue("SMEMBERS"), + resp.StringValue(test.destination), + }); err != nil { + t.Error(err) + } + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected set at key \"%s\" to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, item := range res.Array() { + if !test.expectedValue.Contains(item.String()) { + t.Errorf("unexpected memeber \"%s\", in response", item.String()) + } + } + }) + } + }) } diff --git a/internal/modules/sorted_set/commands_test.go b/internal/modules/sorted_set/commands_test.go index 7474a92..f7e6a73 100644 --- a/internal/modules/sorted_set/commands_test.go +++ b/internal/modules/sorted_set/commands_test.go @@ -32,20 +32,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var addr = "localhost" -var port int +func Test_SortedSet(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error(err) + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -53,305 +59,203 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleZADD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - tests := []struct { - name string - presetValue *sorted_set.SortedSet - key string - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Create new sorted set and return the cardinality of the new sorted set", - presetValue: nil, - key: "ZaddKey1", - command: []string{"ZADD", "ZaddKey1", "5.5", "member1", "67.77", "member2", "10", "member3", "-inf", "member4", "+inf", "member5"}, - expectedResponse: 5, - expectedError: nil, - }, - { - name: "2. Only add the elements that do not currently exist in the sorted set when NX flag is provided", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey2", - command: []string{"ZADD", "ZaddKey2", "NX", "5.5", "member1", "67.77", "member4", "10", "member5"}, - expectedResponse: 2, - expectedError: nil, - }, - { - name: "3. Do not add any elements when providing existing members with NX flag", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey3", - command: []string{"ZADD", "ZaddKey3", "NX", "5.5", "member1", "67.77", "member2", "10", "member3"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "4. Successfully add elements to an existing set when XX flag is provided with existing elements", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey4", - command: []string{"ZADD", "ZaddKey4", "XX", "CH", "55", "member1", "1005", "member2", "15", "member3", "99.75", "member4"}, - expectedResponse: 3, - expectedError: nil, - }, - { - name: "5. Fail to add element when providing XX flag with elements that do not exist in the sorted set.", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey5", - command: []string{"ZADD", "ZaddKey5", "XX", "5.5", "member4", "100.5", "member5", "15", "member6"}, - expectedResponse: 0, - expectedError: nil, - }, - { - // 6. Only update the elements where provided score is greater than current score and GT flag is provided - // Return only the new elements added by default - name: "6. Only update the elements where provided score is greater than current score and GT flag is provided", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey6", - command: []string{"ZADD", "ZaddKey6", "XX", "CH", "GT", "7.5", "member1", "100.5", "member4", "15", "member5"}, - expectedResponse: 1, - expectedError: nil, - }, - { - // 7. Only update the elements where provided score is less than current score if LT flag is provided - // Return only the new elements added by default. - name: "7. Only update the elements where provided score is less than current score if LT flag is provided", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey7", - command: []string{"ZADD", "ZaddKey7", "XX", "LT", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "8. Return all the elements that were updated AND added when CH flag is provided", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey8", - command: []string{"ZADD", "ZaddKey8", "XX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedResponse: 1, - expectedError: nil, - }, - { - name: "9. Increment the member by score", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZaddKey9", - command: []string{"ZADD", "ZaddKey9", "INCR", "5.5", "member3"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "10. Fail when GT/LT flag is provided alongside NX flag", - presetValue: nil, - key: "ZaddKey10", - command: []string{"ZADD", "ZaddKey10", "NX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, - expectedResponse: 0, - expectedError: errors.New("GT/LT flags not allowed if NX flag is provided"), - }, - { - name: "11. Command is too short", - presetValue: nil, - key: "ZaddKey11", - command: []string{"ZADD", "ZaddKey11"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "12. Throw error when score/member entries are do not match", - presetValue: nil, - key: "ZaddKey11", - command: []string{"ZADD", "ZaddKey12", "10.5", "member1", "12.5"}, - expectedResponse: 0, - expectedError: errors.New("score/member pairs must be float/string"), - }, - { - name: "13. Throw error when INCR flag is passed with more than one score/member pair", - presetValue: nil, - key: "ZaddKey13", - command: []string{"ZADD", "ZaddKey13", "INCR", "10.5", "member1", "12.5", "member2"}, - expectedResponse: 0, - expectedError: errors.New("cannot pass more than one score/member pair when INCR flag is provided"), - }, - } + t.Run("Test_HandleZADD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string + tests := []struct { + name string + presetValue *sorted_set.SortedSet + key string + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Create new sorted set and return the cardinality of the new sorted set", + presetValue: nil, + key: "ZaddKey1", + command: []string{"ZADD", "ZaddKey1", "5.5", "member1", "67.77", "member2", "10", "member3", "-inf", "member4", "+inf", "member5"}, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "2. Only add the elements that do not currently exist in the sorted set when NX flag is provided", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey2", + command: []string{"ZADD", "ZaddKey2", "NX", "5.5", "member1", "67.77", "member4", "10", "member5"}, + expectedResponse: 2, + expectedError: nil, + }, + { + name: "3. Do not add any elements when providing existing members with NX flag", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey3", + command: []string{"ZADD", "ZaddKey3", "NX", "5.5", "member1", "67.77", "member2", "10", "member3"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "4. Successfully add elements to an existing set when XX flag is provided with existing elements", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey4", + command: []string{"ZADD", "ZaddKey4", "XX", "CH", "55", "member1", "1005", "member2", "15", "member3", "99.75", "member4"}, + expectedResponse: 3, + expectedError: nil, + }, + { + name: "5. Fail to add element when providing XX flag with elements that do not exist in the sorted set.", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey5", + command: []string{"ZADD", "ZaddKey5", "XX", "5.5", "member4", "100.5", "member5", "15", "member6"}, + expectedResponse: 0, + expectedError: nil, + }, + { + // 6. Only update the elements where provided score is greater than current score and GT flag is provided + // Return only the new elements added by default + name: "6. Only update the elements where provided score is greater than current score and GT flag is provided", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey6", + command: []string{"ZADD", "ZaddKey6", "XX", "CH", "GT", "7.5", "member1", "100.5", "member4", "15", "member5"}, + expectedResponse: 1, + expectedError: nil, + }, + { + // 7. Only update the elements where provided score is less than current score if LT flag is provided + // Return only the new elements added by default. + name: "7. Only update the elements where provided score is less than current score if LT flag is provided", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey7", + command: []string{"ZADD", "ZaddKey7", "XX", "LT", "3.5", "member1", "100.5", "member4", "15", "member5"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "8. Return all the elements that were updated AND added when CH flag is provided", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey8", + command: []string{"ZADD", "ZaddKey8", "XX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, + expectedResponse: 1, + expectedError: nil, + }, + { + name: "9. Increment the member by score", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + }), + key: "ZaddKey9", + command: []string{"ZADD", "ZaddKey9", "INCR", "5.5", "member3"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "10. Fail when GT/LT flag is provided alongside NX flag", + presetValue: nil, + key: "ZaddKey10", + command: []string{"ZADD", "ZaddKey10", "NX", "LT", "CH", "3.5", "member1", "100.5", "member4", "15", "member5"}, + expectedResponse: 0, + expectedError: errors.New("GT/LT flags not allowed if NX flag is provided"), + }, + { + name: "11. Command is too short", + presetValue: nil, + key: "ZaddKey11", + command: []string{"ZADD", "ZaddKey11"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "12. Throw error when score/member entries are do not match", + presetValue: nil, + key: "ZaddKey11", + command: []string{"ZADD", "ZaddKey12", "10.5", "member1", "12.5"}, + expectedResponse: 0, + expectedError: errors.New("score/member pairs must be float/string"), + }, + { + name: "13. Throw error when INCR flag is passed with more than one score/member pair", + presetValue: nil, + key: "ZaddKey13", + command: []string{"ZADD", "ZaddKey13", "INCR", "10.5", "member1", "12.5", "member2"}, + expectedResponse: 0, + expectedError: errors.New("cannot pass more than one score/member pair when INCR flag is provided"), + }, + } - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if res.Integer() != test.presetValue.Cardinality() { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleZCARD(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Get cardinality of valid sorted set.", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - }), - key: "ZcardKey1", - command: []string{"ZCARD", "ZcardKey1"}, - expectedResponse: 3, - expectedError: nil, - }, - { - name: "2. Return 0 when trying to get cardinality from non-existent key", - presetValue: nil, - key: "ZcardKey2", - command: []string{"ZCARD", "ZcardKey2"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Command is too short", - presetValue: nil, - key: "ZcardKey3", - command: []string{"ZCARD"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "4. Command too long", - presetValue: nil, - key: "ZcardKey4", - command: []string{"ZCARD", "ZcardKey4", "ZcardKey5"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Return error when not a sorted set", - presetValue: "Default value", - key: "ZcardKey5", - command: []string{"ZCARD", "ZcardKey5"}, - expectedResponse: 0, - expectedError: errors.New("value at ZcardKey5 is not a sorted set"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + for _, member := range test.presetValue.GetAll() { command = append(command, []resp.Value{ resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), resp.StringValue(string(member.Value)), }...) } - expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if res.Integer() != test.presetValue.Cardinality() { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -362,506 +266,109 @@ func Test_HandleZCARD(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleZCOUNT(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Get entire count using infinity boundaries", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - {Value: "member4", Score: sorted_set.Score(1083.13)}, - {Value: "member5", Score: sorted_set.Score(11)}, - {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, - {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, - }), - key: "ZcountKey1", - command: []string{"ZCOUNT", "ZcountKey1", "-inf", "+inf"}, - expectedResponse: 7, - expectedError: nil, - }, - { - name: "2. Get count of sub-set from -inf to limit", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - {Value: "member4", Score: sorted_set.Score(1083.13)}, - {Value: "member5", Score: sorted_set.Score(11)}, - {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, - {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, - }), - key: "ZcountKey2", - command: []string{"ZCOUNT", "ZcountKey2", "-inf", "90"}, - expectedResponse: 5, - expectedError: nil, - }, - { - name: "3. Get count of sub-set from bottom boundary to +inf limit", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "member1", Score: sorted_set.Score(5.5)}, - {Value: "member2", Score: sorted_set.Score(67.77)}, - {Value: "member3", Score: sorted_set.Score(10)}, - {Value: "member4", Score: sorted_set.Score(1083.13)}, - {Value: "member5", Score: sorted_set.Score(11)}, - {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, - {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, - }), - key: "ZcountKey3", - command: []string{"ZCOUNT", "ZcountKey3", "1000", "+inf"}, - expectedResponse: 2, - expectedError: nil, - }, - { - name: "4. Return error when bottom boundary is not a valid double/float", - presetValue: nil, - key: "ZcountKey4", - command: []string{"ZCOUNT", "ZcountKey4", "min", "10"}, - expectedResponse: 0, - expectedError: errors.New("min constraint must be a double"), - }, - { - name: "5. Return error when top boundary is not a valid double/float", - presetValue: nil, - key: "ZcountKey5", - command: []string{"ZCOUNT", "ZcountKey5", "-10", "max"}, - expectedResponse: 0, - expectedError: errors.New("max constraint must be a double"), - }, - { - name: "6. Command is too short", - presetValue: nil, - key: "ZcountKey6", - command: []string{"ZCOUNT"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Command too long", - presetValue: nil, - key: "ZcountKey7", - command: []string{"ZCOUNT", "ZcountKey4", "min", "max", "count"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "8. Throw error when value at the key is not a sorted set", - presetValue: "Default value", - key: "ZcountKey8", - command: []string{"ZCOUNT", "ZcountKey8", "1", "10"}, - expectedResponse: 0, - expectedError: errors.New("value at ZcountKey8 is not a sorted set"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + return } - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } + }) + } + }) - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } + t.Run("Test_HandleZCARD", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleZLEXCOUNT(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedResponse int - expectedError error - }{ - { - name: "1. Get entire count using infinity boundaries", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "e", Score: sorted_set.Score(1)}, - {Value: "f", Score: sorted_set.Score(1)}, - {Value: "g", Score: sorted_set.Score(1)}, - {Value: "h", Score: sorted_set.Score(1)}, - {Value: "i", Score: sorted_set.Score(1)}, - {Value: "j", Score: sorted_set.Score(1)}, - {Value: "k", Score: sorted_set.Score(1)}, - }), - key: "ZlexCountKey1", - command: []string{"ZLEXCOUNT", "ZlexCountKey1", "f", "j"}, - expectedResponse: 5, - expectedError: nil, - }, - { - name: "2. Return 0 when the members do not have the same score", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: sorted_set.Score(5.5)}, - {Value: "b", Score: sorted_set.Score(67.77)}, - {Value: "c", Score: sorted_set.Score(10)}, - {Value: "d", Score: sorted_set.Score(1083.13)}, - {Value: "e", Score: sorted_set.Score(11)}, - {Value: "f", Score: sorted_set.Score(math.Inf(-1))}, - {Value: "g", Score: sorted_set.Score(math.Inf(1))}, - }), - key: "ZlexCountKey2", - command: []string{"ZLEXCOUNT", "ZlexCountKey2", "a", "b"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Return 0 when the key does not exist", - presetValue: nil, - key: "ZlexCountKey3", - command: []string{"ZLEXCOUNT", "ZlexCountKey3", "a", "z"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "4. Return error when the value at the key is not a sorted set", - presetValue: "Default value", - key: "ZlexCountKey4", - command: []string{"ZLEXCOUNT", "ZlexCountKey4", "a", "z"}, - expectedResponse: 0, - expectedError: errors.New("value at ZlexCountKey4 is not a sorted set"), - }, - { - name: "5. Command is too short", - presetValue: nil, - key: "ZlexCountKey5", - command: []string{"ZLEXCOUNT"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - presetValue: nil, - key: "ZlexCountKey6", - command: []string{"ZLEXCOUNT", "ZlexCountKey6", "min", "max", "count"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} - -func Test_HandleZDIFF(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Get the difference between 2 sorted sets without scores.", - presetValues: map[string]interface{}{ - "ZdiffKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - {Value: "two", Score: 2}, - {Value: "three", Score: 3}, - {Value: "four", Score: 4}, - }), - "ZdiffKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, - {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, - {Value: "eight", Score: 8}, + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Get cardinality of valid sorted set.", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, }), + key: "ZcardKey1", + command: []string{"ZCARD", "ZcardKey1"}, + expectedResponse: 3, + expectedError: nil, }, - command: []string{"ZDIFF", "ZdiffKey1", "ZdiffKey2"}, - expectedResponse: [][]string{{"one"}, {"two"}}, - expectedError: nil, - }, - { - name: "2. Get the difference between 2 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZdiffKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - {Value: "two", Score: 2}, - {Value: "three", Score: 3}, - {Value: "four", Score: 4}, - }), - "ZdiffKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, - {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, - {Value: "eight", Score: 8}, - }), + { + name: "2. Return 0 when trying to get cardinality from non-existent key", + presetValue: nil, + key: "ZcardKey2", + command: []string{"ZCARD", "ZcardKey2"}, + expectedResponse: 0, + expectedError: nil, }, - command: []string{"ZDIFF", "ZdiffKey3", "ZdiffKey4", "WITHSCORES"}, - expectedResponse: [][]string{{"one", "1"}, {"two", "2"}}, - expectedError: nil, - }, - { - name: "3. Get the difference between 3 sets with scores.", - presetValues: map[string]interface{}{ - "ZdiffKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZdiffKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZdiffKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "3. Command is too short", + presetValue: nil, + key: "ZcardKey3", + command: []string{"ZCARD"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"ZDIFF", "ZdiffKey5", "ZdiffKey6", "ZdiffKey7", "WITHSCORES"}, - expectedResponse: [][]string{{"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, - expectedError: nil, - }, - { - name: "4. Return sorted set if only one key exists and is a sorted set", - presetValues: map[string]interface{}{ - "ZdiffKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), + { + name: "4. Command too long", + presetValue: nil, + key: "ZcardKey4", + command: []string{"ZCARD", "ZcardKey4", "ZcardKey5"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"ZDIFF", "ZdiffKey8", "ZdiffKey9", "ZdiffKey10", "WITHSCORES"}, - expectedResponse: [][]string{ - {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, - {"six", "6"}, {"seven", "7"}, {"eight", "8"}, + { + name: "5. Return error when not a sorted set", + presetValue: "Default value", + key: "ZcardKey5", + command: []string{"ZCARD", "ZcardKey5"}, + expectedResponse: 0, + expectedError: errors.New("value at ZcardKey5 is not a sorted set"), }, - expectedError: nil, - }, - { - name: "5. Throw error when one of the keys is not a sorted set.", - presetValues: map[string]interface{}{ - "ZdiffKey11": "Default value", - "ZdiffKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZdiffKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZDIFF", "ZdiffKey11", "ZdiffKey12", "ZdiffKey13"}, - expectedResponse: nil, - expectedError: errors.New("value at ZdiffKey11 is not a sorted set"), - }, - { - name: "6. Command too short", - command: []string{"ZDIFF"}, - expectedResponse: [][]string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { command = append(command, []resp.Value{ resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), resp.StringValue(string(member.Value)), }...) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -877,496 +384,9 @@ func Test_HandleZDIFF(t *testing.T) { } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() - } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) - } - if score != "" { - for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) - } - } - } - } - }) - } -} - -func Test_HandleZDIFFSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Get the difference between 2 sorted sets.", - presetValues: map[string]interface{}{ - "ZdiffStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - "ZdiffStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZdiffStoreDestinationKey1", - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey1", "ZdiffStoreKey1", "ZdiffStoreKey2"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}, {Value: "two", Score: 2}}), - expectedResponse: 2, - expectedError: nil, - }, - { - name: "2. Get the difference between 3 sorted sets.", - presetValues: map[string]interface{}{ - "ZdiffStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZdiffStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZdiffStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZdiffStoreDestinationKey2", - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey2", "ZdiffStoreKey3", "ZdiffStoreKey4", "ZdiffStoreKey5"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - }), - expectedResponse: 4, - expectedError: nil, - }, - { - name: "3. Return base sorted set element if base set is the only existing key provided and is a valid sorted set", - presetValues: map[string]interface{}{ - "ZdiffStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZdiffStoreDestinationKey3", - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey3", "ZdiffStoreKey6", "ZdiffStoreKey7", "ZdiffStoreKey8"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - expectedResponse: 8, - expectedError: nil, - }, - { - name: "4. Throw error when base sorted set is not a set.", - presetValues: map[string]interface{}{ - "ZdiffStoreKey9": "Default value", - "ZdiffStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZdiffStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZdiffStoreDestinationKey4", - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey4", "ZdiffStoreKey9", "ZdiffStoreKey10", "ZdiffStoreKey11"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: errors.New("value at ZdiffStoreKey9 is not a sorted set"), - }, - { - name: "5. Return 0 when base set is non-existent.", - destination: "ZdiffStoreDestinationKey5", - presetValues: map[string]interface{}{ - "ZdiffStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZdiffStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey5", "non-existent", "ZdiffStoreKey12", "ZdiffStoreKey13"}, - expectedValue: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "6. Command too short", - command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey6"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - if test.expectedValue == nil { - return - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(test.destination), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !test.expectedValue.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if test.expectedValue.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) - } - } - }) - } -} - -func Test_HandleZINCRBY(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValue interface{} - key string - command []string - expectedValue *sorted_set.SortedSet - expectedResponse string - expectedError error - }{ - { - name: "1. Successfully increment by int. Return the new score", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - key: "ZincrbyKey1", - command: []string{"ZINCRBY", "ZincrbyKey1", "5", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 6}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - expectedResponse: "6", - expectedError: nil, - }, - { - name: "2. Successfully increment by float. Return new score", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - key: "ZincrbyKey2", - command: []string{"ZINCRBY", "ZincrbyKey2", "346.785", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 347.785}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - expectedResponse: "347.785", - expectedError: nil, - }, - { - name: "3. Increment on non-existent sorted set will create the set with the member and increment as its score", - presetValue: nil, - key: "ZincrbyKey3", - command: []string{"ZINCRBY", "ZincrbyKey3", "346.785", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 346.785}, - }), - expectedResponse: "346.785", - expectedError: nil, - }, - { - name: "4. Increment score to +inf", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - key: "ZincrbyKey4", - command: []string{"ZINCRBY", "ZincrbyKey4", "+inf", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(1))}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - expectedResponse: "+Inf", - expectedError: nil, - }, - { - name: "5. Increment score to -inf", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - key: "ZincrbyKey5", - command: []string{"ZINCRBY", "ZincrbyKey5", "-inf", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - expectedResponse: "-Inf", - expectedError: nil, - }, - { - name: "6. Incrementing score by negative increment should lower the score", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - key: "ZincrbyKey6", - command: []string{"ZINCRBY", "ZincrbyKey6", "-2.5", "five"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 2.5}, - }), - expectedResponse: "2.5", - expectedError: nil, - }, - { - name: "7. Return error when attempting to increment on a value that is not a valid sorted set", - presetValue: "Default value", - key: "ZincrbyKey7", - command: []string{"ZINCRBY", "ZincrbyKey7", "-2.5", "five"}, - expectedValue: nil, - expectedResponse: "", - expectedError: errors.New("value at ZincrbyKey7 is not a sorted set"), - }, - { - name: "8. Return error when trying to increment a member that already has score -inf", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, - }), - key: "ZincrbyKey8", - command: []string{"ZINCRBY", "ZincrbyKey8", "2.5", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, - }), - expectedResponse: "", - expectedError: errors.New("cannot increment -inf or +inf"), - }, - { - name: "9. Return error when trying to increment a member that already has score +inf", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(1))}, - }), - key: "ZincrbyKey9", - command: []string{"ZINCRBY", "ZincrbyKey9", "2.5", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, - }), - expectedResponse: "", - expectedError: errors.New("cannot increment -inf or +inf"), - }, - { - name: "10. Return error when increment is not a valid number", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - }), - key: "ZincrbyKey10", - command: []string{"ZINCRBY", "ZincrbyKey10", "increment", "one"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - }), - expectedResponse: "", - expectedError: errors.New("increment must be a double"), - }, - { - name: "11. Command too short", - key: "ZincrbyKey11", - command: []string{"ZINCRBY", "ZincrbyKey11", "one"}, - expectedResponse: "", - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "12. Command too long", - key: "ZincrbyKey12", - command: []string{"ZINCRBY", "ZincrbyKey12", "one", "1", "2"}, - expectedResponse: "", - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -1377,280 +397,153 @@ func Test_HandleZINCRBY(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } - return - } + }) + } + }) - if res.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } + t.Run("Test_HandleZCOUNT", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - // Check if the resulting sorted set has the expected members/scores - if test.expectedValue == nil { - return - } + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Get entire count using infinity boundaries", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + {Value: "member4", Score: sorted_set.Score(1083.13)}, + {Value: "member5", Score: sorted_set.Score(11)}, + {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, + {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, + }), + key: "ZcountKey1", + command: []string{"ZCOUNT", "ZcountKey1", "-inf", "+inf"}, + expectedResponse: 7, + expectedError: nil, + }, + { + name: "2. Get count of sub-set from -inf to limit", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + {Value: "member4", Score: sorted_set.Score(1083.13)}, + {Value: "member5", Score: sorted_set.Score(11)}, + {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, + {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, + }), + key: "ZcountKey2", + command: []string{"ZCOUNT", "ZcountKey2", "-inf", "90"}, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "3. Get count of sub-set from bottom boundary to +inf limit", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "member1", Score: sorted_set.Score(5.5)}, + {Value: "member2", Score: sorted_set.Score(67.77)}, + {Value: "member3", Score: sorted_set.Score(10)}, + {Value: "member4", Score: sorted_set.Score(1083.13)}, + {Value: "member5", Score: sorted_set.Score(11)}, + {Value: "member6", Score: sorted_set.Score(math.Inf(-1))}, + {Value: "member7", Score: sorted_set.Score(math.Inf(1))}, + }), + key: "ZcountKey3", + command: []string{"ZCOUNT", "ZcountKey3", "1000", "+inf"}, + expectedResponse: 2, + expectedError: nil, + }, + { + name: "4. Return error when bottom boundary is not a valid double/float", + presetValue: nil, + key: "ZcountKey4", + command: []string{"ZCOUNT", "ZcountKey4", "min", "10"}, + expectedResponse: 0, + expectedError: errors.New("min constraint must be a double"), + }, + { + name: "5. Return error when top boundary is not a valid double/float", + presetValue: nil, + key: "ZcountKey5", + command: []string{"ZCOUNT", "ZcountKey5", "-10", "max"}, + expectedResponse: 0, + expectedError: errors.New("max constraint must be a double"), + }, + { + name: "6. Command is too short", + presetValue: nil, + key: "ZcountKey6", + command: []string{"ZCOUNT"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Command too long", + presetValue: nil, + key: "ZcountKey7", + command: []string{"ZCOUNT", "ZcountKey4", "min", "max", "count"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "8. Throw error when value at the key is not a sorted set", + presetValue: "Default value", + key: "ZcountKey8", + command: []string{"ZCOUNT", "ZcountKey8", "1", "10"}, + expectedResponse: 0, + expectedError: errors.New("value at ZcountKey8 is not a sorted set"), + }, + } - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(test.key), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - test.key, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !test.expectedValue.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if test.expectedValue.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) - } - } - }) - } -} - -func Test_HandleZMPOP(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - preset bool - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Successfully pop one min element by default", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey1"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - expectedResponse: [][]string{ - {"one", "1"}, - }, - expectedError: nil, - }, - { - name: "2. Successfully pop one min element by specifying MIN", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey2", "MIN"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - expectedResponse: [][]string{ - {"one", "1"}, - }, - expectedError: nil, - }, - { - name: "3. Successfully pop one max element by specifying MAX modifier", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey3", "MAX"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - }), - }, - expectedResponse: [][]string{ - {"five", "5"}, - }, - expectedError: nil, - }, - { - name: "4. Successfully pop multiple min elements", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey4", "MIN", "COUNT", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "six", Score: 6}, - }), - }, - expectedResponse: [][]string{ - {"one", "1"}, {"two", "2"}, {"three", "3"}, - {"four", "4"}, {"five", "5"}, - }, - expectedError: nil, - }, - { - name: "5. Successfully pop multiple max elements", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey5", "MAX", "COUNT", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - }), - }, - expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, - expectedError: nil, - }, - { - name: "6. Successfully pop elements from the first set which is non-empty", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey6", "ZmpopKey7", "MAX", "COUNT", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), - "ZmpopKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, - }), - }, - expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, - expectedError: nil, - }, - { - name: "7. Skip the non-set items and pop elements from the first non-empty sorted set found", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopKey8": "Default value", - "ZmpopKey9": "56", - "ZmpopKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - }), - }, - command: []string{"ZMPOP", "ZmpopKey8", "ZmpopKey9", "ZmpopKey10", "ZmpopKey11", "MIN", "COUNT", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), - "ZmpopKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "six", Score: 6}, - }), - }, - expectedResponse: [][]string{{"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}}, - expectedError: nil, - }, - { - name: "9. Return error when count is a negative integer", - preset: false, - command: []string{"ZMPOP", "ZmpopKey8", "MAX", "COUNT", "-20"}, - expectedError: errors.New("count must be a positive integer"), - }, - { - name: "9. Command too short", - preset: false, - command: []string{"ZMPOP"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { command = append(command, []resp.Value{ resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), resp.StringValue(string(member.Value)), }...) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -1666,63 +559,615 @@ func Test_HandleZMPOP(t *testing.T) { } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) + return } - if score != "" { - for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleZLEXCOUNT", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedResponse int + expectedError error + }{ + { + name: "1. Get entire count using infinity boundaries", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "e", Score: sorted_set.Score(1)}, + {Value: "f", Score: sorted_set.Score(1)}, + {Value: "g", Score: sorted_set.Score(1)}, + {Value: "h", Score: sorted_set.Score(1)}, + {Value: "i", Score: sorted_set.Score(1)}, + {Value: "j", Score: sorted_set.Score(1)}, + {Value: "k", Score: sorted_set.Score(1)}, + }), + key: "ZlexCountKey1", + command: []string{"ZLEXCOUNT", "ZlexCountKey1", "f", "j"}, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "2. Return 0 when the members do not have the same score", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: sorted_set.Score(5.5)}, + {Value: "b", Score: sorted_set.Score(67.77)}, + {Value: "c", Score: sorted_set.Score(10)}, + {Value: "d", Score: sorted_set.Score(1083.13)}, + {Value: "e", Score: sorted_set.Score(11)}, + {Value: "f", Score: sorted_set.Score(math.Inf(-1))}, + {Value: "g", Score: sorted_set.Score(math.Inf(1))}, + }), + key: "ZlexCountKey2", + command: []string{"ZLEXCOUNT", "ZlexCountKey2", "a", "b"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Return 0 when the key does not exist", + presetValue: nil, + key: "ZlexCountKey3", + command: []string{"ZLEXCOUNT", "ZlexCountKey3", "a", "z"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "4. Return error when the value at the key is not a sorted set", + presetValue: "Default value", + key: "ZlexCountKey4", + command: []string{"ZLEXCOUNT", "ZlexCountKey4", "a", "z"}, + expectedResponse: 0, + expectedError: errors.New("value at ZlexCountKey4 is not a sorted set"), + }, + { + name: "5. Command is too short", + presetValue: nil, + key: "ZlexCountKey5", + command: []string{"ZLEXCOUNT"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + presetValue: nil, + key: "ZlexCountKey6", + command: []string{"ZLEXCOUNT", "ZlexCountKey6", "min", "max", "count"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) + } + }) + } + }) + + t.Run("Test_HandleZDIFF", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Get the difference between 2 sorted sets without scores.", + presetValues: map[string]interface{}{ + "ZdiffKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + {Value: "two", Score: 2}, + {Value: "three", Score: 3}, + {Value: "four", Score: 4}, + }), + "ZdiffKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, + {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, + {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZDIFF", "ZdiffKey1", "ZdiffKey2"}, + expectedResponse: [][]string{{"one"}, {"two"}}, + expectedError: nil, + }, + { + name: "2. Get the difference between 2 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZdiffKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + {Value: "two", Score: 2}, + {Value: "three", Score: 3}, + {Value: "four", Score: 4}, + }), + "ZdiffKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, + {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, + {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZDIFF", "ZdiffKey3", "ZdiffKey4", "WITHSCORES"}, + expectedResponse: [][]string{{"one", "1"}, {"two", "2"}}, + expectedError: nil, + }, + { + name: "3. Get the difference between 3 sets with scores.", + presetValues: map[string]interface{}{ + "ZdiffKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZdiffKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZdiffKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZDIFF", "ZdiffKey5", "ZdiffKey6", "ZdiffKey7", "WITHSCORES"}, + expectedResponse: [][]string{{"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, + expectedError: nil, + }, + { + name: "4. Return sorted set if only one key exists and is a sorted set", + presetValues: map[string]interface{}{ + "ZdiffKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZDIFF", "ZdiffKey8", "ZdiffKey9", "ZdiffKey10", "WITHSCORES"}, + expectedResponse: [][]string{ + {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, + {"six", "6"}, {"seven", "7"}, {"eight", "8"}, + }, + expectedError: nil, + }, + { + name: "5. Throw error when one of the keys is not a sorted set.", + presetValues: map[string]interface{}{ + "ZdiffKey11": "Default value", + "ZdiffKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZdiffKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZDIFF", "ZdiffKey11", "ZdiffKey12", "ZdiffKey13"}, + expectedResponse: nil, + expectedError: errors.New("value at ZdiffKey11 is not a sorted set"), + }, + { + name: "6. Command too short", + command: []string{"ZDIFF"}, + expectedResponse: [][]string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() + } + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } } } } - } + }) + } + }) - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue + t.Run("Test_HandleZDIFFSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Get the difference between 2 sorted sets.", + presetValues: map[string]interface{}{ + "ZdiffStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + "ZdiffStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZdiffStoreDestinationKey1", + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey1", "ZdiffStoreKey1", "ZdiffStoreKey2"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}, {Value: "two", Score: 2}}), + expectedResponse: 2, + expectedError: nil, + }, + { + name: "2. Get the difference between 3 sorted sets.", + presetValues: map[string]interface{}{ + "ZdiffStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZdiffStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZdiffStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZdiffStoreDestinationKey2", + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey2", "ZdiffStoreKey3", "ZdiffStoreKey4", "ZdiffStoreKey5"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + expectedResponse: 4, + expectedError: nil, + }, + { + name: "3. Return base sorted set element if base set is the only existing key provided and is a valid sorted set", + presetValues: map[string]interface{}{ + "ZdiffStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZdiffStoreDestinationKey3", + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey3", "ZdiffStoreKey6", "ZdiffStoreKey7", "ZdiffStoreKey8"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + expectedResponse: 8, + expectedError: nil, + }, + { + name: "4. Throw error when base sorted set is not a set.", + presetValues: map[string]interface{}{ + "ZdiffStoreKey9": "Default value", + "ZdiffStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZdiffStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZdiffStoreDestinationKey4", + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey4", "ZdiffStoreKey9", "ZdiffStoreKey10", "ZdiffStoreKey11"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: errors.New("value at ZdiffStoreKey9 is not a sorted set"), + }, + { + name: "5. Return 0 when base set is non-existent.", + destination: "ZdiffStoreDestinationKey5", + presetValues: map[string]interface{}{ + "ZdiffStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZdiffStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey5", "non-existent", "ZdiffStoreKey12", "ZdiffStoreKey13"}, + expectedValue: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "6. Command too short", + command: []string{"ZDIFFSTORE", "ZdiffStoreDestinationKey6"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return } if err = client.WriteArray([]resp.Value{ resp.StringValue("ZRANGE"), - resp.StringValue(key), + resp.StringValue(test.destination), resp.StringValue("-inf"), resp.StringValue("+inf"), resp.StringValue("BYSCORE"), @@ -1736,178 +1181,229 @@ func Test_HandleZMPOP(t *testing.T) { t.Error(err) } - if len(res.Array()) != expectedSortedSet.Cardinality() { + if len(res.Array()) != test.expectedValue.Cardinality() { t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) + test.destination, test.expectedValue.Cardinality(), len(res.Array())) } for _, member := range res.Array() { value := sorted_set.Value(member.Array()[0].String()) score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { + if !test.expectedValue.Contains(value) { t.Errorf("unexpected value %s in resulting sorted set", value) } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } } - } - }) - } -} + }) + } + }) -func Test_HandleZPOP(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleZINCRBY", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - preset bool - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Successfully pop one min element by default", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopMinKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + tests := []struct { + name string + presetValue interface{} + key string + command []string + expectedValue *sorted_set.SortedSet + expectedResponse string + expectedError error + }{ + { + name: "1. Successfully increment by int. Return the new score", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, }), - }, - command: []string{"ZPOPMIN", "ZmpopMinKey1"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopMinKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "two", Score: 2}, + key: "ZincrbyKey1", + command: []string{"ZINCRBY", "ZincrbyKey1", "5", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 6}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, }), + expectedResponse: "6", + expectedError: nil, }, - expectedResponse: [][]string{ - {"one", "1"}, - }, - expectedError: nil, - }, - { - name: "2. Successfully pop one max element by default", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopMaxKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "2. Successfully increment by float. Return new score", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, }), + key: "ZincrbyKey2", + command: []string{"ZINCRBY", "ZincrbyKey2", "346.785", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 347.785}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + expectedResponse: "347.785", + expectedError: nil, }, - command: []string{"ZPOPMAX", "ZmpopMaxKey2"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopMaxKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "3. Increment on non-existent sorted set will create the set with the member and increment as its score", + presetValue: nil, + key: "ZincrbyKey3", + command: []string{"ZINCRBY", "ZincrbyKey3", "346.785", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 346.785}, + }), + expectedResponse: "346.785", + expectedError: nil, + }, + { + name: "4. Increment score to +inf", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, }), + key: "ZincrbyKey4", + command: []string{"ZINCRBY", "ZincrbyKey4", "+inf", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(1))}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + expectedResponse: "+Inf", + expectedError: nil, }, - expectedResponse: [][]string{ - {"five", "5"}, - }, - expectedError: nil, - }, - { - name: "3. Successfully pop multiple min elements", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopMinKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "5. Increment score to -inf", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "five", Score: 5}, }), - }, - command: []string{"ZPOPMIN", "ZmpopMinKey3", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopMinKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "six", Score: 6}, + key: "ZincrbyKey5", + command: []string{"ZINCRBY", "ZincrbyKey5", "-inf", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, }), + expectedResponse: "-Inf", + expectedError: nil, }, - expectedResponse: [][]string{ - {"one", "1"}, {"two", "2"}, {"three", "3"}, - {"four", "4"}, {"five", "5"}, - }, - expectedError: nil, - }, - { - name: "4. Successfully pop multiple max elements", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopMaxKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "6. Incrementing score by negative increment should lower the score", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "five", Score: 5}, }), + key: "ZincrbyKey6", + command: []string{"ZINCRBY", "ZincrbyKey6", "-2.5", "five"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 2.5}, + }), + expectedResponse: "2.5", + expectedError: nil, }, - command: []string{"ZPOPMAX", "ZmpopMaxKey4", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZmpopMaxKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "7. Return error when attempting to increment on a value that is not a valid sorted set", + presetValue: "Default value", + key: "ZincrbyKey7", + command: []string{"ZINCRBY", "ZincrbyKey7", "-2.5", "five"}, + expectedValue: nil, + expectedResponse: "", + expectedError: errors.New("value at ZincrbyKey7 is not a sorted set"), + }, + { + name: "8. Return error when trying to increment a member that already has score -inf", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, + }), + key: "ZincrbyKey8", + command: []string{"ZINCRBY", "ZincrbyKey8", "2.5", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, + }), + expectedResponse: "", + expectedError: errors.New("cannot increment -inf or +inf"), + }, + { + name: "9. Return error when trying to increment a member that already has score +inf", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(1))}, + }), + key: "ZincrbyKey9", + command: []string{"ZINCRBY", "ZincrbyKey9", "2.5", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: sorted_set.Score(math.Inf(-1))}, + }), + expectedResponse: "", + expectedError: errors.New("cannot increment -inf or +inf"), + }, + { + name: "10. Return error when increment is not a valid number", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, }), + key: "ZincrbyKey10", + command: []string{"ZINCRBY", "ZincrbyKey10", "increment", "one"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + }), + expectedResponse: "", + expectedError: errors.New("increment must be a double"), }, - expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, - expectedError: nil, - }, - { - name: "5. Throw an error when trying to pop from an element that's not a sorted set", - preset: true, - presetValues: map[string]interface{}{ - "ZmpopMinKey5": "Default value", + { + name: "11. Command too short", + key: "ZincrbyKey11", + command: []string{"ZINCRBY", "ZincrbyKey11", "one"}, + expectedResponse: "", + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"ZPOPMIN", "ZmpopMinKey5"}, - expectedValues: nil, - expectedResponse: nil, - expectedError: errors.New("value at key ZmpopMinKey5 is not a sorted set"), - }, - { - name: "6. Command too short", - preset: false, - command: []string{"ZPOPMAX"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Command too long", - preset: false, - command: []string{"ZPOPMAX", "ZmpopMaxKey7", "6", "3"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + { + name: "12. Command too long", + key: "ZincrbyKey12", + command: []string{"ZINCRBY", "ZincrbyKey12", "one", "1", "2"}, + expectedResponse: "", + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { command = append(command, []resp.Value{ resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), resp.StringValue(string(member.Value)), }...) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -1923,63 +1419,38 @@ func Test_HandleZPOP(t *testing.T) { } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - return - } - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) + return } - if score != "" { - for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) - } - } - } - } - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return } if err = client.WriteArray([]resp.Value{ resp.StringValue("ZRANGE"), - resp.StringValue(key), + resp.StringValue(test.key), resp.StringValue("-inf"), resp.StringValue("+inf"), resp.StringValue("BYSCORE"), @@ -1993,391 +1464,254 @@ func Test_HandleZPOP(t *testing.T) { t.Error(err) } - if len(res.Array()) != expectedSortedSet.Cardinality() { + if len(res.Array()) != test.expectedValue.Cardinality() { t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) + test.key, test.expectedValue.Cardinality(), len(res.Array())) } for _, member := range res.Array() { value := sorted_set.Value(member.Array()[0].String()) score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { + if !test.expectedValue.Contains(value) { t.Errorf("unexpected value %s in resulting sorted set", value) } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) } } - } - }) - } -} + }) + } + }) -func Test_HandleZMSCORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleZMPOP", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - // 1. Return multiple scores from the sorted set. - // Return nil for elements that do not exist in the sorted set. - name: "1. Return multiple scores from the sorted set.", - presetValues: map[string]interface{}{ - "ZmScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, - {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, - {Value: "five", Score: 5}, - }), + tests := []struct { + name string + preset bool + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Successfully pop one min element by default", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey1"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + expectedResponse: [][]string{ + {"one", "1"}, + }, + expectedError: nil, }, - command: []string{"ZMSCORE", "ZmScoreKey1", "one", "none", "two", "one", "three", "four", "none", "five"}, - expectedResponse: []string{"1.1", "", "245", "1.1", "3", "4.055", "", "5"}, - expectedError: nil, - }, - { - name: "2. If key does not exist, return empty array", - presetValues: nil, - command: []string{"ZMSCORE", "ZmScoreKey2", "one", "two", "three", "four"}, - expectedResponse: []string{}, - expectedError: nil, - }, - { - name: "3. Throw error when trying to find scores from elements that are not sorted sets", - presetValues: map[string]interface{}{"ZmScoreKey3": "Default value"}, - command: []string{"ZMSCORE", "ZmScoreKey3", "one", "two", "three"}, - expectedError: errors.New("value at ZmScoreKey3 is not a sorted set"), - }, - { - name: "9. Command too short", - command: []string{"ZMSCORE"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + { + name: "2. Successfully pop one min element by specifying MIN", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey2", "MIN"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + expectedResponse: [][]string{ + {"one", "1"}, + }, + expectedError: nil, + }, + { + name: "3. Successfully pop one max element by specifying MAX modifier", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey3", "MAX"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + }), + }, + expectedResponse: [][]string{ + {"five", "5"}, + }, + expectedError: nil, + }, + { + name: "4. Successfully pop multiple min elements", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey4", "MIN", "COUNT", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "six", Score: 6}, + }), + }, + expectedResponse: [][]string{ + {"one", "1"}, {"two", "2"}, {"three", "3"}, + {"four", "4"}, {"five", "5"}, + }, + expectedError: nil, + }, + { + name: "5. Successfully pop multiple max elements", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey5", "MAX", "COUNT", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + }), + }, + expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, + expectedError: nil, + }, + { + name: "6. Successfully pop elements from the first set which is non-empty", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey6", "ZmpopKey7", "MAX", "COUNT", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), + "ZmpopKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + }), + }, + expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, + expectedError: nil, + }, + { + name: "7. Skip the non-set items and pop elements from the first non-empty sorted set found", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopKey8": "Default value", + "ZmpopKey9": "56", + "ZmpopKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZMPOP", "ZmpopKey8", "ZmpopKey9", "ZmpopKey10", "ZmpopKey11", "MIN", "COUNT", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{}), + "ZmpopKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "six", Score: 6}, + }), + }, + expectedResponse: [][]string{{"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}}, + expectedError: nil, + }, + { + name: "9. Return error when count is a negative integer", + preset: false, + command: []string{"ZMPOP", "ZmpopKey8", "MAX", "COUNT", "-20"}, + expectedError: errors.New("count must be a positive integer"), + }, + { + name: "9. Command too short", + preset: false, + command: []string{"ZMPOP"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + + if err = client.WriteArray(command); err != nil { + t.Error(err) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for i := 0; i < len(res.Array()); i++ { - if test.expectedResponse[i] != res.Array()[i].String() { - t.Errorf("expected element at index %d to be \"%s\", got %s", - i, test.expectedResponse[i], res.Array()[i].String()) - } - } - }) - } -} - -func Test_HandleZSCORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse string - expectedError error - }{ - { - name: "1. Return score from a sorted set.", - presetValues: map[string]interface{}{ - "ZscoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, - {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZSCORE", "ZscoreKey1", "four"}, - expectedResponse: "4.055", - expectedError: nil, - }, - { - name: "2. If key does not exist, return nil value", - presetValues: nil, - command: []string{"ZSCORE", "ZscoreKey2", "one"}, - expectedResponse: "", - expectedError: nil, - }, - { - name: "3. If key exists and is a sorted set, but the member does not exist, return nil", - presetValues: map[string]interface{}{ - "ZscoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, - {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZSCORE", "ZscoreKey3", "non-existent"}, - expectedResponse: "", - expectedError: nil, - }, - { - name: "4. Throw error when trying to find scores from elements that are not sorted sets", - presetValues: map[string]interface{}{"ZscoreKey4": "Default value"}, - command: []string{"ZSCORE", "ZscoreKey4", "one"}, - expectedError: errors.New("value at ZscoreKey4 is not a sorted set"), - }, - { - name: "5. Command too short", - command: []string{"ZSCORE"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"ZSCORE", "ZscoreKey5", "one", "two"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } - }) - } -} - -func Test_HandleZRANDMEMBER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - key string - presetValue interface{} - command []string - expectedValue int // The final cardinality of the resulting set - allowRepeat bool - expectedResponse [][]string - expectedError error - }{ - { - // 1. Return multiple random elements without removing them. - // Count is positive, do not allow repeated elements - name: "1. Return multiple random elements without removing them.", - key: "ZrandMemberKey1", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "3"}, - expectedValue: 8, - allowRepeat: false, - expectedResponse: [][]string{ - {"one"}, {"two"}, {"three"}, {"four"}, - {"five"}, {"six"}, {"seven"}, {"eight"}, - }, - expectedError: nil, - }, - { - // 2. Return multiple random elements and their scores without removing them. - // Count is negative, so allow repeated numbers. - name: "2. Return multiple random elements and their scores without removing them.", - key: "ZrandMemberKey2", - presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - command: []string{"ZRANDMEMBER", "ZrandMemberKey2", "-5", "WITHSCORES"}, - expectedValue: 8, - allowRepeat: true, - expectedResponse: [][]string{ - {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, - {"five", "5"}, {"six", "6"}, {"seven", "7"}, {"eight", "8"}, - }, - expectedError: nil, - }, - { - name: "2. Return error when the source key is not a sorted set.", - key: "ZrandMemberKey3", - presetValue: "Default value", - command: []string{"ZRANDMEMBER", "ZrandMemberKey3"}, - expectedValue: 0, - expectedError: errors.New("value at ZrandMemberKey3 is not a sorted set"), - }, - { - name: "5. Command too short", - command: []string{"ZRANDMEMBER"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"ZRANDMEMBER", "source5", "source6", "member1", "member2"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "7. Throw error when count is not an integer", - command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "count"}, - expectedError: errors.New("count must be an integer"), - }, - { - name: "8. Throw error when the fourth argument is not WITHSCORES", - command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "8", "ANOTHER"}, - expectedError: errors.New("last option must be WITHSCORES"), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != nil { - var command []resp.Value - var expected string - - switch test.presetValue.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} - for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } if err = client.WriteArray(command); err != nil { @@ -2388,718 +1722,719 @@ func Test_HandleZRANDMEMBER(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) - } - return - } - - // Check that each of the returned elements is in the expected response. - for _, item := range res.Array() { - value := sorted_set.Value(item.Array()[0].String()) - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == string(value) - }) { - t.Errorf("unexected element \"%s\" in response", value) - } - for _, expected := range test.expectedResponse { - if len(item.Array()) != len(expected) { - t.Errorf("expected response for element \"%s\" to have length %d, got %d", - value, len(expected), len(item.Array())) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) } - if expected[0] != string(value) { + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() + } + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } + } + } + } + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { continue } - if len(expected) == 2 { - score := item.Array()[1].String() - if expected[1] != score { - t.Errorf("expected score for memebr \"%s\" to be %s, got %s", value, expected[1], score) + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } - } + }) + } + }) + + t.Run("Test_HandleZPOP", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + preset bool + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Successfully pop one min element by default", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopMinKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZPOPMIN", "ZmpopMinKey1"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopMinKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + expectedResponse: [][]string{ + {"one", "1"}, + }, + expectedError: nil, + }, + { + name: "2. Successfully pop one max element by default", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopMaxKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZPOPMAX", "ZmpopMaxKey2"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopMaxKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + }), + }, + expectedResponse: [][]string{ + {"five", "5"}, + }, + expectedError: nil, + }, + { + name: "3. Successfully pop multiple min elements", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopMinKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZPOPMIN", "ZmpopMinKey3", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopMinKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "six", Score: 6}, + }), + }, + expectedResponse: [][]string{ + {"one", "1"}, {"two", "2"}, {"three", "3"}, + {"four", "4"}, {"five", "5"}, + }, + expectedError: nil, + }, + { + name: "4. Successfully pop multiple max elements", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopMaxKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + }), + }, + command: []string{"ZPOPMAX", "ZmpopMaxKey4", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZmpopMaxKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, + }), + }, + expectedResponse: [][]string{{"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}}, + expectedError: nil, + }, + { + name: "5. Throw an error when trying to pop from an element that's not a sorted set", + preset: true, + presetValues: map[string]interface{}{ + "ZmpopMinKey5": "Default value", + }, + command: []string{"ZPOPMIN", "ZmpopMinKey5"}, + expectedValues: nil, + expectedResponse: nil, + expectedError: errors.New("value at key ZmpopMinKey5 is not a sorted set"), + }, + { + name: "6. Command too short", + preset: false, + command: []string{"ZPOPMAX"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "7. Command too long", + preset: false, + command: []string{"ZPOPMAX", "ZmpopMaxKey7", "6", "3"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } - // Check that allowRepeat determines whether elements are repeated or not. - if !test.allowRepeat { - ss := sorted_set.NewSortedSet([]sorted_set.MemberParam{}) for _, item := range res.Array() { - member := sorted_set.Value(item.Array()[0].String()) - score := func() sorted_set.Score { + value := item.Array()[0].String() + score := func() string { if len(item.Array()) == 2 { - return sorted_set.Score(item.Array()[1].Float()) + return item.Array()[1].String() } - return sorted_set.Score(0) + return "" }() - _, err = ss.AddOrUpdate( - []sorted_set.MemberParam{{member, score}}, - nil, nil, nil, nil) - if err != nil { - t.Error(err) + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } + } } } - if len(res.Array()) != ss.Cardinality() { - t.Error("unexpected repeated elements in response") - } - } - }) - } -} -func Test_HandleZRANK(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse []string - expectedError error - }{ - { - name: "1. Return element's rank from a sorted set.", - presetValues: map[string]interface{}{ - "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZRANK", "ZrankKey1", "four"}, - expectedResponse: []string{"3"}, - expectedError: nil, - }, - { - name: "2. Return element's rank from a sorted set with its score.", - presetValues: map[string]interface{}{ - "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100.1}, {Value: "two", Score: 245}, - {Value: "three", Score: 305.43}, {Value: "four", Score: 411.055}, - {Value: "five", Score: 500}, - }), - }, - command: []string{"ZRANK", "ZrankKey1", "four", "WITHSCORES"}, - expectedResponse: []string{"3", "411.055"}, - expectedError: nil, - }, - { - name: "3. If key does not exist, return nil value", - presetValues: nil, - command: []string{"ZRANK", "ZrankKey3", "one"}, - expectedResponse: nil, - expectedError: nil, - }, - { - name: "4. If key exists and is a sorted set, but the member does not exist, return nil", - presetValues: map[string]interface{}{ - "ZrankKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, - {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, - {Value: "five", Score: 5}, - }), - }, - command: []string{"ZRANK", "ZrankKey4", "non-existent"}, - expectedResponse: nil, - expectedError: nil, - }, - { - name: "5. Throw error when trying to find scores from elements that are not sorted sets", - presetValues: map[string]interface{}{"ZrankKey5": "Default value"}, - command: []string{"ZRANK", "ZrankKey5", "one"}, - expectedError: errors.New("value at ZrankKey5 is not a sorted set"), - }, - { - name: "5. Command too short", - command: []string{"ZRANK"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"ZRANK", "ZrankKey5", "one", "WITHSCORES", "two"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue } - if err = client.WriteArray(command); err != nil { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for i := 0; i < len(res.Array()); i++ { - if test.expectedResponse[i] != res.Array()[i].String() { - t.Errorf("expected element at index %d to be \"%s\", got %s", - i, test.expectedResponse[i], res.Array()[i].String()) - } - } - }) - } -} - -func Test_HandleZREM(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - // Successfully remove multiple elements from sorted set, skipping non-existent members. - // Return deleted count. - name: "1. Successfully remove multiple elements from sorted set, skipping non-existent members.", - presetValues: map[string]interface{}{ - "ZremKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), - }, - command: []string{"ZREM", "ZremKey1", "three", "four", "five", "none", "six", "none", "seven"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), - }, - expectedResponse: 5, - expectedError: nil, - }, - { - name: "2. If key does not exist, return 0", - presetValues: nil, - command: []string{"ZREM", "ZremKey2", "member"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Return error key is not a sorted set", - presetValues: map[string]interface{}{ - "ZremKey3": "Default value", - }, - command: []string{"ZREM", "ZremKey3", "member"}, - expectedError: errors.New("value at ZremKey3 is not a sorted set"), - }, - { - name: "9. Command too short", - command: []string{"ZREM"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } + }) + } + }) - } + t.Run("Test_HandleZMSCORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + // 1. Return multiple scores from the sorted set. + // Return nil for elements that do not exist in the sorted set. + name: "1. Return multiple scores from the sorted set.", + presetValues: map[string]interface{}{ + "ZmScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, + {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZMSCORE", "ZmScoreKey1", "one", "none", "two", "one", "three", "four", "none", "five"}, + expectedResponse: []string{"1.1", "", "245", "1.1", "3", "4.055", "", "5"}, + expectedError: nil, + }, + { + name: "2. If key does not exist, return empty array", + presetValues: nil, + command: []string{"ZMSCORE", "ZmScoreKey2", "one", "two", "three", "four"}, + expectedResponse: []string{}, + expectedError: nil, + }, + { + name: "3. Throw error when trying to find scores from elements that are not sorted sets", + presetValues: map[string]interface{}{"ZmScoreKey3": "Default value"}, + command: []string{"ZMSCORE", "ZmScoreKey3", "one", "two", "three"}, + expectedError: errors.New("value at ZmScoreKey3 is not a sorted set"), + }, + { + name: "9. Command too short", + command: []string{"ZMSCORE"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - if res.Integer() != test.expectedResponse { - t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) - } + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue } - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(key), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } - - res, _, err = client.ReadValue() + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if len(res.Array()) != expectedSortedSet.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return } - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for i := 0; i < len(res.Array()); i++ { + if test.expectedResponse[i] != res.Array()[i].String() { + t.Errorf("expected element at index %d to be \"%s\", got %s", + i, test.expectedResponse[i], res.Array()[i].String()) } } - } - }) - } -} + }) + } + }) -func Test_HandleZREMRANGEBYSCORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleZSCORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Successfully remove multiple elements with scores inside the provided range", - presetValues: map[string]interface{}{ - "ZremRangeByScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse string + expectedError error + }{ + { + name: "1. Return score from a sorted set.", + presetValues: map[string]interface{}{ + "ZscoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, + {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZSCORE", "ZscoreKey1", "four"}, + expectedResponse: "4.055", + expectedError: nil, }, - command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey1", "3", "7"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremRangeByScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + { + name: "2. If key does not exist, return nil value", + presetValues: nil, + command: []string{"ZSCORE", "ZscoreKey2", "one"}, + expectedResponse: "", + expectedError: nil, }, - expectedResponse: 5, - expectedError: nil, - }, - { - name: "2. If key does not exist, return 0", - presetValues: nil, - command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey2", "2", "4"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. Return error key is not a sorted set", - presetValues: map[string]interface{}{ - "ZremRangeByScoreKey3": "Default value", + { + name: "3. If key exists and is a sorted set, but the member does not exist, return nil", + presetValues: map[string]interface{}{ + "ZscoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, + {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZSCORE", "ZscoreKey3", "non-existent"}, + expectedResponse: "", + expectedError: nil, }, - command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey3", "4", "4"}, - expectedError: errors.New("value at ZremRangeByScoreKey3 is not a sorted set"), - }, - { - name: "4. Command too short", - command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey4", "3"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "5. Command too long", - command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey5", "4", "5", "8"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + { + name: "4. Throw error when trying to find scores from elements that are not sorted sets", + presetValues: map[string]interface{}{"ZscoreKey4": "Default value"}, + command: []string{"ZSCORE", "ZscoreKey4", "one"}, + expectedError: errors.New("value at ZscoreKey4 is not a sorted set"), + }, + { + name: "5. Command too short", + command: []string{"ZSCORE"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + command: []string{"ZSCORE", "ZscoreKey5", "one", "two"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(key), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - - res, _, err = client.ReadValue() + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if len(res.Array()) != expectedSortedSet.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return } - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) - } + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) } - } - }) - } -} + }) + } + }) -func Test_HandleZREMRANGEBYRANK(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleZRANDMEMBER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Successfully remove multiple elements within range", - presetValues: map[string]interface{}{ - "ZremRangeByRankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + tests := []struct { + name string + key string + presetValue interface{} + command []string + expectedValue int // The final cardinality of the resulting set + allowRepeat bool + expectedResponse [][]string + expectedError error + }{ + { + // 1. Return multiple random elements without removing them. + // Count is positive, do not allow repeated elements + name: "1. Return multiple random elements without removing them.", + key: "ZrandMemberKey1", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), + command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "3"}, + expectedValue: 8, + allowRepeat: false, + expectedResponse: [][]string{ + {"one"}, {"two"}, {"three"}, {"four"}, + {"five"}, {"six"}, {"seven"}, {"eight"}, + }, + expectedError: nil, }, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey1", "0", "5"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremRangeByRankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + { + // 2. Return multiple random elements and their scores without removing them. + // Count is negative, so allow repeated numbers. + name: "2. Return multiple random elements and their scores without removing them.", + key: "ZrandMemberKey2", + presetValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), + command: []string{"ZRANDMEMBER", "ZrandMemberKey2", "-5", "WITHSCORES"}, + expectedValue: 8, + allowRepeat: true, + expectedResponse: [][]string{ + {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, + {"five", "5"}, {"six", "6"}, {"seven", "7"}, {"eight", "8"}, + }, + expectedError: nil, }, - expectedResponse: 6, - expectedError: nil, - }, - { - name: "2. Establish boundaries from the end of the set when negative boundaries are provided", - presetValues: map[string]interface{}{ - "ZremRangeByRankKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + { + name: "2. Return error when the source key is not a sorted set.", + key: "ZrandMemberKey3", + presetValue: "Default value", + command: []string{"ZRANDMEMBER", "ZrandMemberKey3"}, + expectedValue: 0, + expectedError: errors.New("value at ZrandMemberKey3 is not a sorted set"), }, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey2", "-6", "-3"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremRangeByRankKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + { + name: "5. Command too short", + command: []string{"ZRANDMEMBER"}, + expectedError: errors.New(constants.WrongArgsResponse), }, - expectedResponse: 4, - expectedError: nil, - }, - { - name: "3. If key does not exist, return 0", - presetValues: nil, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey3", "2", "4"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "4. Return error key is not a sorted set", - presetValues: map[string]interface{}{ - "ZremRangeByRankKey3": "Default value", + { + name: "6. Command too long", + command: []string{"ZRANDMEMBER", "source5", "source6", "member1", "member2"}, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey3", "4", "4"}, - expectedError: errors.New("value at ZremRangeByRankKey3 is not a sorted set"), - }, - { - name: "5. Return error when start index is out of bounds", - presetValues: map[string]interface{}{ - "ZremRangeByRankKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + { + name: "7. Throw error when count is not an integer", + command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "count"}, + expectedError: errors.New("count must be an integer"), }, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey5", "-12", "5"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: errors.New("indices out of bounds"), - }, - { - name: "6. Return error when end index is out of bounds", - presetValues: map[string]interface{}{ - "ZremRangeByRankKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - }), + { + name: "8. Throw error when the fourth argument is not WITHSCORES", + command: []string{"ZRANDMEMBER", "ZrandMemberKey1", "8", "ANOTHER"}, + expectedError: errors.New("last option must be WITHSCORES"), }, - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey6", "0", "11"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: errors.New("indices out of bounds"), - }, - { - name: "7. Command too short", - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey4", "3"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "8. Command too long", - command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey7", "4", "5", "8"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != nil { + var command []resp.Value + var expected string + + switch test.presetValue.(type) { case string: command = []resp.Value{ resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + resp.StringValue(test.key), + resp.StringValue(test.presetValue.(string)), } expected = "ok" case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(test.key)} + for _, member := range test.presetValue.(*sorted_set.SortedSet).GetAll() { command = append(command, []resp.Value{ resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), resp.StringValue(string(member.Value)), }...) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + expected = strconv.Itoa(test.presetValue.(*sorted_set.SortedSet).Cardinality()) } if err = client.WriteArray(command); err != nil { @@ -3115,2538 +2450,3305 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { } } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) } - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(key), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { + if err = client.WriteArray(command); err != nil { t.Error(err) } - - res, _, err = client.ReadValue() + res, _, err := client.ReadValue() if err != nil { t.Error(err) } - if len(res.Array()) != expectedSortedSet.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) + // Check that each of the returned elements is in the expected response. + for _, item := range res.Array() { + value := sorted_set.Value(item.Array()[0].String()) + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == string(value) + }) { + t.Errorf("unexected element \"%s\" in response", value) } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) - } - } - } - }) - } -} - -func Test_HandleZREMRANGEBYLEX(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedValues map[string]*sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Successfully remove multiple elements with scores inside the provided range", - presetValues: map[string]interface{}{ - "ZremRangeByLexKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, - {Value: "c", Score: 1}, {Value: "d", Score: 1}, - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - {Value: "i", Score: 1}, {Value: "j", Score: 1}, - }), - }, - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey1", "a", "d"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremRangeByLexKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - {Value: "i", Score: 1}, {Value: "j", Score: 1}, - }), - }, - expectedResponse: 4, - expectedError: nil, - }, - { - name: "2. Return 0 if the members do not have the same score", - presetValues: map[string]interface{}{ - "ZremRangeByLexKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 2}, - {Value: "c", Score: 3}, {Value: "d", Score: 4}, - {Value: "e", Score: 5}, {Value: "f", Score: 6}, - {Value: "g", Score: 7}, {Value: "h", Score: 8}, - {Value: "i", Score: 9}, {Value: "j", Score: 10}, - }), - }, - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey2", "d", "g"}, - expectedValues: map[string]*sorted_set.SortedSet{ - "ZremRangeByLexKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 2}, - {Value: "c", Score: 3}, {Value: "d", Score: 4}, - {Value: "e", Score: 5}, {Value: "f", Score: 6}, - {Value: "g", Score: 7}, {Value: "h", Score: 8}, - {Value: "i", Score: 9}, {Value: "j", Score: 10}, - }), - }, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "3. If key does not exist, return 0", - presetValues: nil, - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey3", "2", "4"}, - expectedValues: nil, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "4. Return error key is not a sorted set", - presetValues: map[string]interface{}{ - "ZremRangeByLexKey3": "Default value", - }, - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey3", "a", "d"}, - expectedError: errors.New("value at ZremRangeByLexKey3 is not a sorted set"), - }, - { - name: "5. Command too short", - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey4", "a"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "6. Command too long", - command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey5", "a", "b", "c"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - for key, expectedSortedSet := range test.expectedValues { - if expectedSortedSet == nil { - continue - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(key), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != expectedSortedSet.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - key, expectedSortedSet.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !expectedSortedSet.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if expectedSortedSet.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", - value, expectedSortedSet.Get(value).Score, score) - } - } - } - }) - } -} - -func Test_HandleZRANGE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Get elements withing score range without score.", - presetValues: map[string]interface{}{ - "ZrangeKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey1", "3", "7", "BYSCORE"}, - expectedResponse: [][]string{{"three"}, {"four"}, {"five"}, {"six"}, {"seven"}}, - expectedError: nil, - }, - { - name: "2. Get elements within score range with score.", - presetValues: map[string]interface{}{ - "ZrangeKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey2", "3", "7", "BYSCORE", "WITHSCORES"}, - expectedResponse: [][]string{ - {"three", "3"}, {"four", "4"}, {"five", "5"}, - {"six", "6"}, {"seven", "7"}}, - expectedError: nil, - }, - { - // 3. Get elements within score range with offset and limit. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "3. Get elements within score range with offset and limit.", - presetValues: map[string]interface{}{ - "ZrangeKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: [][]string{{"three", "3"}, {"four", "4"}, {"five", "5"}}, - expectedError: nil, - }, - { - // 4. Get elements within score range with offset and limit + reverse the results. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - // REV reverses the original set before getting the range. - name: "4. Get elements within score range with offset and limit + reverse the results.", - presetValues: map[string]interface{}{ - "ZrangeKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, - expectedResponse: [][]string{{"six", "6"}, {"five", "5"}, {"four", "4"}}, - expectedError: nil, - }, - { - name: "5. Get elements within lex range without score.", - presetValues: map[string]interface{}{ - "ZrangeKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "e", Score: 1}, - {Value: "b", Score: 1}, {Value: "f", Score: 1}, - {Value: "c", Score: 1}, {Value: "g", Score: 1}, - {Value: "d", Score: 1}, {Value: "h", Score: 1}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey5", "c", "g", "BYLEX"}, - expectedResponse: [][]string{{"c"}, {"d"}, {"e"}, {"f"}, {"g"}}, - expectedError: nil, - }, - { - name: "6. Get elements within lex range with score.", - presetValues: map[string]interface{}{ - "ZrangeKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "e", Score: 1}, - {Value: "b", Score: 1}, {Value: "f", Score: 1}, - {Value: "c", Score: 1}, {Value: "g", Score: 1}, - {Value: "d", Score: 1}, {Value: "h", Score: 1}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey6", "a", "f", "BYLEX", "WITHSCORES"}, - expectedResponse: [][]string{ - {"a", "1"}, {"b", "1"}, {"c", "1"}, - {"d", "1"}, {"e", "1"}, {"f", "1"}}, - expectedError: nil, - }, - { - // 7. Get elements within lex range with offset and limit. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "7. Get elements within lex range with offset and limit.", - presetValues: map[string]interface{}{ - "ZrangeKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, - {Value: "c", Score: 1}, {Value: "d", Score: 1}, - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: [][]string{{"c", "1"}, {"d", "1"}, {"e", "1"}}, - expectedError: nil, - }, - { - // 8. Get elements within lex range with offset and limit + reverse the results. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - // REV reverses the original set before getting the range. - name: "8. Get elements within lex range with offset and limit + reverse the results.", - presetValues: map[string]interface{}{ - "ZrangeKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, - {Value: "c", Score: 1}, {Value: "d", Score: 1}, - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, - expectedResponse: [][]string{{"f", "1"}, {"e", "1"}, {"d", "1"}}, - expectedError: nil, - }, - { - name: "9. Return an empty slice when we use BYLEX while elements have different scores", - presetValues: map[string]interface{}{ - "ZrangeKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 5}, - {Value: "c", Score: 2}, {Value: "d", Score: 6}, - {Value: "e", Score: 3}, {Value: "f", Score: 7}, - {Value: "g", Score: 4}, {Value: "h", Score: 8}, - }), - }, - command: []string{"ZRANGE", "ZrangeKey9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: [][]string{}, - expectedError: nil, - }, - { - name: "10. Throw error when limit does not provide both offset and limit", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, - expectedResponse: [][]string{}, - expectedError: errors.New("limit should contain offset and count as integers"), - }, - { - name: "11. Throw error when offset is not a valid integer", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, - expectedResponse: [][]string{}, - expectedError: errors.New("limit offset must be integer"), - }, - { - name: "12. Throw error when limit is not a valid integer", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, - expectedResponse: [][]string{}, - expectedError: errors.New("limit count must be integer"), - }, - { - name: "13. Throw error when offset is negative", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, - expectedResponse: [][]string{}, - expectedError: errors.New("limit offset must be >= 0"), - }, - { - name: "14. Throw error when the key does not hold a sorted set", - presetValues: map[string]interface{}{ - "ZrangeKey14": "Default value", - }, - command: []string{"ZRANGE", "ZrangeKey14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: [][]string{}, - expectedError: errors.New("value at ZrangeKey14 is not a sorted set"), - }, - { - name: "15. Command too short", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey15", "1"}, - expectedResponse: [][]string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "16. Command too long", - presetValues: nil, - command: []string{"ZRANGE", "ZrangeKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, - expectedResponse: [][]string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() - } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) - } - if score != "" { for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + if len(item.Array()) != len(expected) { + t.Errorf("expected response for element \"%s\" to have length %d, got %d", + value, len(expected), len(item.Array())) + } + if expected[0] != string(value) { + continue + } + if len(expected) == 2 { + score := item.Array()[1].String() + if expected[1] != score { + t.Errorf("expected score for memebr \"%s\" to be %s, got %s", value, expected[1], score) + } } } } - } - }) - } -} -func Test_HandleZRANGESTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Get elements withing score range without score.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZrangeStoreDestinationKey1", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey1", "ZrangeStoreKey1", "3", "7", "BYSCORE"}, - expectedResponse: 5, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, - {Value: "six", Score: 6}, {Value: "seven", Score: 7}, - }), - expectedError: nil, - }, - { - name: "2. Get elements within score range with score.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZrangeStoreDestinationKey2", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey2", "ZrangeStoreKey2", "3", "7", "BYSCORE", "WITHSCORES"}, - expectedResponse: 5, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, - {Value: "six", Score: 6}, {Value: "seven", Score: 7}, - }), - expectedError: nil, - }, - { - // 3. Get elements within score range with offset and limit. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "3. Get elements within score range with offset and limit.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZrangeStoreDestinationKey3", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey3", "ZrangeStoreKey3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: 3, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, - }), - expectedError: nil, - }, - { - // 4. Get elements within score range with offset and limit + reverse the results. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - // REV reverses the original set before getting the range. - name: "4. Get elements within score range with offset and limit + reverse the results.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZrangeStoreDestinationKey4", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey4", "ZrangeStoreKey4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, - expectedResponse: 3, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "six", Score: 6}, {Value: "five", Score: 5}, {Value: "four", Score: 4}, - }), - expectedError: nil, - }, - { - name: "5. Get elements within lex range without score.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "e", Score: 1}, - {Value: "b", Score: 1}, {Value: "f", Score: 1}, - {Value: "c", Score: 1}, {Value: "g", Score: 1}, - {Value: "d", Score: 1}, {Value: "h", Score: 1}, - }), - }, - destination: "ZrangeStoreDestinationKey5", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey5", "ZrangeStoreKey5", "c", "g", "BYLEX"}, - expectedResponse: 5, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "c", Score: 1}, {Value: "d", Score: 1}, {Value: "e", Score: 1}, - {Value: "f", Score: 1}, {Value: "g", Score: 1}, - }), - expectedError: nil, - }, - { - name: "6. Get elements within lex range with score.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "e", Score: 1}, - {Value: "b", Score: 1}, {Value: "f", Score: 1}, - {Value: "c", Score: 1}, {Value: "g", Score: 1}, - {Value: "d", Score: 1}, {Value: "h", Score: 1}, - }), - }, - destination: "ZrangeStoreDestinationKey6", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey6", "ZrangeStoreKey6", "a", "f", "BYLEX", "WITHSCORES"}, - expectedResponse: 6, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, {Value: "c", Score: 1}, - {Value: "d", Score: 1}, {Value: "e", Score: 1}, {Value: "f", Score: 1}, - }), - expectedError: nil, - }, - { - // 7. Get elements within lex range with offset and limit. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - name: "7. Get elements within lex range with offset and limit.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, - {Value: "c", Score: 1}, {Value: "d", Score: 1}, - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - }), - }, - destination: "ZrangeStoreDestinationKey7", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey7", "ZrangeStoreKey7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: 3, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "c", Score: 1}, {Value: "d", Score: 1}, {Value: "e", Score: 1}, - }), - expectedError: nil, - }, - { - // 8. Get elements within lex range with offset and limit + reverse the results. - // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). - // REV reverses the original set before getting the range. - name: "8. Get elements within lex range with offset and limit + reverse the results.", - presetValues: map[string]interface{}{ - "ZrangeStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 1}, - {Value: "c", Score: 1}, {Value: "d", Score: 1}, - {Value: "e", Score: 1}, {Value: "f", Score: 1}, - {Value: "g", Score: 1}, {Value: "h", Score: 1}, - }), - }, - destination: "ZrangeStoreDestinationKey8", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey8", "ZrangeStoreKey8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, - expectedResponse: 3, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "f", Score: 1}, {Value: "e", Score: 1}, {Value: "d", Score: 1}, - }), - expectedError: nil, - }, - { - name: "9. Return an empty slice when we use BYLEX while elements have different scores", - presetValues: map[string]interface{}{ - "ZrangeStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "a", Score: 1}, {Value: "b", Score: 5}, - {Value: "c", Score: 2}, {Value: "d", Score: 6}, - {Value: "e", Score: 3}, {Value: "f", Score: 7}, - {Value: "g", Score: 4}, {Value: "h", Score: 8}, - }), - }, - destination: "ZrangeStoreDestinationKey9", - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey9", "ZrangeStoreKey9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: 0, - expectedValue: nil, - expectedError: nil, - }, - { - name: "10. Throw error when limit does not provide both offset and limit", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey10", "ZrangeStoreKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, - expectedResponse: 0, - expectedError: errors.New("limit should contain offset and count as integers"), - }, - { - name: "11. Throw error when offset is not a valid integer", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey11", "ZrangeStoreKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, - expectedResponse: 0, - expectedError: errors.New("limit offset must be integer"), - }, - { - name: "12. Throw error when limit is not a valid integer", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey12", "ZrangeStoreKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, - expectedResponse: 0, - expectedError: errors.New("limit count must be integer"), - }, - { - name: "13. Throw error when offset is negative", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey13", "ZrangeStoreKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, - expectedResponse: 0, - expectedError: errors.New("limit offset must be >= 0"), - }, - { - name: "14. Throw error when the key does not hold a sorted set", - presetValues: map[string]interface{}{ - "ZrangeStoreKey14": "Default value", - }, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey14", "ZrangeStoreKey14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, - expectedResponse: 0, - expectedError: errors.New("value at ZrangeStoreKey14 is not a sorted set"), - }, - { - name: "15. Command too short", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreKey15", "1"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "16 Command too long", - presetValues: nil, - command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey16", "ZrangeStoreKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + // Check that allowRepeat determines whether elements are repeated or not. + if !test.allowRepeat { + ss := sorted_set.NewSortedSet([]sorted_set.MemberParam{}) + for _, item := range res.Array() { + member := sorted_set.Value(item.Array()[0].String()) + score := func() sorted_set.Score { + if len(item.Array()) == 2 { + return sorted_set.Score(item.Array()[1].Float()) + } + return sorted_set.Score(0) + }() + _, err = ss.AddOrUpdate( + []sorted_set.MemberParam{{member, score}}, + nil, nil, nil, nil) + if err != nil { + t.Error(err) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + } + if len(res.Array()) != ss.Cardinality() { + t.Error("unexpected repeated elements in response") + } + } + }) + } + }) + + t.Run("Test_HandleZRANK", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse []string + expectedError error + }{ + { + name: "1. Return element's rank from a sorted set.", + presetValues: map[string]interface{}{ + "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZRANK", "ZrankKey1", "four"}, + expectedResponse: []string{"3"}, + expectedError: nil, + }, + { + name: "2. Return element's rank from a sorted set with its score.", + presetValues: map[string]interface{}{ + "ZrankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100.1}, {Value: "two", Score: 245}, + {Value: "three", Score: 305.43}, {Value: "four", Score: 411.055}, + {Value: "five", Score: 500}, + }), + }, + command: []string{"ZRANK", "ZrankKey1", "four", "WITHSCORES"}, + expectedResponse: []string{"3", "411.055"}, + expectedError: nil, + }, + { + name: "3. If key does not exist, return nil value", + presetValues: nil, + command: []string{"ZRANK", "ZrankKey3", "one"}, + expectedResponse: nil, + expectedError: nil, + }, + { + name: "4. If key exists and is a sorted set, but the member does not exist, return nil", + presetValues: map[string]interface{}{ + "ZrankKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1.1}, {Value: "two", Score: 245}, + {Value: "three", Score: 3}, {Value: "four", Score: 4.055}, + {Value: "five", Score: 5}, + }), + }, + command: []string{"ZRANK", "ZrankKey4", "non-existent"}, + expectedResponse: nil, + expectedError: nil, + }, + { + name: "5. Throw error when trying to find scores from elements that are not sorted sets", + presetValues: map[string]interface{}{"ZrankKey5": "Default value"}, + command: []string{"ZRANK", "ZrankKey5", "one"}, + expectedError: errors.New("value at ZrankKey5 is not a sorted set"), + }, + { + name: "5. Command too short", + command: []string{"ZRANK"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + command: []string{"ZRANK", "ZrankKey5", "one", "WITHSCORES", "two"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - if err = client.WriteArray(command); err != nil { + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for i := 0; i < len(res.Array()); i++ { + if test.expectedResponse[i] != res.Array()[i].String() { + t.Errorf("expected element at index %d to be \"%s\", got %s", + i, test.expectedResponse[i], res.Array()[i].String()) + } + } + }) + } + }) + + t.Run("Test_HandleZREM", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + // Successfully remove multiple elements from sorted set, skipping non-existent members. + // Return deleted count. + name: "1. Successfully remove multiple elements from sorted set, skipping non-existent members.", + presetValues: map[string]interface{}{ + "ZremKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREM", "ZremKey1", "three", "four", "five", "none", "six", "none", "seven"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "2. If key does not exist, return 0", + presetValues: nil, + command: []string{"ZREM", "ZremKey2", "member"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Return error key is not a sorted set", + presetValues: map[string]interface{}{ + "ZremKey3": "Default value", + }, + command: []string{"ZREM", "ZremKey3", "member"}, + expectedError: errors.New("value at ZremKey3 is not a sorted set"), + }, + { + name: "9. Command too short", + command: []string{"ZREM"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - if test.expectedValue == nil { - return - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(test.destination), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !test.expectedValue.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if test.expectedValue.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) - } - } - }) - } -} - -func Test_HandleZINTER(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Get the intersection between 2 sorted sets.", - presetValues: map[string]interface{}{ - "ZinterKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - "ZinterKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - command: []string{"ZINTER", "ZinterKey1", "ZinterKey2"}, - expectedResponse: [][]string{{"three"}, {"four"}, {"five"}}, - expectedError: nil, - }, - { - // 2. Get the intersection between 3 sorted sets with scores. - // By default, the SUM aggregate will be used. - name: "2. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, - }), - "ZinterKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey3", "ZinterKey4", "ZinterKey5", "WITHSCORES"}, - expectedResponse: [][]string{{"one", "3"}, {"eight", "24"}}, - expectedError: nil, - }, - { - // 3. Get the intersection between 3 sorted sets with scores. - // Use MIN aggregate. - name: "3. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey6", "ZinterKey7", "ZinterKey8", "WITHSCORES", "AGGREGATE", "MIN"}, - expectedResponse: [][]string{{"one", "1"}, {"eight", "8"}}, - expectedError: nil, - }, - { - // 4. Get the intersection between 3 sorted sets with scores. - // Use MAX aggregate. - name: "4. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey9", "ZinterKey10", "ZinterKey11", "WITHSCORES", "AGGREGATE", "MAX"}, - expectedResponse: [][]string{{"one", "1000"}, {"eight", "800"}}, - expectedError: nil, - }, - { - // 5. Get the intersection between 3 sorted sets with scores. - // Use SUM aggregate with weights modifier. - name: "5. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey12", "ZinterKey13", "ZinterKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, - expectedResponse: [][]string{{"one", "3105"}, {"eight", "2808"}}, - expectedError: nil, - }, - { - // 6. Get the intersection between 3 sorted sets with scores. - // Use MAX aggregate with added weights. - name: "6. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey15", "ZinterKey16", "ZinterKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, - expectedResponse: [][]string{{"one", "3000"}, {"eight", "2400"}}, - expectedError: nil, - }, - { - // 7. Get the intersection between 3 sorted sets with scores. - // Use MIN aggregate with added weights. - name: "7. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "ZinterKey18", "ZinterKey19", "ZinterKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, - expectedResponse: [][]string{{"one", "5"}, {"eight", "8"}}, - expectedError: nil, - }, - { - name: "8. Throw an error if there are more weights than keys", - presetValues: map[string]interface{}{ - "ZinterKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTER", "ZinterKey21", "ZinterKey22", "WEIGHTS", "1", "2", "3"}, - expectedResponse: nil, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "9. Throw an error if there are fewer weights than keys", - presetValues: map[string]interface{}{ - "ZinterKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - }), - "ZinterKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTER", "ZinterKey23", "ZinterKey24", "ZinterKey25", "WEIGHTS", "5", "4"}, - expectedResponse: nil, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "10. Throw an error if there are no keys provided", - presetValues: map[string]interface{}{ - "ZinterKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZinterKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZinterKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTER", "WEIGHTS", "5", "4"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "11. Throw an error if any of the provided keys are not sorted sets", - presetValues: map[string]interface{}{ - "ZinterKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterKey30": "Default value", - "ZinterKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTER", "ZinterKey29", "ZinterKey30", "ZinterKey31"}, - expectedResponse: nil, - expectedError: errors.New("value at ZinterKey30 is not a sorted set"), - }, - { - name: "12. If any of the keys does not exist, return an empty array.", - presetValues: map[string]interface{}{ - "ZinterKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZinterKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTER", "non-existent", "ZinterKey32", "ZinterKey33"}, - expectedResponse: [][]string{}, - expectedError: nil, - }, - { - name: "13. Command too short", - command: []string{"ZINTER"}, - expectedResponse: [][]string{}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) } - if err = client.WriteArray(command); err != nil { + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) + } + } + } + }) + } + }) + + t.Run("Test_HandleZREMRANGEBYSCORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Successfully remove multiple elements with scores inside the provided range", + presetValues: map[string]interface{}{ + "ZremRangeByScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey1", "3", "7"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremRangeByScoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + expectedResponse: 5, + expectedError: nil, + }, + { + name: "2. If key does not exist, return 0", + presetValues: nil, + command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey2", "2", "4"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. Return error key is not a sorted set", + presetValues: map[string]interface{}{ + "ZremRangeByScoreKey3": "Default value", + }, + command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey3", "4", "4"}, + expectedError: errors.New("value at ZremRangeByScoreKey3 is not a sorted set"), + }, + { + name: "4. Command too short", + command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey4", "3"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "5. Command too long", + command: []string{"ZREMRANGEBYSCORE", "ZremRangeByScoreKey5", "4", "5", "8"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() - } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) - } - if score != "" { - for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) } } } - } - }) - } -} + }) + } + }) -func Test_HandleZINTERSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) + t.Run("Test_HandleZREMRANGEBYRANK", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - tests := []struct { - name string - presetValues map[string]interface{} - destination string - command []string - expectedValue *sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Get the intersection between 2 sorted sets.", - presetValues: map[string]interface{}{ - "ZinterStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - "ZinterStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Successfully remove multiple elements within range", + presetValues: map[string]interface{}{ + "ZremRangeByRankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey1", "0", "5"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremRangeByRankKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + expectedResponse: 6, + expectedError: nil, }, - destination: "ZinterStoreDestinationKey1", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey1", "ZinterStoreKey1", "ZinterStoreKey2"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 6}, {Value: "four", Score: 8}, - {Value: "five", Score: 10}, - }), - expectedResponse: 3, - expectedError: nil, - }, - { - // 2. Get the intersection between 3 sorted sets with scores. - // By default, the SUM aggregate will be used. - name: "2. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "2. Establish boundaries from the end of the set when negative boundaries are provided", + presetValues: map[string]interface{}{ + "ZremRangeByRankKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey2", "-6", "-3"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremRangeByRankKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + expectedResponse: 4, + expectedError: nil, }, - destination: "ZinterStoreDestinationKey2", - command: []string{ - "ZINTERSTORE", "ZinterStoreDestinationKey2", "ZinterStoreKey3", "ZinterStoreKey4", "ZinterStoreKey5", "WITHSCORES", + { + name: "3. If key does not exist, return 0", + presetValues: nil, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey3", "2", "4"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: nil, }, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3}, {Value: "eight", Score: 24}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - // 3. Get the intersection between 3 sorted sets with scores. - // Use MIN aggregate. - name: "3. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "4. Return error key is not a sorted set", + presetValues: map[string]interface{}{ + "ZremRangeByRankKey3": "Default value", + }, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey3", "4", "4"}, + expectedError: errors.New("value at ZremRangeByRankKey3 is not a sorted set"), }, - destination: "ZinterStoreDestinationKey3", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey3", "ZinterStoreKey6", "ZinterStoreKey7", "ZinterStoreKey8", "WITHSCORES", "AGGREGATE", "MIN"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 8}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - // 4. Get the intersection between 3 sorted sets with scores. - // Use MAX aggregate. - name: "4. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "5. Return error when start index is out of bounds", + presetValues: map[string]interface{}{ + "ZremRangeByRankKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey5", "-12", "5"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: errors.New("indices out of bounds"), }, - destination: "ZinterStoreDestinationKey4", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey4", "ZinterStoreKey9", "ZinterStoreKey10", "ZinterStoreKey11", "WITHSCORES", "AGGREGATE", "MAX"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - // 5. Get the intersection between 3 sorted sets with scores. - // Use SUM aggregate with weights modifier. - name: "5. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterStoreKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "6. Return error when end index is out of bounds", + presetValues: map[string]interface{}{ + "ZremRangeByRankKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey6", "0", "11"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: errors.New("indices out of bounds"), }, - destination: "ZinterStoreDestinationKey5", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey5", "ZinterStoreKey12", "ZinterStoreKey13", "ZinterStoreKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3105}, {Value: "eight", Score: 2808}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - // 6. Get the intersection between 3 sorted sets with scores. - // Use MAX aggregate with added weights. - name: "6. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterStoreKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "7. Command too short", + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey4", "3"}, + expectedError: errors.New(constants.WrongArgsResponse), }, - destination: "ZinterStoreDestinationKey6", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey6", "ZinterStoreKey15", "ZinterStoreKey16", "ZinterStoreKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3000}, {Value: "eight", Score: 2400}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - // 7. Get the intersection between 3 sorted sets with scores. - // Use MIN aggregate with added weights. - name: "7. Get the intersection between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZinterStoreKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZinterStoreKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "8. Command too long", + command: []string{"ZREMRANGEBYRANK", "ZremRangeByRankKey7", "4", "5", "8"}, + expectedError: errors.New(constants.WrongArgsResponse), }, - destination: "ZinterStoreDestinationKey7", - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey7", "ZinterStoreKey18", "ZinterStoreKey19", "ZinterStoreKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 5}, {Value: "eight", Score: 8}, - }), - expectedResponse: 2, - expectedError: nil, - }, - { - name: "8. Throw an error if there are more weights than keys", - presetValues: map[string]interface{}{ - "ZinterStoreKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey8", "ZinterStoreKey21", "ZinterStoreKey22", "WEIGHTS", "1", "2", "3"}, - expectedResponse: 0, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "9. Throw an error if there are fewer weights than keys", - presetValues: map[string]interface{}{ - "ZinterStoreKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - }), - "ZinterStoreKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey9", "ZinterStoreKey23", "ZinterStoreKey24", "ZinterStoreKey25", "WEIGHTS", "5", "4"}, - expectedResponse: 0, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "10. Throw an error if there are no keys provided", - presetValues: map[string]interface{}{ - "ZinterStoreKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZinterStoreKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZinterStoreKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTERSTORE", "WEIGHTS", "5", "4"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "11. Throw an error if any of the provided keys are not sorted sets", - presetValues: map[string]interface{}{ - "ZinterStoreKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZinterStoreKey30": "Default value", - "ZinterStoreKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZINTERSTORE", "ZinterStoreKey29", "ZinterStoreKey30", "ZinterStoreKey31"}, - expectedResponse: 0, - expectedError: errors.New("value at ZinterStoreKey30 is not a sorted set"), - }, - { - name: "12. If any of the keys does not exist, return an empty array.", - presetValues: map[string]interface{}{ - "ZinterStoreKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZinterStoreKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey12", "non-existent", "ZinterStoreKey32", "ZinterStoreKey33"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "13. Command too short", - command: []string{"ZINTERSTORE"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - if err = client.WriteArray(command); err != nil { + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { t.Error(err) } - res, _, err := client.ReadValue() + + res, _, err = client.ReadValue() if err != nil { t.Error(err) } - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) + } } } - } + }) + } + }) - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } + t.Run("Test_HandleZREMRANGEBYLEX", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - if test.expectedValue == nil { - return - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(test.destination), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !test.expectedValue.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if test.expectedValue.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) - } - } - }) - } -} - -func Test_HandleZUNION(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - presetValues map[string]interface{} - command []string - expectedResponse [][]string - expectedError error - }{ - { - name: "1. Get the union between 2 sorted sets.", - presetValues: map[string]interface{}{ - "ZunionKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - "ZunionKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedValues map[string]*sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Successfully remove multiple elements with scores inside the provided range", + presetValues: map[string]interface{}{ + "ZremRangeByLexKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, + {Value: "c", Score: 1}, {Value: "d", Score: 1}, + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + {Value: "i", Score: 1}, {Value: "j", Score: 1}, + }), + }, + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey1", "a", "d"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremRangeByLexKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + {Value: "i", Score: 1}, {Value: "j", Score: 1}, + }), + }, + expectedResponse: 4, + expectedError: nil, }, - command: []string{"ZUNION", "ZunionKey1", "ZunionKey2"}, - expectedResponse: [][]string{{"one"}, {"two"}, {"three"}, {"four"}, {"five"}, {"six"}, {"seven"}, {"eight"}}, - expectedError: nil, - }, - { - // 2. Get the union between 3 sorted sets with scores. - // By default, the SUM aggregate will be used. - name: "2. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + { + name: "2. Return 0 if the members do not have the same score", + presetValues: map[string]interface{}{ + "ZremRangeByLexKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 2}, + {Value: "c", Score: 3}, {Value: "d", Score: 4}, + {Value: "e", Score: 5}, {Value: "f", Score: 6}, + {Value: "g", Score: 7}, {Value: "h", Score: 8}, + {Value: "i", Score: 9}, {Value: "j", Score: 10}, + }), + }, + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey2", "d", "g"}, + expectedValues: map[string]*sorted_set.SortedSet{ + "ZremRangeByLexKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 2}, + {Value: "c", Score: 3}, {Value: "d", Score: 4}, + {Value: "e", Score: 5}, {Value: "f", Score: 6}, + {Value: "g", Score: 7}, {Value: "h", Score: 8}, + {Value: "i", Score: 9}, {Value: "j", Score: 10}, + }), + }, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "3. If key does not exist, return 0", + presetValues: nil, + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey3", "2", "4"}, + expectedValues: nil, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "4. Return error key is not a sorted set", + presetValues: map[string]interface{}{ + "ZremRangeByLexKey3": "Default value", + }, + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey3", "a", "d"}, + expectedError: errors.New("value at ZremRangeByLexKey3 is not a sorted set"), + }, + { + name: "5. Command too short", + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey4", "a"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "6. Command too long", + command: []string{"ZREMRANGEBYLEX", "ZremRangeByLexKey5", "a", "b", "c"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response array of length %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + for key, expectedSortedSet := range test.expectedValues { + if expectedSortedSet == nil { + continue + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(key), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != expectedSortedSet.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + key, expectedSortedSet.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !expectedSortedSet.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if expectedSortedSet.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", + value, expectedSortedSet.Get(value).Score, score) + } + } + } + }) + } + }) + + t.Run("Test_HandleZRANGE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Get elements withing score range without score.", + presetValues: map[string]interface{}{ + "ZrangeKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey1", "3", "7", "BYSCORE"}, + expectedResponse: [][]string{{"three"}, {"four"}, {"five"}, {"six"}, {"seven"}}, + expectedError: nil, + }, + { + name: "2. Get elements within score range with score.", + presetValues: map[string]interface{}{ + "ZrangeKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey2", "3", "7", "BYSCORE", "WITHSCORES"}, + expectedResponse: [][]string{ + {"three", "3"}, {"four", "4"}, {"five", "5"}, + {"six", "6"}, {"seven", "7"}}, + expectedError: nil, + }, + { + // 3. Get elements within score range with offset and limit. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + name: "3. Get elements within score range with offset and limit.", + presetValues: map[string]interface{}{ + "ZrangeKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: [][]string{{"three", "3"}, {"four", "4"}, {"five", "5"}}, + expectedError: nil, + }, + { + // 4. Get elements within score range with offset and limit + reverse the results. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + // REV reverses the original set before getting the range. + name: "4. Get elements within score range with offset and limit + reverse the results.", + presetValues: map[string]interface{}{ + "ZrangeKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + expectedResponse: [][]string{{"six", "6"}, {"five", "5"}, {"four", "4"}}, + expectedError: nil, + }, + { + name: "5. Get elements within lex range without score.", + presetValues: map[string]interface{}{ + "ZrangeKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "e", Score: 1}, + {Value: "b", Score: 1}, {Value: "f", Score: 1}, + {Value: "c", Score: 1}, {Value: "g", Score: 1}, + {Value: "d", Score: 1}, {Value: "h", Score: 1}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey5", "c", "g", "BYLEX"}, + expectedResponse: [][]string{{"c"}, {"d"}, {"e"}, {"f"}, {"g"}}, + expectedError: nil, + }, + { + name: "6. Get elements within lex range with score.", + presetValues: map[string]interface{}{ + "ZrangeKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "e", Score: 1}, + {Value: "b", Score: 1}, {Value: "f", Score: 1}, + {Value: "c", Score: 1}, {Value: "g", Score: 1}, + {Value: "d", Score: 1}, {Value: "h", Score: 1}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey6", "a", "f", "BYLEX", "WITHSCORES"}, + expectedResponse: [][]string{ + {"a", "1"}, {"b", "1"}, {"c", "1"}, + {"d", "1"}, {"e", "1"}, {"f", "1"}}, + expectedError: nil, + }, + { + // 7. Get elements within lex range with offset and limit. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + name: "7. Get elements within lex range with offset and limit.", + presetValues: map[string]interface{}{ + "ZrangeKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, + {Value: "c", Score: 1}, {Value: "d", Score: 1}, + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: [][]string{{"c", "1"}, {"d", "1"}, {"e", "1"}}, + expectedError: nil, + }, + { + // 8. Get elements within lex range with offset and limit + reverse the results. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + // REV reverses the original set before getting the range. + name: "8. Get elements within lex range with offset and limit + reverse the results.", + presetValues: map[string]interface{}{ + "ZrangeKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, + {Value: "c", Score: 1}, {Value: "d", Score: 1}, + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + expectedResponse: [][]string{{"f", "1"}, {"e", "1"}, {"d", "1"}}, + expectedError: nil, + }, + { + name: "9. Return an empty slice when we use BYLEX while elements have different scores", + presetValues: map[string]interface{}{ + "ZrangeKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 5}, + {Value: "c", Score: 2}, {Value: "d", Score: 6}, + {Value: "e", Score: 3}, {Value: "f", Score: 7}, + {Value: "g", Score: 4}, {Value: "h", Score: 8}, + }), + }, + command: []string{"ZRANGE", "ZrangeKey9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: [][]string{}, + expectedError: nil, + }, + { + name: "10. Throw error when limit does not provide both offset and limit", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, + expectedResponse: [][]string{}, + expectedError: errors.New("limit should contain offset and count as integers"), + }, + { + name: "11. Throw error when offset is not a valid integer", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, + expectedResponse: [][]string{}, + expectedError: errors.New("limit offset must be integer"), + }, + { + name: "12. Throw error when limit is not a valid integer", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, + expectedResponse: [][]string{}, + expectedError: errors.New("limit count must be integer"), + }, + { + name: "13. Throw error when offset is negative", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, + expectedResponse: [][]string{}, + expectedError: errors.New("limit offset must be >= 0"), + }, + { + name: "14. Throw error when the key does not hold a sorted set", + presetValues: map[string]interface{}{ + "ZrangeKey14": "Default value", + }, + command: []string{"ZRANGE", "ZrangeKey14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: [][]string{}, + expectedError: errors.New("value at ZrangeKey14 is not a sorted set"), + }, + { + name: "15. Command too short", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey15", "1"}, + expectedResponse: [][]string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "16. Command too long", + presetValues: nil, + command: []string{"ZRANGE", "ZrangeKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, + expectedResponse: [][]string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() + } + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } + } + } + } + }) + } + }) + + t.Run("Test_HandleZRANGESTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Get elements withing score range without score.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZrangeStoreDestinationKey1", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey1", "ZrangeStoreKey1", "3", "7", "BYSCORE"}, + expectedResponse: 5, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, + {Value: "six", Score: 6}, {Value: "seven", Score: 7}, + }), + expectedError: nil, + }, + { + name: "2. Get elements within score range with score.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZrangeStoreDestinationKey2", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey2", "ZrangeStoreKey2", "3", "7", "BYSCORE", "WITHSCORES"}, + expectedResponse: 5, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, + {Value: "six", Score: 6}, {Value: "seven", Score: 7}, + }), + expectedError: nil, + }, + { + // 3. Get elements within score range with offset and limit. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + name: "3. Get elements within score range with offset and limit.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZrangeStoreDestinationKey3", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey3", "ZrangeStoreKey3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: 3, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, {Value: "five", Score: 5}, + }), + expectedError: nil, + }, + { + // 4. Get elements within score range with offset and limit + reverse the results. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + // REV reverses the original set before getting the range. + name: "4. Get elements within score range with offset and limit + reverse the results.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZrangeStoreDestinationKey4", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey4", "ZrangeStoreKey4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + expectedResponse: 3, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "six", Score: 6}, {Value: "five", Score: 5}, {Value: "four", Score: 4}, + }), + expectedError: nil, + }, + { + name: "5. Get elements within lex range without score.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "e", Score: 1}, + {Value: "b", Score: 1}, {Value: "f", Score: 1}, + {Value: "c", Score: 1}, {Value: "g", Score: 1}, + {Value: "d", Score: 1}, {Value: "h", Score: 1}, + }), + }, + destination: "ZrangeStoreDestinationKey5", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey5", "ZrangeStoreKey5", "c", "g", "BYLEX"}, + expectedResponse: 5, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "c", Score: 1}, {Value: "d", Score: 1}, {Value: "e", Score: 1}, + {Value: "f", Score: 1}, {Value: "g", Score: 1}, + }), + expectedError: nil, + }, + { + name: "6. Get elements within lex range with score.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "e", Score: 1}, + {Value: "b", Score: 1}, {Value: "f", Score: 1}, + {Value: "c", Score: 1}, {Value: "g", Score: 1}, + {Value: "d", Score: 1}, {Value: "h", Score: 1}, + }), + }, + destination: "ZrangeStoreDestinationKey6", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey6", "ZrangeStoreKey6", "a", "f", "BYLEX", "WITHSCORES"}, + expectedResponse: 6, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, {Value: "c", Score: 1}, + {Value: "d", Score: 1}, {Value: "e", Score: 1}, {Value: "f", Score: 1}, + }), + expectedError: nil, + }, + { + // 7. Get elements within lex range with offset and limit. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + name: "7. Get elements within lex range with offset and limit.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, + {Value: "c", Score: 1}, {Value: "d", Score: 1}, + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + }), + }, + destination: "ZrangeStoreDestinationKey7", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey7", "ZrangeStoreKey7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: 3, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "c", Score: 1}, {Value: "d", Score: 1}, {Value: "e", Score: 1}, + }), + expectedError: nil, + }, + { + // 8. Get elements within lex range with offset and limit + reverse the results. + // Offset and limit are in where we start and stop counting in the original sorted set (NOT THE RESULT). + // REV reverses the original set before getting the range. + name: "8. Get elements within lex range with offset and limit + reverse the results.", + presetValues: map[string]interface{}{ + "ZrangeStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 1}, + {Value: "c", Score: 1}, {Value: "d", Score: 1}, + {Value: "e", Score: 1}, {Value: "f", Score: 1}, + {Value: "g", Score: 1}, {Value: "h", Score: 1}, + }), + }, + destination: "ZrangeStoreDestinationKey8", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey8", "ZrangeStoreKey8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + expectedResponse: 3, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "f", Score: 1}, {Value: "e", Score: 1}, {Value: "d", Score: 1}, + }), + expectedError: nil, + }, + { + name: "9. Return an empty slice when we use BYLEX while elements have different scores", + presetValues: map[string]interface{}{ + "ZrangeStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "a", Score: 1}, {Value: "b", Score: 5}, + {Value: "c", Score: 2}, {Value: "d", Score: 6}, + {Value: "e", Score: 3}, {Value: "f", Score: 7}, + {Value: "g", Score: 4}, {Value: "h", Score: 8}, + }), + }, + destination: "ZrangeStoreDestinationKey9", + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey9", "ZrangeStoreKey9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: 0, + expectedValue: nil, + expectedError: nil, + }, + { + name: "10. Throw error when limit does not provide both offset and limit", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey10", "ZrangeStoreKey10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, + expectedResponse: 0, + expectedError: errors.New("limit should contain offset and count as integers"), + }, + { + name: "11. Throw error when offset is not a valid integer", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey11", "ZrangeStoreKey11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, + expectedResponse: 0, + expectedError: errors.New("limit offset must be integer"), + }, + { + name: "12. Throw error when limit is not a valid integer", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey12", "ZrangeStoreKey12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, + expectedResponse: 0, + expectedError: errors.New("limit count must be integer"), + }, + { + name: "13. Throw error when offset is negative", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey13", "ZrangeStoreKey13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, + expectedResponse: 0, + expectedError: errors.New("limit offset must be >= 0"), + }, + { + name: "14. Throw error when the key does not hold a sorted set", + presetValues: map[string]interface{}{ + "ZrangeStoreKey14": "Default value", + }, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey14", "ZrangeStoreKey14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + expectedResponse: 0, + expectedError: errors.New("value at ZrangeStoreKey14 is not a sorted set"), + }, + { + name: "15. Command too short", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreKey15", "1"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "16 Command too long", + presetValues: nil, + command: []string{"ZRANGESTORE", "ZrangeStoreDestinationKey16", "ZrangeStoreKey16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) + } + } + }) + } + }) + + t.Run("Test_HandleZINTER", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Get the intersection between 2 sorted sets.", + presetValues: map[string]interface{}{ + "ZinterKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + "ZinterKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZINTER", "ZinterKey1", "ZinterKey2"}, + expectedResponse: [][]string{{"three"}, {"four"}, {"five"}}, + expectedError: nil, + }, + { + // 2. Get the intersection between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + name: "2. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, + }), + "ZinterKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey3", "ZinterKey4", "ZinterKey5", "WITHSCORES"}, + expectedResponse: [][]string{{"one", "3"}, {"eight", "24"}}, + expectedError: nil, + }, + { + // 3. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate. + name: "3. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey6", "ZinterKey7", "ZinterKey8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedResponse: [][]string{{"one", "1"}, {"eight", "8"}}, + expectedError: nil, + }, + { + // 4. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate. + name: "4. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey9", "ZinterKey10", "ZinterKey11", "WITHSCORES", "AGGREGATE", "MAX"}, + expectedResponse: [][]string{{"one", "1000"}, {"eight", "800"}}, + expectedError: nil, + }, + { + // 5. Get the intersection between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + name: "5. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey12", "ZinterKey13", "ZinterKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, + expectedResponse: [][]string{{"one", "3105"}, {"eight", "2808"}}, + expectedError: nil, + }, + { + // 6. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + name: "6. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey15", "ZinterKey16", "ZinterKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, + expectedResponse: [][]string{{"one", "3000"}, {"eight", "2400"}}, + expectedError: nil, + }, + { + // 7. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + name: "7. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "ZinterKey18", "ZinterKey19", "ZinterKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, + expectedResponse: [][]string{{"one", "5"}, {"eight", "8"}}, + expectedError: nil, + }, + { + name: "8. Throw an error if there are more weights than keys", + presetValues: map[string]interface{}{ + "ZinterKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTER", "ZinterKey21", "ZinterKey22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "9. Throw an error if there are fewer weights than keys", + presetValues: map[string]interface{}{ + "ZinterKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + }), + "ZinterKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTER", "ZinterKey23", "ZinterKey24", "ZinterKey25", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "10. Throw an error if there are no keys provided", + presetValues: map[string]interface{}{ + "ZinterKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZinterKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZinterKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTER", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "11. Throw an error if any of the provided keys are not sorted sets", + presetValues: map[string]interface{}{ + "ZinterKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterKey30": "Default value", + "ZinterKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTER", "ZinterKey29", "ZinterKey30", "ZinterKey31"}, + expectedResponse: nil, + expectedError: errors.New("value at ZinterKey30 is not a sorted set"), + }, + { + name: "12. If any of the keys does not exist, return an empty array.", + presetValues: map[string]interface{}{ + "ZinterKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZinterKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTER", "non-existent", "ZinterKey32", "ZinterKey33"}, + expectedResponse: [][]string{}, + expectedError: nil, + }, + { + name: "13. Command too short", + command: []string{"ZINTER"}, + expectedResponse: [][]string{}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() + } + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } + } + } + } + }) + } + }) + + t.Run("Test_HandleZINTERSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + destination string + command []string + expectedValue *sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Get the intersection between 2 sorted sets.", + presetValues: map[string]interface{}{ + "ZinterStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + "ZinterStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZinterStoreDestinationKey1", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey1", "ZinterStoreKey1", "ZinterStoreKey2"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 6}, {Value: "four", Score: 8}, + {Value: "five", Score: 10}, + }), + expectedResponse: 3, + expectedError: nil, + }, + { + // 2. Get the intersection between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + name: "2. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey2", + command: []string{ + "ZINTERSTORE", "ZinterStoreDestinationKey2", "ZinterStoreKey3", "ZinterStoreKey4", "ZinterStoreKey5", "WITHSCORES", + }, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3}, {Value: "eight", Score: 24}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 3. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate. + name: "3. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey3", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey3", "ZinterStoreKey6", "ZinterStoreKey7", "ZinterStoreKey8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "eight", Score: 8}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 4. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate. + name: "4. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey4", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey4", "ZinterStoreKey9", "ZinterStoreKey10", "ZinterStoreKey11", "WITHSCORES", "AGGREGATE", "MAX"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 5. Get the intersection between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + name: "5. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterStoreKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey5", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey5", "ZinterStoreKey12", "ZinterStoreKey13", "ZinterStoreKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "5", "3"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3105}, {Value: "eight", Score: 2808}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 6. Get the intersection between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + name: "6. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterStoreKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey6", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey6", "ZinterStoreKey15", "ZinterStoreKey16", "ZinterStoreKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "5", "3"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3000}, {Value: "eight", Score: 2400}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + // 7. Get the intersection between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + name: "7. Get the intersection between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZinterStoreKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZinterStoreKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZinterStoreDestinationKey7", + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey7", "ZinterStoreKey18", "ZinterStoreKey19", "ZinterStoreKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "5", "3"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 5}, {Value: "eight", Score: 8}, + }), + expectedResponse: 2, + expectedError: nil, + }, + { + name: "8. Throw an error if there are more weights than keys", + presetValues: map[string]interface{}{ + "ZinterStoreKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey8", "ZinterStoreKey21", "ZinterStoreKey22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "9. Throw an error if there are fewer weights than keys", + presetValues: map[string]interface{}{ + "ZinterStoreKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + }), + "ZinterStoreKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey9", "ZinterStoreKey23", "ZinterStoreKey24", "ZinterStoreKey25", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "10. Throw an error if there are no keys provided", + presetValues: map[string]interface{}{ + "ZinterStoreKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZinterStoreKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZinterStoreKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTERSTORE", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "11. Throw an error if any of the provided keys are not sorted sets", + presetValues: map[string]interface{}{ + "ZinterStoreKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZinterStoreKey30": "Default value", + "ZinterStoreKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZINTERSTORE", "ZinterStoreKey29", "ZinterStoreKey30", "ZinterStoreKey31"}, + expectedResponse: 0, + expectedError: errors.New("value at ZinterStoreKey30 is not a sorted set"), + }, + { + name: "12. If any of the keys does not exist, return an empty array.", + presetValues: map[string]interface{}{ + "ZinterStoreKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZinterStoreKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZINTERSTORE", "ZinterStoreDestinationKey12", "non-existent", "ZinterStoreKey32", "ZinterStoreKey33"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "13. Command too short", + command: []string{"ZINTERSTORE"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) + } + } + }) + } + }) + + t.Run("Test_HandleZUNION", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + presetValues map[string]interface{} + command []string + expectedResponse [][]string + expectedError error + }{ + { + name: "1. Get the union between 2 sorted sets.", + presetValues: map[string]interface{}{ + "ZunionKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + "ZunionKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + command: []string{"ZUNION", "ZunionKey1", "ZunionKey2"}, + expectedResponse: [][]string{{"one"}, {"two"}, {"three"}, {"four"}, {"five"}, {"six"}, {"seven"}, {"eight"}}, + expectedError: nil, + }, + { + // 2. Get the union between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + name: "2. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, + }), + "ZunionKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 36}, + }), + }, + command: []string{"ZUNION", "ZunionKey3", "ZunionKey4", "ZunionKey5", "WITHSCORES"}, + expectedResponse: [][]string{ + {"one", "3"}, {"two", "4"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, + {"seven", "7"}, {"eight", "24"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, + {"twelve", "24"}, {"thirty-six", "72"}, + }, + expectedError: nil, + }, + { + // 3. Get the union between 3 sorted sets with scores. + // Use MIN aggregate. + name: "3. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, + }), + }, + command: []string{"ZUNION", "ZunionKey6", "ZunionKey7", "ZunionKey8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedResponse: [][]string{ + {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, + {"seven", "7"}, {"eight", "8"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, + {"twelve", "12"}, {"thirty-six", "36"}, + }, + expectedError: nil, + }, + { + // 4. Get the union between 3 sorted sets with scores. + // Use MAX aggregate. + name: "4. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, + }), + }, + command: []string{"ZUNION", "ZunionKey9", "ZunionKey10", "ZunionKey11", "WITHSCORES", "AGGREGATE", "MAX"}, + expectedResponse: [][]string{ + {"one", "1000"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, + {"seven", "7"}, {"eight", "800"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, + {"twelve", "12"}, {"thirty-six", "72"}, + }, + expectedError: nil, + }, + { + // 5. Get the union between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + name: "5. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZUNION", "ZunionKey12", "ZunionKey13", "ZunionKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "2", "3"}, + expectedResponse: [][]string{ + {"one", "3102"}, {"two", "6"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, + {"seven", "7"}, {"eight", "2568"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, + {"twelve", "60"}, {"thirty-six", "72"}, + }, + expectedError: nil, + }, + { + // 6. Get the union between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + name: "6. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZUNION", "ZunionKey15", "ZunionKey16", "ZunionKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "2", "3"}, + expectedResponse: [][]string{ + {"one", "3000"}, {"two", "4"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, + {"seven", "7"}, {"eight", "2400"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, + {"twelve", "36"}, {"thirty-six", "72"}, + }, + expectedError: nil, + }, + { + // 7. Get the union between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + name: "7. Get the union between 3 sorted sets with scores.", + presetValues: map[string]interface{}{ + "ZunionKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZUNION", "ZunionKey18", "ZunionKey19", "ZunionKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "2", "3"}, + expectedResponse: [][]string{ + {"one", "2"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, {"seven", "7"}, + {"eight", "8"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, {"twelve", "24"}, {"thirty-six", "72"}, + }, + expectedError: nil, + }, + { + name: "8. Throw an error if there are more weights than keys", + presetValues: map[string]interface{}{ + "ZunionKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZUNION", "ZunionKey21", "ZunionKey22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "9. Throw an error if there are fewer weights than keys", + presetValues: map[string]interface{}{ + "ZunionKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + }), + "ZunionKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZUNION", "ZunionKey23", "ZunionKey24", "ZunionKey25", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New("number of weights should match number of keys"), + }, + { + name: "10. Throw an error if there are no keys provided", + presetValues: map[string]interface{}{ + "ZunionKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZunionKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZunionKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZUNION", "WEIGHTS", "5", "4"}, + expectedResponse: nil, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "11. Throw an error if any of the provided keys are not sorted sets", + presetValues: map[string]interface{}{ + "ZunionKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionKey30": "Default value", + "ZunionKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZUNION", "ZunionKey29", "ZunionKey30", "ZunionKey31"}, + expectedResponse: nil, + expectedError: errors.New("value at ZunionKey30 is not a sorted set"), + }, + { + name: "12. If any of the keys does not exist, skip it.", + presetValues: map[string]interface{}{ + "ZunionKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZunionKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + command: []string{"ZUNION", "non-existent", "ZunionKey32", "ZunionKey33"}, + expectedResponse: [][]string{ + {"one"}, {"two"}, {"thirty-six"}, {"twelve"}, {"eleven"}, + {"seven"}, {"eight"}, {"nine"}, {"ten"}, + }, + expectedError: nil, + }, + { + name: "13. Command too short", + command: []string{"ZUNION"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) + } + } + + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if len(res.Array()) != len(test.expectedResponse) { + t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) + } + + for _, item := range res.Array() { + value := item.Array()[0].String() + score := func() string { + if len(item.Array()) == 2 { + return item.Array()[1].String() + } + return "" + }() + if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { + return expected[0] == value + }) { + t.Errorf("unexpected member \"%s\" in response", value) + } + if score != "" { + for _, expected := range test.expectedResponse { + if expected[0] == value && expected[1] != score { + t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + } + } + } + } + }) + } + }) + + t.Run("Test_HandleZUNIONSTORE", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error() + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) + + tests := []struct { + name string + preset bool + presetValues map[string]interface{} + destination string + command []string + expectedValue *sorted_set.SortedSet + expectedResponse int + expectedError error + }{ + { + name: "1. Get the union between 2 sorted sets.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, + }), + "ZunionStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + }, + destination: "ZunionStoreDestinationKey1", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey1", "ZunionStoreKey1", "ZunionStoreKey2"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "three", Score: 6}, {Value: "four", Score: 8}, + {Value: "five", Score: 10}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, }), - "ZunionKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, + expectedResponse: 8, + expectedError: nil, + }, + { + // 2. Get the union between 3 sorted sets with scores. + // By default, the SUM aggregate will be used. + name: "2. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 36}, + }), + }, + destination: "ZunionStoreDestinationKey2", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey2", "ZunionStoreKey3", "ZunionStoreKey4", "ZunionStoreKey5", "WITHSCORES"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3}, {Value: "two", Score: 4}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 24}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, + {Value: "twelve", Score: 24}, {Value: "thirty-six", Score: 72}, }), - "ZunionKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + expectedResponse: 13, + expectedError: nil, + }, + { + // 3. Get the union between 3 sorted sets with scores. + // Use MIN aggregate. + name: "3. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, + }), + }, + destination: "ZunionStoreDestinationKey3", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey3", "ZunionStoreKey6", "ZunionStoreKey7", "ZunionStoreKey8", "WITHSCORES", "AGGREGATE", "MIN"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 36}, }), + expectedResponse: 13, + expectedError: nil, }, - command: []string{"ZUNION", "ZunionKey3", "ZunionKey4", "ZunionKey5", "WITHSCORES"}, - expectedResponse: [][]string{ - {"one", "3"}, {"two", "4"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, - {"seven", "7"}, {"eight", "24"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, - {"twelve", "24"}, {"thirty-six", "72"}, - }, - expectedError: nil, - }, - { - // 3. Get the union between 3 sorted sets with scores. - // Use MIN aggregate. - name: "3. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + { + // 4. Get the union between 3 sorted sets with scores. + // Use MAX aggregate. + name: "4. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, + }), + }, + destination: "ZunionStoreDestinationKey4", + command: []string{ + "ZUNIONSTORE", "ZunionStoreDestinationKey4", "ZunionStoreKey9", "ZunionStoreKey10", "ZunionStoreKey11", "WITHSCORES", "AGGREGATE", "MAX", + }, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, }), + expectedResponse: 13, + expectedError: nil, }, - command: []string{"ZUNION", "ZunionKey6", "ZunionKey7", "ZunionKey8", "WITHSCORES", "AGGREGATE", "MIN"}, - expectedResponse: [][]string{ - {"one", "1"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, - {"seven", "7"}, {"eight", "8"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, - {"twelve", "12"}, {"thirty-six", "36"}, + { + // 5. Get the union between 3 sorted sets with scores. + // Use SUM aggregate with weights modifier. + name: "5. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionStoreKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZunionStoreDestinationKey5", + command: []string{ + "ZUNIONSTORE", "ZunionStoreDestinationKey5", "ZunionStoreKey12", "ZunionStoreKey13", "ZunionStoreKey14", + "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "2", "3", + }, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3102}, {Value: "two", Score: 6}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 2568}, + {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, + {Value: "twelve", Score: 60}, {Value: "thirty-six", Score: 72}, + }), + expectedResponse: 13, + expectedError: nil, }, - expectedError: nil, - }, - { - // 4. Get the union between 3 sorted sets with scores. - // Use MAX aggregate. - name: "4. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, + { + // 6. Get the union between 3 sorted sets with scores. + // Use MAX aggregate with added weights. + name: "6. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionStoreKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZunionStoreDestinationKey6", + command: []string{ + "ZUNIONSTORE", "ZunionStoreDestinationKey6", "ZunionStoreKey15", "ZunionStoreKey16", "ZunionStoreKey17", + "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "2", "3"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 3000}, {Value: "two", Score: 4}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 2400}, + {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, + {Value: "twelve", Score: 36}, {Value: "thirty-six", Score: 72}, }), + expectedResponse: 13, + expectedError: nil, }, - command: []string{"ZUNION", "ZunionKey9", "ZunionKey10", "ZunionKey11", "WITHSCORES", "AGGREGATE", "MAX"}, - expectedResponse: [][]string{ - {"one", "1000"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, - {"seven", "7"}, {"eight", "800"}, {"nine", "9"}, {"ten", "10"}, {"eleven", "11"}, - {"twelve", "12"}, {"thirty-six", "72"}, + { + // 7. Get the union between 3 sorted sets with scores. + // Use MIN aggregate with added weights. + name: "7. Get the union between 3 sorted sets with scores.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 100}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, + }), + "ZunionStoreKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZunionStoreDestinationKey7", + command: []string{ + "ZUNIONSTORE", "ZunionStoreDestinationKey7", "ZunionStoreKey18", "ZunionStoreKey19", "ZunionStoreKey20", + "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "2", "3", + }, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 2}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, + {Value: "twelve", Score: 24}, {Value: "thirty-six", Score: 72}, + }), + expectedResponse: 13, + expectedError: nil, }, - expectedError: nil, - }, - { - // 5. Get the union between 3 sorted sets with scores. - // Use SUM aggregate with weights modifier. - name: "5. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "8. Throw an error if there are more weights than keys", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + destination: "ZunionStoreDestinationKey8", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey8", "ZunionStoreKey21", "ZunionStoreKey22", "WEIGHTS", "1", "2", "3"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), }, - command: []string{"ZUNION", "ZunionKey12", "ZunionKey13", "ZunionKey14", "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "2", "3"}, - expectedResponse: [][]string{ - {"one", "3102"}, {"two", "6"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, - {"seven", "7"}, {"eight", "2568"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, - {"twelve", "60"}, {"thirty-six", "72"}, + { + name: "9. Throw an error if there are fewer weights than keys", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + }), + "ZunionStoreKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + destination: "ZunionStoreDestinationKey9", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey9", "ZunionStoreKey23", "ZunionStoreKey24", "ZunionStoreKey25", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New("number of weights should match number of keys"), }, - expectedError: nil, - }, - { - // 6. Get the union between 3 sorted sets with scores. - // Use MAX aggregate with added weights. - name: "6. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), + { + name: "10. Throw an error if there are no keys provided", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZunionStoreKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + "ZunionStoreKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + command: []string{"ZUNIONSTORE", "WEIGHTS", "5", "4"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - command: []string{"ZUNION", "ZunionKey15", "ZunionKey16", "ZunionKey17", "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "2", "3"}, - expectedResponse: [][]string{ - {"one", "3000"}, {"two", "4"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, - {"seven", "7"}, {"eight", "2400"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, - {"twelve", "36"}, {"thirty-six", "72"}, + { + name: "11. Throw an error if any of the provided keys are not sorted sets", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "three", Score: 3}, {Value: "four", Score: 4}, + {Value: "five", Score: 5}, {Value: "six", Score: 6}, + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + }), + "ZunionStoreKey30": "Default value", + "ZunionStoreKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), + }, + destination: "ZunionStoreDestinationKey11", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey11", "ZunionStoreKey29", "ZunionStoreKey30", "ZunionStoreKey31"}, + expectedResponse: 0, + expectedError: errors.New("value at ZunionStoreKey30 is not a sorted set"), }, - expectedError: nil, - }, - { - // 7. Get the union between 3 sorted sets with scores. - // Use MIN aggregate with added weights. - name: "7. Get the union between 3 sorted sets with scores.", - presetValues: map[string]interface{}{ - "ZunionKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, + { + name: "12. If any of the keys does not exist, skip it.", + preset: true, + presetValues: map[string]interface{}{ + "ZunionStoreKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, + {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, + {Value: "eleven", Score: 11}, + }), + "ZunionStoreKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, + {Value: "twelve", Score: 12}, + }), + }, + destination: "ZunionStoreDestinationKey12", + command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey12", "non-existent", "ZunionStoreKey32", "ZunionStoreKey33"}, + expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ + {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, + {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 24}, + {Value: "thirty-six", Score: 36}, }), + expectedResponse: 9, + expectedError: nil, }, - command: []string{"ZUNION", "ZunionKey18", "ZunionKey19", "ZunionKey20", "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "2", "3"}, - expectedResponse: [][]string{ - {"one", "2"}, {"two", "2"}, {"three", "3"}, {"four", "4"}, {"five", "5"}, {"six", "6"}, {"seven", "7"}, - {"eight", "8"}, {"nine", "27"}, {"ten", "30"}, {"eleven", "22"}, {"twelve", "24"}, {"thirty-six", "72"}, + { + name: "13. Command too short", + preset: false, + command: []string{"ZUNIONSTORE"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), }, - expectedError: nil, - }, - { - name: "8. Throw an error if there are more weights than keys", - presetValues: map[string]interface{}{ - "ZunionKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZUNION", "ZunionKey21", "ZunionKey22", "WEIGHTS", "1", "2", "3"}, - expectedResponse: nil, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "9. Throw an error if there are fewer weights than keys", - presetValues: map[string]interface{}{ - "ZunionKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - }), - "ZunionKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZUNION", "ZunionKey23", "ZunionKey24", "ZunionKey25", "WEIGHTS", "5", "4"}, - expectedResponse: nil, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "10. Throw an error if there are no keys provided", - presetValues: map[string]interface{}{ - "ZunionKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZunionKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZunionKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZUNION", "WEIGHTS", "5", "4"}, - expectedResponse: nil, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "11. Throw an error if any of the provided keys are not sorted sets", - presetValues: map[string]interface{}{ - "ZunionKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionKey30": "Default value", - "ZunionKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZUNION", "ZunionKey29", "ZunionKey30", "ZunionKey31"}, - expectedResponse: nil, - expectedError: errors.New("value at ZunionKey30 is not a sorted set"), - }, - { - name: "12. If any of the keys does not exist, skip it.", - presetValues: map[string]interface{}{ - "ZunionKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZunionKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - command: []string{"ZUNION", "non-existent", "ZunionKey32", "ZunionKey33"}, - expectedResponse: [][]string{ - {"one"}, {"two"}, {"thirty-six"}, {"twelve"}, {"eleven"}, - {"seven"}, {"eight"}, {"nine"}, {"ten"}, - }, - expectedError: nil, - }, - { - name: "13. Command too short", - command: []string{"ZUNION"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValues != nil { + var command []resp.Value + var expected string + for key, value := range test.presetValues { + switch value.(type) { + case string: + command = []resp.Value{ + resp.StringValue("SET"), + resp.StringValue(key), + resp.StringValue(value.(string)), + } + expected = "ok" + case *sorted_set.SortedSet: + command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} + for _, member := range value.(*sorted_set.SortedSet).GetAll() { + command = append(command, []resp.Value{ + resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), + resp.StringValue(string(member.Value)), + }...) + } + expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if len(res.Array()) != len(test.expectedResponse) { - t.Errorf("expected response array of length %d, got %d", len(test.expectedResponse), len(res.Array())) - } - - for _, item := range res.Array() { - value := item.Array()[0].String() - score := func() string { - if len(item.Array()) == 2 { - return item.Array()[1].String() - } - return "" - }() - if !slices.ContainsFunc(test.expectedResponse, func(expected []string) bool { - return expected[0] == value - }) { - t.Errorf("unexpected member \"%s\" in response", value) - } - if score != "" { - for _, expected := range test.expectedResponse { - if expected[0] == value && expected[1] != score { - t.Errorf("expected score for member \"%s\" to be %s, got %s", value, expected[1], score) + if !strings.EqualFold(res.String(), expected) { + t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) } } } - } - }) - } -} - -func Test_HandleZUNIONSTORE(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error() - return - } - client := resp.NewConn(conn) - - tests := []struct { - name string - preset bool - presetValues map[string]interface{} - destination string - command []string - expectedValue *sorted_set.SortedSet - expectedResponse int - expectedError error - }{ - { - name: "1. Get the union between 2 sorted sets.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey1": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, - }), - "ZunionStoreKey2": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - }, - destination: "ZunionStoreDestinationKey1", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey1", "ZunionStoreKey1", "ZunionStoreKey2"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 6}, {Value: "four", Score: 8}, - {Value: "five", Score: 10}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - expectedResponse: 8, - expectedError: nil, - }, - { - // 2. Get the union between 3 sorted sets with scores. - // By default, the SUM aggregate will be used. - name: "2. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey3": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey4": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey5": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 36}, - }), - }, - destination: "ZunionStoreDestinationKey2", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey2", "ZunionStoreKey3", "ZunionStoreKey4", "ZunionStoreKey5", "WITHSCORES"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3}, {Value: "two", Score: 4}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 24}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, - {Value: "twelve", Score: 24}, {Value: "thirty-six", Score: 72}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - // 3. Get the union between 3 sorted sets with scores. - // Use MIN aggregate. - name: "3. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey6": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey7": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionStoreKey8": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, - }), - }, - destination: "ZunionStoreDestinationKey3", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey3", "ZunionStoreKey6", "ZunionStoreKey7", "ZunionStoreKey8", "WITHSCORES", "AGGREGATE", "MIN"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 36}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - // 4. Get the union between 3 sorted sets with scores. - // Use MAX aggregate. - name: "4. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey9": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey10": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionStoreKey11": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, - }), - }, - destination: "ZunionStoreDestinationKey4", - command: []string{ - "ZUNIONSTORE", "ZunionStoreDestinationKey4", "ZunionStoreKey9", "ZunionStoreKey10", "ZunionStoreKey11", "WITHSCORES", "AGGREGATE", "MAX", - }, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, - {Value: "twelve", Score: 12}, {Value: "thirty-six", Score: 72}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - // 5. Get the union between 3 sorted sets with scores. - // Use SUM aggregate with weights modifier. - name: "5. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey12": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey13": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionStoreKey14": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZunionStoreDestinationKey5", - command: []string{ - "ZUNIONSTORE", "ZunionStoreDestinationKey5", "ZunionStoreKey12", "ZunionStoreKey13", "ZunionStoreKey14", - "WITHSCORES", "AGGREGATE", "SUM", "WEIGHTS", "1", "2", "3", - }, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3102}, {Value: "two", Score: 6}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 2568}, - {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, - {Value: "twelve", Score: 60}, {Value: "thirty-six", Score: 72}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - // 6. Get the union between 3 sorted sets with scores. - // Use MAX aggregate with added weights. - name: "6. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey15": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey16": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionStoreKey17": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZunionStoreDestinationKey6", - command: []string{ - "ZUNIONSTORE", "ZunionStoreDestinationKey6", "ZunionStoreKey15", "ZunionStoreKey16", "ZunionStoreKey17", - "WITHSCORES", "AGGREGATE", "MAX", "WEIGHTS", "1", "2", "3"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 3000}, {Value: "two", Score: 4}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 2400}, - {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, - {Value: "twelve", Score: 36}, {Value: "thirty-six", Score: 72}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - // 7. Get the union between 3 sorted sets with scores. - // Use MIN aggregate with added weights. - name: "7. Get the union between 3 sorted sets with scores.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey18": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 100}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey19": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, {Value: "eight", Score: 80}, - }), - "ZunionStoreKey20": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1000}, {Value: "eight", Score: 800}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZunionStoreDestinationKey7", - command: []string{ - "ZUNIONSTORE", "ZunionStoreDestinationKey7", "ZunionStoreKey18", "ZunionStoreKey19", "ZunionStoreKey20", - "WITHSCORES", "AGGREGATE", "MIN", "WEIGHTS", "1", "2", "3", - }, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 2}, {Value: "two", Score: 2}, {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 27}, {Value: "ten", Score: 30}, {Value: "eleven", Score: 22}, - {Value: "twelve", Score: 24}, {Value: "thirty-six", Score: 72}, - }), - expectedResponse: 13, - expectedError: nil, - }, - { - name: "8. Throw an error if there are more weights than keys", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey21": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey22": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - destination: "ZunionStoreDestinationKey8", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey8", "ZunionStoreKey21", "ZunionStoreKey22", "WEIGHTS", "1", "2", "3"}, - expectedResponse: 0, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "9. Throw an error if there are fewer weights than keys", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey23": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey24": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - }), - "ZunionStoreKey25": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - destination: "ZunionStoreDestinationKey9", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey9", "ZunionStoreKey23", "ZunionStoreKey24", "ZunionStoreKey25", "WEIGHTS", "5", "4"}, - expectedResponse: 0, - expectedError: errors.New("number of weights should match number of keys"), - }, - { - name: "10. Throw an error if there are no keys provided", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey26": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZunionStoreKey27": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - "ZunionStoreKey28": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - command: []string{"ZUNIONSTORE", "WEIGHTS", "5", "4"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "11. Throw an error if any of the provided keys are not sorted sets", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey29": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "three", Score: 3}, {Value: "four", Score: 4}, - {Value: "five", Score: 5}, {Value: "six", Score: 6}, - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - }), - "ZunionStoreKey30": "Default value", - "ZunionStoreKey31": sorted_set.NewSortedSet([]sorted_set.MemberParam{{Value: "one", Score: 1}}), - }, - destination: "ZunionStoreDestinationKey11", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey11", "ZunionStoreKey29", "ZunionStoreKey30", "ZunionStoreKey31"}, - expectedResponse: 0, - expectedError: errors.New("value at ZunionStoreKey30 is not a sorted set"), - }, - { - name: "12. If any of the keys does not exist, skip it.", - preset: true, - presetValues: map[string]interface{}{ - "ZunionStoreKey32": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, - {Value: "thirty-six", Score: 36}, {Value: "twelve", Score: 12}, - {Value: "eleven", Score: 11}, - }), - "ZunionStoreKey33": sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, - {Value: "twelve", Score: 12}, - }), - }, - destination: "ZunionStoreDestinationKey12", - command: []string{"ZUNIONSTORE", "ZunionStoreDestinationKey12", "non-existent", "ZunionStoreKey32", "ZunionStoreKey33"}, - expectedValue: sorted_set.NewSortedSet([]sorted_set.MemberParam{ - {Value: "one", Score: 1}, {Value: "two", Score: 2}, {Value: "seven", Score: 7}, {Value: "eight", Score: 8}, - {Value: "nine", Score: 9}, {Value: "ten", Score: 10}, {Value: "eleven", Score: 11}, {Value: "twelve", Score: 24}, - {Value: "thirty-six", Score: 36}, - }), - expectedResponse: 9, - expectedError: nil, - }, - { - name: "13. Command too short", - preset: false, - command: []string{"ZUNIONSTORE"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValues != nil { - var command []resp.Value - var expected string - for key, value := range test.presetValues { - switch value.(type) { - case string: - command = []resp.Value{ - resp.StringValue("SET"), - resp.StringValue(key), - resp.StringValue(value.(string)), - } - expected = "ok" - case *sorted_set.SortedSet: - command = []resp.Value{resp.StringValue("ZADD"), resp.StringValue(key)} - for _, member := range value.(*sorted_set.SortedSet).GetAll() { - command = append(command, []resp.Value{ - resp.StringValue(strconv.FormatFloat(float64(member.Score), 'f', -1, 64)), - resp.StringValue(string(member.Value)), - }...) - } - expected = strconv.Itoa(value.(*sorted_set.SortedSet).Cardinality()) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if !strings.EqualFold(res.String(), expected) { - t.Errorf("expected preset response to be \"%s\", got %s", expected, res.String()) - } - } - } - - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) - } - return - } - - if res.Integer() != test.expectedResponse { - t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) - } - - // Check if the resulting sorted set has the expected members/scores - if test.expectedValue == nil { - return - } - - if err = client.WriteArray([]resp.Value{ - resp.StringValue("ZRANGE"), - resp.StringValue(test.destination), - resp.StringValue("-inf"), - resp.StringValue("+inf"), - resp.StringValue("BYSCORE"), - resp.StringValue("WITHSCORES"), - }); err != nil { - t.Error(err) - } - - res, _, err = client.ReadValue() - if err != nil { - t.Error(err) - } - - if len(res.Array()) != test.expectedValue.Cardinality() { - t.Errorf("expected resulting set %s to have cardinality %d, got %d", - test.destination, test.expectedValue.Cardinality(), len(res.Array())) - } - - for _, member := range res.Array() { - value := sorted_set.Value(member.Array()[0].String()) - score := sorted_set.Score(member.Array()[1].Float()) - if !test.expectedValue.Contains(value) { - t.Errorf("unexpected value %s in resulting sorted set", value) - } - if test.expectedValue.Get(value).Score != score { - t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) - } - } - }) - } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), res.Error().Error()) + } + return + } + + if res.Integer() != test.expectedResponse { + t.Errorf("expected response %d, got %d", test.expectedResponse, res.Integer()) + } + + // Check if the resulting sorted set has the expected members/scores + if test.expectedValue == nil { + return + } + + if err = client.WriteArray([]resp.Value{ + resp.StringValue("ZRANGE"), + resp.StringValue(test.destination), + resp.StringValue("-inf"), + resp.StringValue("+inf"), + resp.StringValue("BYSCORE"), + resp.StringValue("WITHSCORES"), + }); err != nil { + t.Error(err) + } + + res, _, err = client.ReadValue() + if err != nil { + t.Error(err) + } + + if len(res.Array()) != test.expectedValue.Cardinality() { + t.Errorf("expected resulting set %s to have cardinality %d, got %d", + test.destination, test.expectedValue.Cardinality(), len(res.Array())) + } + + for _, member := range res.Array() { + value := sorted_set.Value(member.Array()[0].String()) + score := sorted_set.Score(member.Array()[1].Float()) + if !test.expectedValue.Contains(value) { + t.Errorf("unexpected value %s in resulting sorted set", value) + } + if test.expectedValue.Get(value).Score != score { + t.Errorf("expected value %s to have score %v, got %v", value, test.expectedValue.Get(value).Score, score) + } + } + }) + } + }) } diff --git a/internal/modules/string/commands_test.go b/internal/modules/string/commands_test.go index 4aff8f1..c67e182 100644 --- a/internal/modules/string/commands_test.go +++ b/internal/modules/string/commands_test.go @@ -29,20 +29,26 @@ import ( "testing" ) -var mockServer *echovault.EchoVault -var addr = "localhost" -var port int +func Test_String(t *testing.T) { + port, err := internal.GetFreePort() + if err != nil { + t.Error() + return + } -func init() { - port, _ = internal.GetFreePort() - mockServer, _ = echovault.NewEchoVault( + mockServer, err := echovault.NewEchoVault( echovault.WithConfig(config.Config{ - BindAddr: addr, + BindAddr: "localhost", Port: uint16(port), DataDir: "", EvictionPolicy: constants.NoEviction, }), ) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -50,115 +56,140 @@ func init() { mockServer.Start() }() wg.Wait() -} -func Test_HandleSetRange(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - return - } - client := resp.NewConn(conn) + t.Cleanup(func() { + mockServer.ShutDown() + }) - tests := []struct { - name string - key string - presetValue string - command []string - expectedValue string - expectedResponse int - expectedError error - }{ - { - name: "Test that SETRANGE on non-existent string creates new string", - key: "SetRangeKey1", - presetValue: "", - command: []string{"SETRANGE", "SetRangeKey1", "10", "New String Value"}, - expectedValue: "New String Value", - expectedResponse: len("New String Value"), - expectedError: nil, - }, - { - name: "Test SETRANGE with an offset that leads to a longer resulting string", - key: "SetRangeKey2", - presetValue: "Original String Value", - command: []string{"SETRANGE", "SetRangeKey2", "16", "Portion Replaced With This New String"}, - expectedValue: "Original String Portion Replaced With This New String", - expectedResponse: len("Original String Portion Replaced With This New String"), - expectedError: nil, - }, - { - name: "SETRANGE with negative offset prepends the string", - key: "SetRangeKey3", - presetValue: "This is a preset value", - command: []string{"SETRANGE", "SetRangeKey3", "-10", "Prepended "}, - expectedValue: "Prepended This is a preset value", - expectedResponse: len("Prepended This is a preset value"), - expectedError: nil, - }, - { - name: "SETRANGE with offset that embeds new string inside the old string", - key: "SetRangeKey4", - presetValue: "This is a preset value", - command: []string{"SETRANGE", "SetRangeKey4", "0", "That"}, - expectedValue: "That is a preset value", - expectedResponse: len("That is a preset value"), - expectedError: nil, - }, - { - name: "SETRANGE with offset longer than original lengths appends the string", - key: "SetRangeKey5", - presetValue: "This is a preset value", - command: []string{"SETRANGE", "SetRangeKey5", "100", " Appended"}, - expectedValue: "This is a preset value Appended", - expectedResponse: len("This is a preset value Appended"), - expectedError: nil, - }, - { - name: "SETRANGE with offset on the last character replaces last character with new string", - key: "SetRangeKey6", - presetValue: "This is a preset value", - command: []string{"SETRANGE", "SetRangeKey6", strconv.Itoa(len("This is a preset value") - 1), " replaced"}, - expectedValue: "This is a preset valu replaced", - expectedResponse: len("This is a preset valu replaced"), - expectedError: nil, - }, - { - name: " Offset not integer", - command: []string{"SETRANGE", "key", "offset", "value"}, - expectedResponse: 0, - expectedError: errors.New("offset must be an integer"), - }, - { - name: "SETRANGE target is not a string", - key: "test-int", - presetValue: "10", - command: []string{"SETRANGE", "test-int", "10", "value"}, - expectedResponse: 0, - expectedError: errors.New("value at key test-int is not a string"), - }, - { - name: "Command too short", - command: []string{"SETRANGE", "key"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "Command too long", - command: []string{"SETRANGE", "key", "offset", "value", "value1"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + t.Run("Test_HandleSetRange", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != "" { - if err = client.WriteArray([]resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue), - }); err != nil { + tests := []struct { + name string + key string + presetValue string + command []string + expectedValue string + expectedResponse int + expectedError error + }{ + { + name: "Test that SETRANGE on non-existent string creates new string", + key: "SetRangeKey1", + presetValue: "", + command: []string{"SETRANGE", "SetRangeKey1", "10", "New String Value"}, + expectedValue: "New String Value", + expectedResponse: len("New String Value"), + expectedError: nil, + }, + { + name: "Test SETRANGE with an offset that leads to a longer resulting string", + key: "SetRangeKey2", + presetValue: "Original String Value", + command: []string{"SETRANGE", "SetRangeKey2", "16", "Portion Replaced With This New String"}, + expectedValue: "Original String Portion Replaced With This New String", + expectedResponse: len("Original String Portion Replaced With This New String"), + expectedError: nil, + }, + { + name: "SETRANGE with negative offset prepends the string", + key: "SetRangeKey3", + presetValue: "This is a preset value", + command: []string{"SETRANGE", "SetRangeKey3", "-10", "Prepended "}, + expectedValue: "Prepended This is a preset value", + expectedResponse: len("Prepended This is a preset value"), + expectedError: nil, + }, + { + name: "SETRANGE with offset that embeds new string inside the old string", + key: "SetRangeKey4", + presetValue: "This is a preset value", + command: []string{"SETRANGE", "SetRangeKey4", "0", "That"}, + expectedValue: "That is a preset value", + expectedResponse: len("That is a preset value"), + expectedError: nil, + }, + { + name: "SETRANGE with offset longer than original lengths appends the string", + key: "SetRangeKey5", + presetValue: "This is a preset value", + command: []string{"SETRANGE", "SetRangeKey5", "100", " Appended"}, + expectedValue: "This is a preset value Appended", + expectedResponse: len("This is a preset value Appended"), + expectedError: nil, + }, + { + name: "SETRANGE with offset on the last character replaces last character with new string", + key: "SetRangeKey6", + presetValue: "This is a preset value", + command: []string{"SETRANGE", "SetRangeKey6", strconv.Itoa(len("This is a preset value") - 1), " replaced"}, + expectedValue: "This is a preset valu replaced", + expectedResponse: len("This is a preset valu replaced"), + expectedError: nil, + }, + { + name: " Offset not integer", + command: []string{"SETRANGE", "key", "offset", "value"}, + expectedResponse: 0, + expectedError: errors.New("offset must be an integer"), + }, + { + name: "SETRANGE target is not a string", + key: "test-int", + presetValue: "10", + command: []string{"SETRANGE", "test-int", "10", "value"}, + expectedResponse: 0, + expectedError: errors.New("value at key test-int is not a string"), + }, + { + name: "Command too short", + command: []string{"SETRANGE", "key"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "Command too long", + command: []string{"SETRANGE", "key", "offset", "value", "value1"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } + + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } res, _, err := client.ReadValue() @@ -166,95 +197,100 @@ func Test_HandleSetRange(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } - return - } + }) + } + }) - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} + t.Run("Test_HandleStrLen", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) -func Test_HandleStrLen(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - } - client := resp.NewConn(conn) + tests := []struct { + name string + key string + presetValue string + command []string + expectedResponse int + expectedError error + }{ + { + name: "Return the correct string length for an existing string", + key: "StrLenKey1", + presetValue: "Test String", + command: []string{"STRLEN", "StrLenKey1"}, + expectedResponse: len("Test String"), + expectedError: nil, + }, + { + name: "If the string does not exist, return 0", + key: "StrLenKey2", + presetValue: "", + command: []string{"STRLEN", "StrLenKey2"}, + expectedResponse: 0, + expectedError: nil, + }, + { + name: "Too few args", + key: "StrLenKey3", + presetValue: "", + command: []string{"STRLEN"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "Too many args", + key: "StrLenKey4", + presetValue: "", + command: []string{"STRLEN", "StrLenKey4", "StrLenKey5"}, + expectedResponse: 0, + expectedError: errors.New(constants.WrongArgsResponse), + }, + } - tests := []struct { - name string - key string - presetValue string - command []string - expectedResponse int - expectedError error - }{ - { - name: "Return the correct string length for an existing string", - key: "StrLenKey1", - presetValue: "Test String", - command: []string{"STRLEN", "StrLenKey1"}, - expectedResponse: len("Test String"), - expectedError: nil, - }, - { - name: "If the string does not exist, return 0", - key: "StrLenKey2", - presetValue: "", - command: []string{"STRLEN", "StrLenKey2"}, - expectedResponse: 0, - expectedError: nil, - }, - { - name: "Too few args", - key: "StrLenKey3", - presetValue: "", - command: []string{"STRLEN"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "Too many args", - key: "StrLenKey4", - presetValue: "", - command: []string{"STRLEN", "StrLenKey4", "StrLenKey5"}, - expectedResponse: 0, - expectedError: errors.New(constants.WrongArgsResponse), - }, - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != "" { - if err = client.WriteArray([]resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue), - }); err != nil { + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } res, _, err := client.ReadValue() @@ -262,140 +298,145 @@ func Test_HandleStrLen(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if res.Integer() != test.expectedResponse { + t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) } - return - } + }) + } + }) - if res.Integer() != test.expectedResponse { - t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, res.Integer()) - } - }) - } -} + t.Run("Test_HandleSubStr", func(t *testing.T) { + t.Parallel() + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) + if err != nil { + t.Error(err) + return + } + defer func() { + _ = conn.Close() + }() + client := resp.NewConn(conn) -func Test_HandleSubStr(t *testing.T) { - conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", addr, port)) - if err != nil { - t.Error(err) - } - client := resp.NewConn(conn) + tests := []struct { + name string + key string + presetValue string + command []string + expectedResponse string + expectedError error + }{ + { + name: "Return substring within the range of the string", + key: "SubStrKey1", + presetValue: "Test String One", + command: []string{"SUBSTR", "SubStrKey1", "5", "10"}, + expectedResponse: "String", + expectedError: nil, + }, + { + name: "Return substring at the end of the string with exact end index", + key: "SubStrKey2", + presetValue: "Test String Two", + command: []string{"SUBSTR", "SubStrKey2", "12", "14"}, + expectedResponse: "Two", + expectedError: nil, + }, + { + name: "Return substring at the end of the string with end index greater than length", + key: "SubStrKey3", + presetValue: "Test String Three", + command: []string{"SUBSTR", "SubStrKey3", "12", "75"}, + expectedResponse: "Three", + expectedError: nil, + }, + { + name: "Return the substring at the start of the string with 0 start index", + key: "SubStrKey4", + presetValue: "Test String Four", + command: []string{"SUBSTR", "SubStrKey4", "0", "3"}, + expectedResponse: "Test", + expectedError: nil, + }, + { + // Return the substring with negative start index. + // Substring should begin abs(start) from the end of the string when start is negative. + name: "Return the substring with negative start index", + key: "SubStrKey5", + presetValue: "Test String Five", + command: []string{"SUBSTR", "SubStrKey5", "-11", "10"}, + expectedResponse: "String", + expectedError: nil, + }, + { + // Return reverse substring with end index smaller than start index. + // When end index is smaller than start index, the 2 indices are reversed. + name: "Return reverse substring with end index smaller than start index", + key: "SubStrKey6", + presetValue: "Test String Six", + command: []string{"SUBSTR", "SubStrKey6", "4", "0"}, + expectedResponse: "tseT", + expectedError: nil, + }, + { + name: "Command too short", + command: []string{"SUBSTR", "key", "10"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "Command too long", + command: []string{"SUBSTR", "key", "10", "15", "20"}, + expectedError: errors.New(constants.WrongArgsResponse), + }, + { + name: "Start index is not an integer", + command: []string{"SUBSTR", "key", "start", "10"}, + expectedError: errors.New("start and end indices must be integers"), + }, + { + name: "End index is not an integer", + command: []string{"SUBSTR", "key", "0", "end"}, + expectedError: errors.New("start and end indices must be integers"), + }, + { + name: "Non-existent key", + command: []string{"SUBSTR", "non-existent-key", "0", "10"}, + expectedError: errors.New("key non-existent-key does not exist"), + }, + } - tests := []struct { - name string - key string - presetValue string - command []string - expectedResponse string - expectedError error - }{ - { - name: "Return substring within the range of the string", - key: "SubStrKey1", - presetValue: "Test String One", - command: []string{"SUBSTR", "SubStrKey1", "5", "10"}, - expectedResponse: "String", - expectedError: nil, - }, - { - name: "Return substring at the end of the string with exact end index", - key: "SubStrKey2", - presetValue: "Test String Two", - command: []string{"SUBSTR", "SubStrKey2", "12", "14"}, - expectedResponse: "Two", - expectedError: nil, - }, - { - name: "Return substring at the end of the string with end index greater than length", - key: "SubStrKey3", - presetValue: "Test String Three", - command: []string{"SUBSTR", "SubStrKey3", "12", "75"}, - expectedResponse: "Three", - expectedError: nil, - }, - { - name: "Return the substring at the start of the string with 0 start index", - key: "SubStrKey4", - presetValue: "Test String Four", - command: []string{"SUBSTR", "SubStrKey4", "0", "3"}, - expectedResponse: "Test", - expectedError: nil, - }, - { - // Return the substring with negative start index. - // Substring should begin abs(start) from the end of the string when start is negative. - name: "Return the substring with negative start index", - key: "SubStrKey5", - presetValue: "Test String Five", - command: []string{"SUBSTR", "SubStrKey5", "-11", "10"}, - expectedResponse: "String", - expectedError: nil, - }, - { - // Return reverse substring with end index smaller than start index. - // When end index is smaller than start index, the 2 indices are reversed. - name: "Return reverse substring with end index smaller than start index", - key: "SubStrKey6", - presetValue: "Test String Six", - command: []string{"SUBSTR", "SubStrKey6", "4", "0"}, - expectedResponse: "tseT", - expectedError: nil, - }, - { - name: "Command too short", - command: []string{"SUBSTR", "key", "10"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "Command too long", - command: []string{"SUBSTR", "key", "10", "15", "20"}, - expectedError: errors.New(constants.WrongArgsResponse), - }, - { - name: "Start index is not an integer", - command: []string{"SUBSTR", "key", "start", "10"}, - expectedError: errors.New("start and end indices must be integers"), - }, - { - name: "End index is not an integer", - command: []string{"SUBSTR", "key", "0", "end"}, - expectedError: errors.New("start and end indices must be integers"), - }, - { - name: "Non-existent key", - command: []string{"SUBSTR", "non-existent-key", "0", "10"}, - expectedError: errors.New("key non-existent-key does not exist"), - }, - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.presetValue != "" { + if err = client.WriteArray([]resp.Value{ + resp.StringValue("SET"), + resp.StringValue(test.key), + resp.StringValue(test.presetValue), + }); err != nil { + t.Error(err) + } + res, _, err := client.ReadValue() + if err != nil { + t.Error(err) + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if test.presetValue != "" { - if err = client.WriteArray([]resp.Value{ - resp.StringValue("SET"), - resp.StringValue(test.key), - resp.StringValue(test.presetValue), - }); err != nil { + if !strings.EqualFold(res.String(), "ok") { + t.Errorf("expected preset response to be OK, got %s", res.String()) + } + } + + command := make([]resp.Value, len(test.command)) + for i, c := range test.command { + command[i] = resp.StringValue(c) + } + + if err = client.WriteArray(command); err != nil { t.Error(err) } res, _, err := client.ReadValue() @@ -403,34 +444,17 @@ func Test_HandleSubStr(t *testing.T) { t.Error(err) } - if !strings.EqualFold(res.String(), "ok") { - t.Errorf("expected preset response to be OK, got %s", res.String()) + if test.expectedError != nil { + if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { + t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + } + return } - } - command := make([]resp.Value, len(test.command)) - for i, c := range test.command { - command[i] = resp.StringValue(c) - } - - if err = client.WriteArray(command); err != nil { - t.Error(err) - } - res, _, err := client.ReadValue() - if err != nil { - t.Error(err) - } - - if test.expectedError != nil { - if !strings.Contains(res.Error().Error(), test.expectedError.Error()) { - t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) + if res.String() != test.expectedResponse { + t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) } - return - } - - if res.String() != test.expectedResponse { - t.Errorf("expected response \"%s\", got \"%s\"", test.expectedResponse, res.String()) - } - }) - } + }) + } + }) } diff --git a/internal/raft/raft.go b/internal/raft/raft.go index d75c7f9..d06290d 100644 --- a/internal/raft/raft.go +++ b/internal/raft/raft.go @@ -216,12 +216,13 @@ func (r *Raft) TakeSnapshot() error { } func (r *Raft) RaftShutdown() { - // Leadership transfer if current node is the leader + // Leadership transfer if current node is the leader. if r.IsRaftLeader() { err := r.raft.LeadershipTransfer().Error() if err != nil { - log.Fatal(err) + log.Printf("raft shutdown: %v\n", err) + return } - log.Println("Leadership transfer successful.") + log.Println("leadership transfer successful.") } }