diff --git a/docker-compose.yaml b/docker-compose.yaml index e80559c..be72e2e 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -47,216 +47,216 @@ services: networks: - testnet - cluster_node_1: - container_name: cluster_node_1 - build: - context: . - dockerfile: Dockerfile.dev - environment: - - PORT=7480 - - RAFT_PORT=8000 - - ML_PORT=7946 - - KEY=/generic/ssl/certs/echovault/server1.key - - CERT=/generic/ssl/certs/echovault/server1.crt - - SERVER_ID=1 - - DATA_DIR=/var/lib/echovault - - IN_MEMORY=false - - TLS=false - - MTLS=false - - BOOTSTRAP_CLUSTER=true - - ACL_CONFIG=/generic/config/echovault/acl.yml - - REQUIRE_PASS=false - - FORWARD_COMMAND=true - - SNAPSHOT_THRESHOLD=1000 - - SNAPSHOT_INTERVAL=5m30s - - RESTORE_SNAPSHOT=false - - RESTORE_AOF=false - - AOF_SYNC_STRATEGY=everysec - - MAX_MEMORY=2000kb - - EVICTION_POLICY=allkeys-lfu - # List of server cert/key pairs - - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key - - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key - # List of client certificate authorities - - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt - ports: - - "7481:7480" - - "7945:7946" - - "8000:8000" - volumes: - - ./config/acl.yml:/generic/config/echovault/acl.yml - - ./volumes/cluster_node_1:/var/lib/echovault - networks: - - testnet - - cluster_node_2: - container_name: cluster_node_2 - build: - context: . - dockerfile: Dockerfile.dev - environment: - - PORT=7480 - - RAFT_PORT=8000 - - ML_PORT=7946 - - KEY=/generic/ssl/certs/echovault/server1.key - - CERT=/generic/ssl/certs/echovault/server1.crt - - SERVER_ID=2 - - JOIN_ADDR=cluster_node_1:7946 - - DATA_DIR=/var/lib/echovault - - IN_MEMORY=false - - TLS=false - - MTLS=false - - BOOTSTRAP_CLUSTER=false - - ACL_CONFIG=/generic/config/echovault/acl.yml - - REQUIRE_PASS=false - - FORWARD_COMMAND=true - - SNAPSHOT_THRESHOLD=1000 - - SNAPSHOT_INTERVAL=5m30s - - RESTORE_SNAPSHOT=false - - RESTORE_AOF=false - - AOF_SYNC_STRATEGY=everysec - - MAX_MEMORY=2000kb - - EVICTION_POLICY=allkeys-lfu - # List of server cert/key pairs - - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key - - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key - # List of client certificate authorities - - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt - ports: - - "7482:7480" - - "7947:7946" - - "8001:8000" - volumes: - - ./config/acl.yml:/generic/config/echovault/acl.yml - - ./volumes/cluster_node_2:/var/lib/echovault - networks: - - testnet - - cluster_node_3: - container_name: cluster_node_3 - build: - context: . - dockerfile: Dockerfile.dev - environment: - - PORT=7480 - - RAFT_PORT=8000 - - ML_PORT=7946 - - KEY=/generic/ssl/certs/echovault/server1.key - - CERT=/generic/ssl/certs/echovault/server1.crt - - SERVER_ID=3 - - JOIN_ADDR=cluster_node_1:7946 - - DATA_DIR=/var/lib/echovault - - IN_MEMORY=false - - TLS=false - - MTLS=false - - BOOTSTRAP_CLUSTER=false - - ACL_CONFIG=/generic/config/echovault/acl.yml - - REQUIRE_PASS=false - - FORWARD_COMMAND=true - - SNAPSHOT_THRESHOLD=1000 - - SNAPSHOT_INTERVAL=5m30s - - RESTORE_SNAPSHOT=false - - RESTORE_AOF=false - - AOF_SYNC_STRATEGY=everysec - - MAX_MEMORY=2000kb - - EVICTION_POLICY=allkeys-lfu - # List of server cert/key pairs - - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key - - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key - # List of client certificate authorities - - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt - ports: - - "7483:7480" - - "7948:7946" - - "8002:8000" - volumes: - - ./config/acl.yml:/generic/config/echovault/acl.yml - - ./volumes/cluster_node_3:/var/lib/echovault - networks: - - testnet - - cluster_node_4: - container_name: cluster_node_4 - build: - context: . - dockerfile: Dockerfile.dev - environment: - - PORT=7480 - - RAFT_PORT=8000 - - ML_PORT=7946 - - KEY=/generic/ssl/certs/echovault/server1.key - - CERT=/generic/ssl/certs/echovault/server1.crt - - SERVER_ID=4 - - JOIN_ADDR=cluster_node_1:7946 - - DATA_DIR=/var/lib/echovault - - IN_MEMORY=false - - TLS=false - - MTLS=false - - BOOTSTRAP_CLUSTER=false - - ACL_CONFIG=/generic/config/echovault/acl.yml - - REQUIRE_PASS=false - - FORWARD_COMMAND=true - - SNAPSHOT_THRESHOLD=1000 - - SNAPSHOT_INTERVAL=5m30s - - RESTORE_SNAPSHOT=false - - RESTORE_AOF=false - - AOF_SYNC_STRATEGY=everysec - - MAX_MEMORY=2000kb - - EVICTION_POLICY=allkeys-lfu - # List of server cert/key pairs - - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key - - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key - # List of client certificate authorities - - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt - ports: - - "7484:7480" - - "7949:7946" - - "8003:8000" - volumes: - - ./config/acl.yml:/generic/config/echovault/acl.yml - - ./volumes/cluster_node_4:/var/lib/echovault - networks: - - testnet - - cluster_node_5: - container_name: cluster_node_5 - build: - context: . - dockerfile: Dockerfile.dev - environment: - - PORT=7480 - - RAFT_PORT=8000 - - ML_PORT=7946 - - KEY=/generic/ssl/certs/echovault/server1.key - - CERT=/generic/ssl/certs/echovault/server1.crt - - SERVER_ID=5 - - JOIN_ADDR=cluster_node_1:7946 - - DATA_DIR=/var/lib/echovault - - IN_MEMORY=false - - TLS=false - - MTLS=false - - BOOTSTRAP_CLUSTER=false - - ACL_CONFIG=/generic/config/echovault/acl.yml - - REQUIRE_PASS=false - - FORWARD_COMMAND=true - - SNAPSHOT_THRESHOLD=1000 - - SNAPSHOT_INTERVAL=5m30s - - RESTORE_SNAPSHOT=false - - RESTORE_AOF=false - - AOF_SYNC_STRATEGY=everysec - - MAX_MEMORY=2000kb - - EVICTION_POLICY=allkeys-lfu - # List of server cert/key pairs - - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key - - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key - # List of client certificate authorities - - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt - ports: - - "7485:7480" - - "7950:7946" - - "8004:8000" - volumes: - - ./config/acl.yml:/generic/config/echovault/acl.yml - - ./volumes/cluster_node_5:/var/lib/echovault - networks: - - testnet \ No newline at end of file +# cluster_node_1: +# container_name: cluster_node_1 +# build: +# context: . +# dockerfile: Dockerfile.dev +# environment: +# - PORT=7480 +# - RAFT_PORT=8000 +# - ML_PORT=7946 +# - KEY=/generic/ssl/certs/echovault/server1.key +# - CERT=/generic/ssl/certs/echovault/server1.crt +# - SERVER_ID=1 +# - DATA_DIR=/var/lib/echovault +# - IN_MEMORY=false +# - TLS=false +# - MTLS=false +# - BOOTSTRAP_CLUSTER=true +# - ACL_CONFIG=/generic/config/echovault/acl.yml +# - REQUIRE_PASS=false +# - FORWARD_COMMAND=true +# - SNAPSHOT_THRESHOLD=1000 +# - SNAPSHOT_INTERVAL=5m30s +# - RESTORE_SNAPSHOT=false +# - RESTORE_AOF=false +# - AOF_SYNC_STRATEGY=everysec +# - MAX_MEMORY=2000kb +# - EVICTION_POLICY=allkeys-lfu +# # List of server cert/key pairs +# - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key +# - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key +# # List of client certificate authorities +# - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt +# ports: +# - "7481:7480" +# - "7945:7946" +# - "8000:8000" +# volumes: +# - ./config/acl.yml:/generic/config/echovault/acl.yml +# - ./volumes/cluster_node_1:/var/lib/echovault +# networks: +# - testnet +# +# cluster_node_2: +# container_name: cluster_node_2 +# build: +# context: . +# dockerfile: Dockerfile.dev +# environment: +# - PORT=7480 +# - RAFT_PORT=8000 +# - ML_PORT=7946 +# - KEY=/generic/ssl/certs/echovault/server1.key +# - CERT=/generic/ssl/certs/echovault/server1.crt +# - SERVER_ID=2 +# - JOIN_ADDR=cluster_node_1:7946 +# - DATA_DIR=/var/lib/echovault +# - IN_MEMORY=false +# - TLS=false +# - MTLS=false +# - BOOTSTRAP_CLUSTER=false +# - ACL_CONFIG=/generic/config/echovault/acl.yml +# - REQUIRE_PASS=false +# - FORWARD_COMMAND=true +# - SNAPSHOT_THRESHOLD=1000 +# - SNAPSHOT_INTERVAL=5m30s +# - RESTORE_SNAPSHOT=false +# - RESTORE_AOF=false +# - AOF_SYNC_STRATEGY=everysec +# - MAX_MEMORY=2000kb +# - EVICTION_POLICY=allkeys-lfu +# # List of server cert/key pairs +# - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key +# - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key +# # List of client certificate authorities +# - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt +# ports: +# - "7482:7480" +# - "7947:7946" +# - "8001:8000" +# volumes: +# - ./config/acl.yml:/generic/config/echovault/acl.yml +# - ./volumes/cluster_node_2:/var/lib/echovault +# networks: +# - testnet +# +# cluster_node_3: +# container_name: cluster_node_3 +# build: +# context: . +# dockerfile: Dockerfile.dev +# environment: +# - PORT=7480 +# - RAFT_PORT=8000 +# - ML_PORT=7946 +# - KEY=/generic/ssl/certs/echovault/server1.key +# - CERT=/generic/ssl/certs/echovault/server1.crt +# - SERVER_ID=3 +# - JOIN_ADDR=cluster_node_1:7946 +# - DATA_DIR=/var/lib/echovault +# - IN_MEMORY=false +# - TLS=false +# - MTLS=false +# - BOOTSTRAP_CLUSTER=false +# - ACL_CONFIG=/generic/config/echovault/acl.yml +# - REQUIRE_PASS=false +# - FORWARD_COMMAND=true +# - SNAPSHOT_THRESHOLD=1000 +# - SNAPSHOT_INTERVAL=5m30s +# - RESTORE_SNAPSHOT=false +# - RESTORE_AOF=false +# - AOF_SYNC_STRATEGY=everysec +# - MAX_MEMORY=2000kb +# - EVICTION_POLICY=allkeys-lfu +# # List of server cert/key pairs +# - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key +# - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key +# # List of client certificate authorities +# - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt +# ports: +# - "7483:7480" +# - "7948:7946" +# - "8002:8000" +# volumes: +# - ./config/acl.yml:/generic/config/echovault/acl.yml +# - ./volumes/cluster_node_3:/var/lib/echovault +# networks: +# - testnet +# +# cluster_node_4: +# container_name: cluster_node_4 +# build: +# context: . +# dockerfile: Dockerfile.dev +# environment: +# - PORT=7480 +# - RAFT_PORT=8000 +# - ML_PORT=7946 +# - KEY=/generic/ssl/certs/echovault/server1.key +# - CERT=/generic/ssl/certs/echovault/server1.crt +# - SERVER_ID=4 +# - JOIN_ADDR=cluster_node_1:7946 +# - DATA_DIR=/var/lib/echovault +# - IN_MEMORY=false +# - TLS=false +# - MTLS=false +# - BOOTSTRAP_CLUSTER=false +# - ACL_CONFIG=/generic/config/echovault/acl.yml +# - REQUIRE_PASS=false +# - FORWARD_COMMAND=true +# - SNAPSHOT_THRESHOLD=1000 +# - SNAPSHOT_INTERVAL=5m30s +# - RESTORE_SNAPSHOT=false +# - RESTORE_AOF=false +# - AOF_SYNC_STRATEGY=everysec +# - MAX_MEMORY=2000kb +# - EVICTION_POLICY=allkeys-lfu +# # List of server cert/key pairs +# - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key +# - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key +# # List of client certificate authorities +# - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt +# ports: +# - "7484:7480" +# - "7949:7946" +# - "8003:8000" +# volumes: +# - ./config/acl.yml:/generic/config/echovault/acl.yml +# - ./volumes/cluster_node_4:/var/lib/echovault +# networks: +# - testnet +# +# cluster_node_5: +# container_name: cluster_node_5 +# build: +# context: . +# dockerfile: Dockerfile.dev +# environment: +# - PORT=7480 +# - RAFT_PORT=8000 +# - ML_PORT=7946 +# - KEY=/generic/ssl/certs/echovault/server1.key +# - CERT=/generic/ssl/certs/echovault/server1.crt +# - SERVER_ID=5 +# - JOIN_ADDR=cluster_node_1:7946 +# - DATA_DIR=/var/lib/echovault +# - IN_MEMORY=false +# - TLS=false +# - MTLS=false +# - BOOTSTRAP_CLUSTER=false +# - ACL_CONFIG=/generic/config/echovault/acl.yml +# - REQUIRE_PASS=false +# - FORWARD_COMMAND=true +# - SNAPSHOT_THRESHOLD=1000 +# - SNAPSHOT_INTERVAL=5m30s +# - RESTORE_SNAPSHOT=false +# - RESTORE_AOF=false +# - AOF_SYNC_STRATEGY=everysec +# - MAX_MEMORY=2000kb +# - EVICTION_POLICY=allkeys-lfu +# # List of server cert/key pairs +# - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key +# - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key +# # List of client certificate authorities +# - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt +# ports: +# - "7485:7480" +# - "7950:7946" +# - "8004:8000" +# volumes: +# - ./config/acl.yml:/generic/config/echovault/acl.yml +# - ./volumes/cluster_node_5:/var/lib/echovault +# networks: +# - testnet \ No newline at end of file diff --git a/src/modules/generic/commands.go b/src/modules/generic/commands.go index 556eec6..3d545d2 100644 --- a/src/modules/generic/commands.go +++ b/src/modules/generic/commands.go @@ -35,7 +35,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co // If GET is provided, the response should be the current stored value. // If there's no current value, then the response should be nil. if params.get { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { res = []byte("$-1\r\n") } else { res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key))) @@ -44,19 +44,19 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co if "xx" == strings.ToLower(params.exists) { // If XX is specified, make sure the key exists. - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, fmt.Errorf("key %s does not exist", key) } _, err = server.KeyLock(ctx, key) } else if "nx" == strings.ToLower(params.exists) { // If NX is specified, make sure that the key does not currently exist. - if server.KeyExists(key) { + if server.KeyExists(ctx, key) { return nil, fmt.Errorf("key %s already exists", key) } _, err = server.CreateKeyAndLock(ctx, key) } else { // Neither XX not NX are specified, lock or create the lock - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // Key does not exist, create it _, err = server.CreateKeyAndLock(ctx, key) } else { @@ -67,7 +67,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co if err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) if err = server.SetValue(ctx, key, utils.AdaptType(value)); err != nil { return nil, err @@ -92,7 +92,7 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, _ *net.C defer func() { for k, v := range entries { if v.locked { - server.KeyUnlock(k) + server.KeyUnlock(ctx, k) entries[k] = KeyObject{ value: v.value, locked: false, @@ -114,7 +114,7 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, _ *net.C // Acquire all the locks for each key first // If any key cannot be acquired, abandon transaction and release all currently held keys for k, v := range entries { - if server.KeyExists(k) { + if server.KeyExists(ctx, k) { if _, err := server.KeyLock(ctx, k); err != nil { return nil, err } @@ -144,7 +144,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co } key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } @@ -152,7 +152,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co if err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) value := server.GetValue(ctx, key) @@ -173,7 +173,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C // Skip if we have already locked this key continue } - if server.KeyExists(key) { + if server.KeyExists(ctx, key) { _, err = server.KeyRLock(ctx, key) if err != nil { return nil, fmt.Errorf("could not obtain lock for %s key", key) @@ -186,7 +186,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) locks[key] = false } } @@ -234,14 +234,14 @@ func handlePersist(ctx context.Context, cmd []string, server utils.Server, _ *ne key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) expireAt := server.GetExpiry(ctx, key) if expireAt == (time.Time{}) { @@ -261,14 +261,14 @@ func handleExpireTime(ctx context.Context, cmd []string, server utils.Server, _ key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":-2\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) expireAt := server.GetExpiry(ctx, key) @@ -292,14 +292,14 @@ func handleTTL(ctx context.Context, cmd []string, server utils.Server, _ *net.Co key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":-2\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) expireAt := server.GetExpiry(ctx, key) @@ -337,14 +337,14 @@ func handleExpire(ctx context.Context, cmd []string, server utils.Server, _ *net expireAt = time.Now().Add(time.Duration(n) * time.Millisecond) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) if len(cmd) == 3 { server.SetExpiry(ctx, key, expireAt, true) @@ -405,14 +405,14 @@ func handleExpireAt(ctx context.Context, cmd []string, server utils.Server, _ *n expireAt = time.UnixMilli(n) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) if len(cmd) == 3 { server.SetExpiry(ctx, key, expireAt, true) diff --git a/src/modules/generic/commands_test.go b/src/modules/generic/commands_test.go index a054577..62d002a 100644 --- a/src/modules/generic/commands_test.go +++ b/src/modules/generic/commands_test.go @@ -177,7 +177,7 @@ func Test_HandleMSET(t *testing.T) { t.Errorf("expected value %s for key %s, got %s", ev, key, value) } } - mockServer.KeyRUnlock(key) + mockServer.KeyRUnlock(context.Background(), key) } } } @@ -211,8 +211,10 @@ func Test_HandleGET(t *testing.T) { if err != nil { t.Error(err) } - mockServer.SetValue(ctx, key, value) - mockServer.KeyUnlock(key) + if err = mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil) if err != nil { @@ -297,8 +299,10 @@ func Test_HandleMGET(t *testing.T) { if err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, test.presetValues[i]) - mockServer.KeyUnlock(key) + if err = mockServer.SetValue(context.Background(), key, test.presetValues[i]); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(context.Background(), key) } // Test the command and its results res, err := handleMGet(context.Background(), test.command, mockServer, nil) diff --git a/src/modules/hash/commands.go b/src/modules/hash/commands.go index 29ca636..22d586e 100644 --- a/src/modules/hash/commands.go +++ b/src/modules/hash/commands.go @@ -29,12 +29,12 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne entries[cmd[i]] = utils.AdaptType(cmd[i+1]) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { _, err = server.CreateKeyAndLock(ctx, key) if err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) if err = server.SetValue(ctx, key, entries); err != nil { return nil, err } @@ -44,7 +44,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -79,14 +79,14 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] fields := cmd[2:] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -130,14 +130,14 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] fields := cmd[2:] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -180,14 +180,14 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -242,14 +242,14 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -337,14 +337,14 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -362,14 +362,14 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -410,11 +410,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn intIncrement = i } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { if _, err := server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) hash := make(map[string]interface{}) if strings.EqualFold(cmd[0], "hincrbyfloat") { hash[field] = floatIncrement @@ -434,7 +434,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn if _, err := server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -484,14 +484,14 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -525,14 +525,14 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] field := cmd[2] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { @@ -555,14 +555,14 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] fields := cmd[2:] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { diff --git a/src/modules/hash/commands_test.go b/src/modules/hash/commands_test.go index 09a5fc6..b8f5035 100644 --- a/src/modules/hash/commands_test.go +++ b/src/modules/hash/commands_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" @@ -98,15 +99,18 @@ func Test_HandleHSET(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSET/HSETNX, %d", i)) if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHSET(context.Background(), test.command, mockServer, nil) + res, err := handleHSET(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -122,10 +126,10 @@ func Test_HandleHSET(t *testing.T) { t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, rv.Integer()) } // Check that all the values are what is expected - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -242,15 +246,19 @@ func Test_HandleHINCRBY(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHINCRBY(context.Background(), test.command, mockServer, nil) + res, err := handleHINCRBY(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -275,10 +283,10 @@ func Test_HandleHINCRBY(t *testing.T) { } } // Check that all the values are what is expected - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -340,15 +348,19 @@ func Test_HandleHGET(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHGET(context.Background(), test.command, mockServer, nil) + res, err := handleHGET(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -440,15 +452,19 @@ func Test_HandleHSTRLEN(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSTRLEN, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHSTRLEN(context.Background(), test.command, mockServer, nil) + res, err := handleHSTRLEN(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -535,15 +551,19 @@ func Test_HandleHVALS(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HVALS, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHVALS(context.Background(), test.command, mockServer, nil) + res, err := handleHVALS(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -727,15 +747,19 @@ func Test_HandleHRANDFIELD(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HRANDFIELD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHRANDFIELD(context.Background(), test.command, mockServer, nil) + res, err := handleHRANDFIELD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -841,15 +865,19 @@ func Test_HandleHLEN(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HLEN, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHLEN(context.Background(), test.command, mockServer, nil) + res, err := handleHLEN(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -931,15 +959,19 @@ func Test_HandleHKeys(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HKEYS, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHKEYS(context.Background(), test.command, mockServer, nil) + res, err := handleHKEYS(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1028,15 +1060,19 @@ func Test_HandleHGETALL(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HGETALL, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHGETALL(context.Background(), test.command, mockServer, nil) + res, err := handleHGETALL(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1136,15 +1172,19 @@ func Test_HandleHEXISTS(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HEXISTS, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHEXISTS(context.Background(), test.command, mockServer, nil) + res, err := handleHEXISTS(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1226,15 +1266,19 @@ func Test_HandleHDEL(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HDEL, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleHDEL(context.Background(), test.command, mockServer, nil) + res, err := handleHDEL(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1252,10 +1296,10 @@ func Test_HandleHDEL(t *testing.T) { } continue } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - if hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}); ok { + if hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok { for field, value := range hash { if value != test.expectedValue[field] { t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value) diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go index 3a0f6d4..5591125 100644 --- a/src/modules/list/commands.go +++ b/src/modules/list/commands.go @@ -19,7 +19,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // If key does not exist, return 0 return []byte(":0\r\n"), nil } @@ -27,7 +27,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) if list, ok := server.GetValue(ctx, key).([]interface{}); ok { return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil @@ -49,7 +49,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn * return nil, errors.New("index must be an integer") } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, errors.New("LINDEX command on non-list item") } @@ -57,7 +57,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } list, ok := server.GetValue(ctx, key).([]interface{}) - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) if !ok { return nil, errors.New("LINDEX command on non-list item") @@ -84,14 +84,14 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * return nil, errors.New("start and end indices must be integers") } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, errors.New("LRANGE command on non-list item") } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { @@ -162,14 +162,14 @@ func handleLSet(ctx context.Context, cmd []string, server utils.Server, conn *ne return nil, errors.New("index must be an integer") } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, errors.New("LSET command on non-list item") } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { @@ -206,14 +206,14 @@ func handleLTrim(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, errors.New("end index must be greater than start index or -1") } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, errors.New("LTRIM command on non-list item") } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { @@ -253,14 +253,14 @@ func handleLRem(ctx context.Context, cmd []string, server utils.Server, conn *ne absoluteCount := utils.AbsInt(count) - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, errors.New("LREM command on non-list item") } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { @@ -320,20 +320,20 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT") } - if !server.KeyExists(source) || !server.KeyExists(destination) { + if !server.KeyExists(ctx, source) || !server.KeyExists(ctx, destination) { return nil, errors.New("both source and destination must be lists") } if _, err = server.KeyLock(ctx, source); err != nil { return nil, err } - defer server.KeyUnlock(source) + defer server.KeyUnlock(ctx, source) _, err = server.KeyLock(ctx, destination) if err != nil { return nil, err } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) sourceList, sourceOk := server.GetValue(ctx, source).([]interface{}) destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{}) @@ -380,7 +380,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { switch strings.ToLower(cmd[0]) { case "lpushx": return nil, errors.New("LPUSHX command on non-list item") @@ -397,7 +397,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) currentList := server.GetValue(ctx, key) @@ -426,7 +426,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n newElems = append(newElems, utils.AdaptType(elem)) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { switch strings.ToLower(cmd[0]) { case "rpushx": return nil, errors.New("RPUSHX command on non-list item") @@ -434,7 +434,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) if err = server.SetValue(ctx, key, []interface{}{}); err != nil { return nil, err } @@ -443,7 +443,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) } currentList := server.GetValue(ctx, key) @@ -468,14 +468,14 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0])) } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { diff --git a/src/modules/list/commands_test.go b/src/modules/list/commands_test.go index a04dc1c..62dd728 100644 --- a/src/modules/list/commands_test.go +++ b/src/modules/list/commands_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" @@ -69,15 +70,19 @@ func Test_HandleLLEN(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LLEN, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLLen(context.Background(), test.command, mockServer, nil) + res, err := handleLLen(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -199,15 +204,19 @@ func Test_HandleLINDEX(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LINDEX, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLIndex(context.Background(), test.command, mockServer, nil) + res, err := handleLIndex(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -350,15 +359,19 @@ func Test_HandleLRANGE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LRANGE, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLRange(context.Background(), test.command, mockServer, nil) + res, err := handleLRange(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -486,15 +499,19 @@ func Test_HandleLSET(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LSET, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLSet(context.Background(), test.command, mockServer, nil) + res, err := handleLSet(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -509,10 +526,10 @@ func Test_HandleLSET(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -524,7 +541,7 @@ func Test_HandleLSET(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -644,15 +661,19 @@ func Test_HandleLTRIM(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LTRIM, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLTrim(context.Background(), test.command, mockServer, nil) + res, err := handleLTrim(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -667,10 +688,10 @@ func Test_HandleLTRIM(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -682,7 +703,7 @@ func Test_HandleLTRIM(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -763,15 +784,19 @@ func Test_HandleLREM(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LREM, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLRem(context.Background(), test.command, mockServer, nil) + res, err := handleLRem(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -786,10 +811,10 @@ func Test_HandleLREM(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -801,7 +826,7 @@ func Test_HandleLREM(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -970,17 +995,21 @@ func Test_HandleLMOVE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LMOVE, %d", i)) + if test.preset { for key, value := range test.presetValue { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleLMove(context.Background(), test.command, mockServer, nil) + res, err := handleLMove(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -996,10 +1025,10 @@ func Test_HandleLMOVE(t *testing.T) { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } for key, value := range test.expectedValue { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), key).([]interface{}) + list, ok := mockServer.GetValue(ctx, key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1015,7 +1044,7 @@ func Test_HandleLMOVE(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], list[i]) } } - mockServer.KeyRUnlock(key) + mockServer.KeyRUnlock(ctx, key) } } } @@ -1079,15 +1108,19 @@ func Test_HandleLPUSH(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LPUSH/LPUSHX, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleLPush(context.Background(), test.command, mockServer, nil) + res, err := handleLPush(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1102,10 +1135,10 @@ func Test_HandleLPUSH(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1117,7 +1150,7 @@ func Test_HandleLPUSH(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -1180,15 +1213,19 @@ func Test_HandleRPUSH(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("RPUSH/RPUSHX, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleRPush(context.Background(), test.command, mockServer, nil) + res, err := handleRPush(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1203,10 +1240,10 @@ func Test_HandleRPUSH(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1218,7 +1255,7 @@ func Test_HandleRPUSH(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -1290,15 +1327,19 @@ func Test_HandlePop(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LPOP/RPOP, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handlePop(context.Background(), test.command, mockServer, nil) + res, err := handlePop(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1313,10 +1354,10 @@ func Test_HandlePop(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) + list, ok := mockServer.GetValue(ctx, test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1328,6 +1369,6 @@ func Test_HandlePop(t *testing.T) { t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } diff --git a/src/modules/set/commands.go b/src/modules/set/commands.go index 20923e1..1f26755 100644 --- a/src/modules/set/commands.go +++ b/src/modules/set/commands.go @@ -20,7 +20,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne var set *Set - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { set = NewSet(cmd[2:]) if ok, err := server.CreateKeyAndLock(ctx, key); !ok && err != nil { return nil, err @@ -28,14 +28,14 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne if err = server.SetValue(ctx, key, set); err != nil { return nil, err } - server.KeyUnlock(key) + server.KeyUnlock(ctx, key) return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -55,14 +55,14 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(fmt.Sprintf(":0\r\n")), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -81,13 +81,13 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n } // Extract base set first - if !server.KeyExists(keys[0]) { + if !server.KeyExists(ctx, keys[0]) { return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[0]) } if _, err = server.KeyRLock(ctx, keys[0]); err != nil { return nil, err } - defer server.KeyRUnlock(keys[0]) + defer server.KeyRUnlock(ctx, keys[0]) baseSet, ok := server.GetValue(ctx, keys[0]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[0]) @@ -97,13 +97,13 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys[1:] { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { continue } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -144,13 +144,13 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co destination := keys[0] // Extract base set first - if !server.KeyExists(keys[1]) { + if !server.KeyExists(ctx, keys[1]) { return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[1]) } if _, err := server.KeyRLock(ctx, keys[1]); err != nil { return nil, err } - defer server.KeyRUnlock(keys[1]) + defer server.KeyRUnlock(ctx, keys[1]) baseSet, ok := server.GetValue(ctx, keys[1]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[1]) @@ -160,13 +160,13 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys[2:] { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { continue } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -189,14 +189,14 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co res := fmt.Sprintf(":%d\r\n", len(elems)) - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } if err = server.SetValue(ctx, destination, diff); err != nil { return nil, err } - server.KeyUnlock(destination) + server.KeyUnlock(ctx, destination) return []byte(res), nil } @@ -206,7 +206,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co if err = server.SetValue(ctx, destination, diff); err != nil { return nil, err } - server.KeyUnlock(destination) + server.KeyUnlock(ctx, destination) return []byte(res), nil } @@ -221,13 +221,13 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn * defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys[0:] { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // If key does not exist, then there is no intersection return []byte("*0\r\n"), nil } @@ -297,13 +297,13 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // If key does not exist, then there is no intersection return []byte(":0\r\n"), nil } @@ -343,13 +343,13 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys[1:] { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // If key does not exist, then there is no intersection return []byte(":0\r\n"), nil } @@ -373,7 +373,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c intersect, _ := Intersection(0, sets...) destination := keys[0] - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -386,7 +386,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c if err = server.SetValue(ctx, destination, intersect); err != nil { return nil, err } - server.KeyUnlock(destination) + server.KeyUnlock(ctx, destination) return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil } @@ -399,14 +399,14 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -428,14 +428,14 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -464,7 +464,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co key := keys[0] members := cmd[2:] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { res := fmt.Sprintf("*%d", len(members)) for i, _ := range members { res = fmt.Sprintf("%s\r\n:0", res) @@ -478,7 +478,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -508,14 +508,14 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n destination := keys[1] member := cmd[3] - if !server.KeyExists(source) { + if !server.KeyExists(ctx, source) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, source); err != nil { return nil, err } - defer server.KeyUnlock(source) + defer server.KeyUnlock(ctx, source) sourceSet, ok := server.GetValue(ctx, source).(*Set) if !ok { @@ -524,12 +524,12 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n var destinationSet *Set - if !server.KeyExists(destination) { + if !server.KeyExists(ctx, destination) { // Destination key does not exist if _, err = server.CreateKeyAndLock(ctx, destination); err != nil { return nil, err } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) destinationSet = NewSet([]string{}) if err = server.SetValue(ctx, destination, destinationSet); err != nil { return nil, err @@ -539,7 +539,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n if _, err := server.KeyLock(ctx, destination); err != nil { return nil, err } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) ds, ok := server.GetValue(ctx, destination).(*Set) if !ok { return nil, errors.New("destination is not a set") @@ -569,14 +569,14 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne count = c } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*-1\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -613,14 +613,14 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c count = c } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*-1\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -649,14 +649,14 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] members := cmd[2:] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*Set) if !ok { @@ -678,13 +678,13 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn * defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { continue } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -729,13 +729,13 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() for _, key := range keys[1:] { - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { continue } if _, err = server.KeyRLock(ctx, key); err != nil { @@ -761,7 +761,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c destination := cmd[1] - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -770,7 +770,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) if err = server.SetValue(ctx, destination, union); err != nil { return nil, err diff --git a/src/modules/set/commant_test.go b/src/modules/set/commant_test.go index 3081cd9..0776efc 100644 --- a/src/modules/set/commant_test.go +++ b/src/modules/set/commant_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" @@ -59,15 +60,19 @@ func Test_HandleSADD(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SADD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSADD(context.Background(), test.command, mockServer, nil) + res, err := handleSADD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -85,10 +90,10 @@ func Test_HandleSADD(t *testing.T) { if rv.Integer() != test.expectedResponse { t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer()) } - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) + set, ok := mockServer.GetValue(ctx, test.key).(*Set) if !ok { t.Errorf("expected set value at key \"%s\"", test.key) } @@ -100,7 +105,7 @@ func Test_HandleSADD(t *testing.T) { t.Errorf("could not find member \"%s\" in expected set", member) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -160,15 +165,19 @@ func Test_HandleSCARD(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SCARD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSCARD(context.Background(), test.command, mockServer, nil) + res, err := handleSCARD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -260,17 +269,21 @@ func Test_HandleSDIFF(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SDIFF, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSDIFF(context.Background(), test.command, mockServer, nil) + res, err := handleSDIFF(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -376,17 +389,21 @@ func Test_HandleSDIFFSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SDIFFSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSDIFFSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleSDIFFSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -405,10 +422,10 @@ func Test_HandleSDIFFSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) + set, ok := mockServer.GetValue(ctx, test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -417,7 +434,7 @@ func Test_HandleSDIFFSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } @@ -493,17 +510,21 @@ func Test_HandleSINTER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTER, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSINTER(context.Background(), test.command, mockServer, nil) + res, err := handleSINTER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -596,29 +617,33 @@ func Test_HandleSINTERCARD(t *testing.T) { "key15": NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), "key16": NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), }, - command: []string{"SINTERSTORE", "key14", "key15", "key16"}, + command: []string{"SINTERCARD", "key14", "key15", "key16"}, expectedResponse: 0, expectedError: errors.New("value at key key14 is not a set"), }, { // 7. Command too short preset: false, - command: []string{"SINTERSTORE"}, + command: []string{"SINTERCARD"}, expectedResponse: 0, expectedError: errors.New(utils.WrongArgsResponse), }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTERCARD, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSINTERCARD(context.Background(), test.command, mockServer, nil) + res, err := handleSINTERCARD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -722,17 +747,21 @@ func Test_HandleSINTERSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTERSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSINTERSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleSINTERSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -751,10 +780,10 @@ func Test_HandleSINTERSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) + set, ok := mockServer.GetValue(ctx, test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -763,7 +792,7 @@ func Test_HandleSINTERSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } @@ -819,15 +848,19 @@ func Test_HandleSISMEMBER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SISMEMBER, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSISMEMBER(context.Background(), test.command, mockServer, nil) + res, err := handleSISMEMBER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -897,15 +930,19 @@ func Test_HandleSMEMBERS(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMEMBERS, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSMEMBERS(context.Background(), test.command, mockServer, nil) + res, err := handleSMEMBERS(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -979,15 +1016,19 @@ func Test_HandleSMISMEMBER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMISMEMBER, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSMISMEMBER(context.Background(), test.command, mockServer, nil) + res, err := handleSMISMEMBER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1090,17 +1131,21 @@ func Test_HandleSMOVE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMOVE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSMOVE(context.Background(), test.command, mockServer, nil) + res, err := handleSMOVE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1123,10 +1168,10 @@ func Test_HandleSMOVE(t *testing.T) { if !ok { t.Errorf("expected value at \"%s\" should be a set", key) } - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(key) } - set, ok := mockServer.GetValue(context.Background(), key).(*Set) + set, ok := mockServer.GetValue(ctx, key).(*Set) if !ok { t.Errorf("expected set \"%s\" to be a set, got another type", key) } @@ -1138,7 +1183,7 @@ func Test_HandleSMOVE(t *testing.T) { t.Errorf("could not find element \"%s\" in the expected set", element) } } - mockServer.KeyRUnlock(key) + mockServer.KeyRUnlock(ctx, key) } } } @@ -1190,15 +1235,19 @@ func Test_HandleSPOP(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SPOP, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSPOP(context.Background(), test.command, mockServer, nil) + res, err := handleSPOP(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1220,10 +1269,10 @@ func Test_HandleSPOP(t *testing.T) { } } // 2. Fetch the set and check if its cardinality is what we expect. - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) + set, ok := mockServer.GetValue(ctx, test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1301,15 +1350,19 @@ func Test_HandleSRANDMEMBER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SRANDMEMBER, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSRANDMEMBER(context.Background(), test.command, mockServer, nil) + res, err := handleSRANDMEMBER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1331,10 +1384,10 @@ func Test_HandleSRANDMEMBER(t *testing.T) { } } // 2. Fetch the set and check if its cardinality is what we expect. - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) + set, ok := mockServer.GetValue(ctx, test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1407,15 +1460,19 @@ func Test_HandleSREM(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SREM, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSREM(context.Background(), test.command, mockServer, nil) + res, err := handleSREM(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1434,10 +1491,10 @@ func Test_HandleSREM(t *testing.T) { t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) + set, ok := mockServer.GetValue(ctx, test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1446,7 +1503,7 @@ func Test_HandleSREM(t *testing.T) { t.Errorf("element \"%s\" not found in expected set values but found in set", element) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } } @@ -1515,17 +1572,21 @@ func Test_HandleSUNION(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUNION, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSUNION(context.Background(), test.command, mockServer, nil) + res, err := handleSUNION(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1609,17 +1670,21 @@ func Test_HandleSUNIONSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUNIONSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleSUNIONSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleSUNIONSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1638,10 +1703,10 @@ func Test_HandleSUNIONSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) + set, ok := mockServer.GetValue(ctx, test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -1650,7 +1715,7 @@ func Test_HandleSUNIONSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 903f6b7..b327abd 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -126,13 +126,13 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne } } - if server.KeyExists(key) { + if server.KeyExists(ctx, key) { // Key exists _, err = server.KeyLock(ctx, key) if err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) @@ -154,7 +154,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set := NewSortedSet(members) if err = server.SetValue(ctx, key, set); err != nil { @@ -171,14 +171,14 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n } key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -232,14 +232,14 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn * maximum = Score(s) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -266,14 +266,14 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con minimum := cmd[2] maximum := cmd[3] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -318,20 +318,20 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() // Extract base set - if !server.KeyExists(keys[0]) { + if !server.KeyExists(ctx, keys[0]) { // If base set does not exist, return an empty array return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[0]); err != nil { return nil, err } - defer server.KeyRUnlock(keys[0]) + defer server.KeyRUnlock(ctx, keys[0]) baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) @@ -341,7 +341,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n var sets []*SortedSet for i := 1; i < len(keys); i++ { - if !server.KeyExists(keys[i]) { + if !server.KeyExists(ctx, keys[i]) { continue } locked, err := server.KeyRLock(ctx, keys[i]) @@ -386,20 +386,20 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() // Extract base set - if !server.KeyExists(keys[0]) { + if !server.KeyExists(ctx, keys[0]) { // If base set does not exist, return 0 return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[0]); err != nil { return nil, err } - defer server.KeyRUnlock(keys[0]) + defer server.KeyRUnlock(ctx, keys[0]) baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) @@ -408,7 +408,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co var sets []*SortedSet for i := 1; i < len(keys); i++ { - if server.KeyExists(keys[i]) { + if server.KeyExists(ctx, keys[i]) { if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } @@ -422,7 +422,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co diff := baseSortedSet.Subtract(sets) - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -431,7 +431,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co return nil, err } } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) if err = server.SetValue(ctx, destination, diff); err != nil { return nil, err @@ -469,7 +469,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn increment = Score(s) } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // If the key does not exist, create a new sorted set at the key with // the member and increment as the first value if _, err = server.CreateKeyAndLock(ctx, key); err != nil { @@ -478,14 +478,14 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn if err = server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}})); err != nil { return nil, err } - server.KeyUnlock(key) + server.KeyUnlock(ctx, key) return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) @@ -518,7 +518,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() @@ -526,7 +526,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * var setParams []SortedSetParam for i := 0; i < len(keys); i++ { - if !server.KeyExists(keys[i]) { + if !server.KeyExists(ctx, keys[i]) { // If any of the keys is non-existent, return an empty array as there's no intersect return []byte("*0\r\n"), nil } @@ -585,7 +585,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() @@ -593,7 +593,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c var setParams []SortedSetParam for i := 0; i < len(keys); i++ { - if !server.KeyExists(keys[i]) { + if !server.KeyExists(ctx, keys[i]) { return []byte(":0\r\n"), nil } if _, err = server.KeyRLock(ctx, keys[i]); err != nil { @@ -612,7 +612,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c intersect := Intersect(aggregate, setParams...) - if server.KeyExists(destination) && intersect.Cardinality() > 0 { + if server.KeyExists(ctx, destination) && intersect.Cardinality() > 0 { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -621,7 +621,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) if err = server.SetValue(ctx, destination, intersect); err != nil { return nil, err @@ -677,21 +677,21 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n } for i := 0; i < len(keys); i++ { - if server.KeyExists(keys[i]) { + if server.KeyExists(ctx, keys[i]) { if _, err = server.KeyLock(ctx, keys[i]); err != nil { continue } v, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok || v.Cardinality() == 0 { - server.KeyUnlock(keys[i]) + server.KeyUnlock(ctx, keys[i]) continue } popped, err := v.Pop(count, policy) if err != nil { - server.KeyUnlock(keys[i]) + server.KeyUnlock(ctx, keys[i]) return nil, err } - server.KeyUnlock(keys[i]) + server.KeyUnlock(ctx, keys[i]) res := fmt.Sprintf("*%d", popped.Cardinality()) @@ -730,14 +730,14 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne count = c } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -767,14 +767,14 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -826,14 +826,14 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -870,14 +870,14 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n withscores = true } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -914,14 +914,14 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -946,13 +946,13 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn * key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("$-1\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) @@ -987,14 +987,14 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv return nil, err } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -1029,14 +1029,14 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve return nil, err } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -1086,14 +1086,14 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server minimum := cmd[2] maximum := cmd[3] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err = server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -1184,14 +1184,14 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { @@ -1321,14 +1321,14 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c } } - if !server.KeyExists(source) { + if !server.KeyExists(ctx, source) { return []byte("*0\r\n"), nil } if _, err = server.KeyRLock(ctx, source); err != nil { return nil, err } - defer server.KeyRUnlock(source) + defer server.KeyRUnlock(ctx, source) set, ok := server.GetValue(ctx, source).(*SortedSet) if !ok { @@ -1387,7 +1387,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c newSortedSet := NewSortedSet(resultMembers) - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -1396,7 +1396,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) if err = server.SetValue(ctx, destination, newSortedSet); err != nil { return nil, err @@ -1419,7 +1419,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn * defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() @@ -1427,7 +1427,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn * var setParams []SortedSetParam for i := 0; i < len(keys); i++ { - if server.KeyExists(keys[i]) { + if server.KeyExists(ctx, keys[i]) { if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } @@ -1481,7 +1481,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c defer func() { for key, locked := range locks { if locked { - server.KeyRUnlock(key) + server.KeyRUnlock(ctx, key) } } }() @@ -1489,7 +1489,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c var setParams []SortedSetParam for i := 0; i < len(keys); i++ { - if server.KeyExists(keys[i]) { + if server.KeyExists(ctx, keys[i]) { if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } @@ -1507,7 +1507,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c union := Union(aggregate, setParams...) - if server.KeyExists(destination) { + if server.KeyExists(ctx, destination) { if _, err = server.KeyLock(ctx, destination); err != nil { return nil, err } @@ -1516,7 +1516,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } } - defer server.KeyUnlock(destination) + defer server.KeyUnlock(ctx, destination) if err = server.SetValue(ctx, destination, union); err != nil { return nil, err diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 2cef684..7e591f1 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" @@ -218,15 +219,19 @@ func Test_HandleZADD(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZADD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZADD(context.Background(), test.command, mockServer, nil) + res, err := handleZADD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -245,10 +250,10 @@ func Test_HandleZADD(t *testing.T) { t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) } // Fetch the sorted set from the server and check it against the expected result - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - sortedSet, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) + sortedSet, ok := mockServer.GetValue(ctx, test.key).(*SortedSet) if !ok { t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key) } @@ -258,7 +263,7 @@ func Test_HandleZADD(t *testing.T) { if !sortedSet.Equals(test.expectedValue) { t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, sortedSet) } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -325,15 +330,19 @@ func Test_HandleZCARD(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZCARD(context.Background(), test.command, mockServer, nil) + res, err := handleZCARD(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -464,15 +473,19 @@ func Test_HandleZCOUNT(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZCOUNT(context.Background(), test.command, mockServer, nil) + res, err := handleZCOUNT(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -577,15 +590,19 @@ func Test_HandleZLEXCOUNT(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZLEXCOUNT, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZLEXCOUNT(context.Background(), test.command, mockServer, nil) + res, err := handleZLEXCOUNT(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -728,17 +745,21 @@ func Test_HandleZDIFF(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFF, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZDIFF(context.Background(), test.command, mockServer, nil) + res, err := handleZDIFF(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -903,17 +924,21 @@ func Test_HandleZDIFFSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFFSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZDIFFSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleZDIFFSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -932,10 +957,10 @@ func Test_HandleZDIFFSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -944,7 +969,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem.value) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } @@ -1119,15 +1144,19 @@ func Test_HandleZINCRBY(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINCRBY, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZINCRBY(context.Background(), test.command, mockServer, nil) + res, err := handleZINCRBY(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1146,10 +1175,10 @@ func Test_HandleZINCRBY(t *testing.T) { t.Errorf("expected response integer %s, got %s", test.expectedResponse, rv.String()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.key).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.key) } @@ -1165,7 +1194,7 @@ func Test_HandleZINCRBY(t *testing.T) { ) } } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } } @@ -1339,17 +1368,21 @@ func Test_HandleZMPOP(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMPOP, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZMPOP(context.Background(), test.command, mockServer, nil) + res, err := handleZMPOP(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1381,10 +1414,10 @@ func Test_HandleZMPOP(t *testing.T) { } } for key, expectedSortedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1510,17 +1543,21 @@ func Test_HandleZPOP(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZPOPMIN/ZPOPMAX, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZPOP(context.Background(), test.command, mockServer, nil) + res, err := handleZPOP(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1552,10 +1589,10 @@ func Test_HandleZPOP(t *testing.T) { } } for key, expectedSortedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1610,17 +1647,21 @@ func Test_HandleZMSCORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMSCORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZMSCORE(context.Background(), test.command, mockServer, nil) + res, err := handleZMSCORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1710,17 +1751,21 @@ func Test_HandleZSCORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZSCORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZSCORE(context.Background(), test.command, mockServer, nil) + res, err := handleZSCORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1825,15 +1870,19 @@ func Test_HandleZRANDMEMBER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANDMEMBER, %d", i)) + if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleZRANDMEMBER(context.Background(), test.command, mockServer, nil) + res, err := handleZRANDMEMBER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -1866,10 +1915,10 @@ func Test_HandleZRANDMEMBER(t *testing.T) { } } // 2. Fetch the set and check if its cardinality is what we expect. - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1980,17 +2029,21 @@ func Test_HandleZRANK(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANK, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZRANK(context.Background(), test.command, mockServer, nil) + res, err := handleZRANK(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2078,17 +2131,21 @@ func Test_HandleZREM(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREM, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZREM(context.Background(), test.command, mockServer, nil) + res, err := handleZREM(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2109,10 +2166,10 @@ func Test_HandleZREM(t *testing.T) { // Check if the expected sorted set is the same at the current one if test.expectedValues != nil { for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2184,17 +2241,21 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYSCORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZREMRANGEBYSCORE(context.Background(), test.command, mockServer, nil) + res, err := handleZREMRANGEBYSCORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2215,10 +2276,10 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { // Check if the expected values are the same if test.expectedValues != nil { for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2344,17 +2405,21 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYRANK, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZREMRANGEBYRANK(context.Background(), test.command, mockServer, nil) + res, err := handleZREMRANGEBYRANK(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2375,10 +2440,10 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { // Check if the expected values are the same if test.expectedValues != nil { for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2475,17 +2540,21 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYLEX, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZREMRANGEBYLEX(context.Background(), test.command, mockServer, nil) + res, err := handleZREMRANGEBYLEX(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2506,10 +2575,10 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { // Check if the expected values are the same if test.expectedValues != nil { for key, expectedSet := range test.expectedValues { - if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { + if _, err = mockServer.KeyRLock(ctx, key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) + set, ok := mockServer.GetValue(ctx, key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2720,17 +2789,21 @@ func Test_HandleZRANGE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZRANGE(context.Background(), test.command, mockServer, nil) + res, err := handleZRANGE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -2790,7 +2863,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination1", - command: []string{"ZRANGE", "destination1", "key1", "3", "7", "BYSCORE"}, + command: []string{"ZRANGESTORE", "destination1", "key1", "3", "7", "BYSCORE"}, expectedResponse: 5, expectedValue: NewSortedSet([]MemberParam{ {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, @@ -2809,7 +2882,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination2", - command: []string{"ZRANGE", "destination2", "key2", "3", "7", "BYSCORE", "WITHSCORES"}, + command: []string{"ZRANGESTORE", "destination2", "key2", "3", "7", "BYSCORE", "WITHSCORES"}, expectedResponse: 5, expectedValue: NewSortedSet([]MemberParam{ {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, @@ -2829,7 +2902,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination3", - command: []string{"ZRANGE", "destination3", "key3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, + command: []string{"ZRANGESTORE", "destination3", "key3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, expectedResponse: 3, expectedValue: NewSortedSet([]MemberParam{ {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, @@ -2849,7 +2922,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination4", - command: []string{"ZRANGE", "destination4", "key4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + command: []string{"ZRANGESTORE", "destination4", "key4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, expectedResponse: 3, expectedValue: NewSortedSet([]MemberParam{ {value: "six", score: 6}, {value: "five", score: 5}, {value: "four", score: 4}, @@ -2867,7 +2940,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination5", - command: []string{"ZRANGE", "destination5", "key5", "c", "g", "BYLEX"}, + command: []string{"ZRANGESTORE", "destination5", "key5", "c", "g", "BYLEX"}, expectedResponse: 5, expectedValue: NewSortedSet([]MemberParam{ {value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1}, @@ -2886,7 +2959,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination6", - command: []string{"ZRANGE", "destination6", "key6", "a", "f", "BYLEX", "WITHSCORES"}, + command: []string{"ZRANGESTORE", "destination6", "key6", "a", "f", "BYLEX", "WITHSCORES"}, expectedResponse: 6, expectedValue: NewSortedSet([]MemberParam{ {value: "a", score: 1}, {value: "b", score: 1}, {value: "c", score: 1}, @@ -2906,7 +2979,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination7", - command: []string{"ZRANGE", "destination7", "key7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + command: []string{"ZRANGESTORE", "destination7", "key7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, expectedResponse: 3, expectedValue: NewSortedSet([]MemberParam{ {value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1}, @@ -2926,7 +2999,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination8", - command: []string{"ZRANGE", "destination8", "key8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, + command: []string{"ZRANGESTORE", "destination8", "key8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, expectedResponse: 3, expectedValue: NewSortedSet([]MemberParam{ {value: "f", score: 1}, {value: "e", score: 1}, {value: "d", score: 1}, @@ -2944,7 +3017,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { }), }, destination: "destination9", - command: []string{"ZRANGE", "destination9", "key9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + command: []string{"ZRANGESTORE", "destination9", "key9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, expectedResponse: 0, expectedValue: nil, expectedError: nil, @@ -2952,28 +3025,28 @@ func Test_HandleZRANGESTORE(t *testing.T) { { // 10. Throw error when limit does not provide both offset and limit preset: false, presetValues: nil, - command: []string{"ZRANGE", "destination10", "key10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, + command: []string{"ZRANGESTORE", "destination10", "key10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, expectedResponse: 0, expectedError: errors.New("limit should contain offset and count as integers"), }, { // 11. Throw error when offset is not a valid integer preset: false, presetValues: nil, - command: []string{"ZRANGE", "destination11", "key11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, + command: []string{"ZRANGESTORE", "destination11", "key11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, expectedResponse: 0, expectedError: errors.New("limit offset must be integer"), }, { // 12. Throw error when limit is not a valid integer preset: false, presetValues: nil, - command: []string{"ZRANGE", "destination12", "key12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, + command: []string{"ZRANGESTORE", "destination12", "key12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, expectedResponse: 0, expectedError: errors.New("limit count must be integer"), }, { // 13. Throw error when offset is negative preset: false, presetValues: nil, - command: []string{"ZRANGE", "destination13", "key13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, + command: []string{"ZRANGESTORE", "destination13", "key13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, expectedResponse: 0, expectedError: errors.New("limit offset must be >= 0"), }, @@ -2982,37 +3055,41 @@ func Test_HandleZRANGESTORE(t *testing.T) { presetValues: map[string]interface{}{ "key14": "Default value", }, - command: []string{"ZRANGE", "destination14", "key14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, + command: []string{"ZRANGESTORE", "destination14", "key14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, expectedResponse: 0, expectedError: errors.New("value at key14 is not a sorted set"), }, { // 15. Command too short preset: false, presetValues: nil, - command: []string{"ZRANGE", "key15", "1"}, + command: []string{"ZRANGESTORE", "key15", "1"}, expectedResponse: 0, expectedError: errors.New(utils.WrongArgsResponse), }, { // 16 Command too long preset: false, presetValues: nil, - command: []string{"ZRANGE", "destination16", "key16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, + command: []string{"ZRANGESTORE", "destination16", "key16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, expectedResponse: 0, expectedError: errors.New(utils.WrongArgsResponse), }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGESTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZRANGESTORE(context.Background(), test.command, mockServer, nil) + res, err := handleZRANGESTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -3031,17 +3108,17 @@ func Test_HandleZRANGESTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } if !set.Equals(test.expectedValue) { t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, set) } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } @@ -3316,17 +3393,21 @@ func Test_HandleZINTER(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTER, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZINTER(context.Background(), test.command, mockServer, nil) + res, err := handleZINTER(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -3661,17 +3742,21 @@ func Test_HandleZINTERSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTERSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZINTERSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleZINTERSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -3690,10 +3775,10 @@ func Test_HandleZINTERSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -3702,7 +3787,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem.value) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } @@ -4002,17 +4087,21 @@ func Test_HandleZUNION(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNION, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZUNION(context.Background(), test.command, mockServer, nil) + res, err := handleZUNION(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -4386,17 +4475,21 @@ func Test_HandleZUNIONSTORE(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNIONSTORE, %d", i)) + if test.preset { for key, value := range test.presetValues { - if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), key, value) - mockServer.KeyUnlock(key) + if err := mockServer.SetValue(ctx, key, value); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, key) } } - res, err := handleZUNIONSTORE(context.Background(), test.command, mockServer, nil) + res, err := handleZUNIONSTORE(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -4415,10 +4508,10 @@ func Test_HandleZUNIONSTORE(t *testing.T) { t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) } if test.expectedValue != nil { - if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) + set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -4427,7 +4520,7 @@ func Test_HandleZUNIONSTORE(t *testing.T) { t.Errorf("could not find element %s in the expected values", elem.value) } } - mockServer.KeyRUnlock(test.destination) + mockServer.KeyRUnlock(ctx, test.destination) } } } diff --git a/src/modules/string/commands.go b/src/modules/string/commands.go index ee92e6d..ac495e0 100644 --- a/src/modules/string/commands.go +++ b/src/modules/string/commands.go @@ -23,21 +23,21 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn newStr := cmd[3] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { if _, err = server.CreateKeyAndLock(ctx, key); err != nil { return nil, err } if err = server.SetValue(ctx, key, newStr); err != nil { return nil, err } - server.KeyUnlock(key) + server.KeyUnlock(ctx, key) return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil } if _, err := server.KeyLock(ctx, key); err != nil { return nil, err } - defer server.KeyUnlock(key) + defer server.KeyUnlock(ctx, key) str, ok := server.GetValue(ctx, key).(string) if !ok { @@ -91,14 +91,14 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn * key := keys[0] - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return []byte(":0\r\n"), nil } if _, err := server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) value, ok := server.GetValue(ctx, key).(string) @@ -125,14 +125,14 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn * return nil, errors.New("start and end indices must be integers") } - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { return nil, fmt.Errorf("key %s does not exist", key) } - if _, err := server.KeyRLock(ctx, key); err != nil { + if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - defer server.KeyRUnlock(key) + defer server.KeyRUnlock(ctx, key) value, ok := server.GetValue(ctx, key).(string) if !ok { diff --git a/src/modules/string/commands_test.go b/src/modules/string/commands_test.go index c531d1b..0f965f5 100644 --- a/src/modules/string/commands_test.go +++ b/src/modules/string/commands_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/utils" "github.com/tidwall/resp" @@ -105,17 +106,21 @@ func Test_HandleSetRange(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SETRANGE, %d", i)) + // If there's a preset step, carry it out here if test.preset { - if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, utils.AdaptType(test.presetValue)) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, utils.AdaptType(test.presetValue)); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSetRange(context.Background(), test.command, mockServer, nil) + res, err := handleSetRange(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -135,17 +140,17 @@ func Test_HandleSetRange(t *testing.T) { } // Get the value from the server and check against the expected value - if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { + if _, err = mockServer.KeyRLock(ctx, test.key); err != nil { t.Error(err) } - value, ok := mockServer.GetValue(context.Background(), test.key).(string) + value, ok := mockServer.GetValue(ctx, test.key).(string) if !ok { t.Error("expected string data type, got another type") } if value != test.expectedValue { t.Errorf("expected value \"%s\", got \"%s\"", test.expectedValue, value) } - mockServer.KeyRUnlock(test.key) + mockServer.KeyRUnlock(ctx, test.key) } } @@ -194,16 +199,20 @@ func Test_HandleStrLen(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("STRLEN, %d", i)) + if test.preset { - _, err := mockServer.CreateKeyAndLock(context.Background(), test.key) + _, err := mockServer.CreateKeyAndLock(ctx, test.key) if err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleStrLen(context.Background(), test.command, mockServer, nil) + res, err := handleStrLen(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) @@ -307,16 +316,19 @@ func Test_HandleSubStr(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { + ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUBSTR, %d", i)) + if test.preset { - _, err := mockServer.CreateKeyAndLock(context.Background(), test.key) - if err != nil { + if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil { t.Error(err) } - mockServer.SetValue(context.Background(), test.key, test.presetValue) - mockServer.KeyUnlock(test.key) + if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil { + t.Error(err) + } + mockServer.KeyUnlock(ctx, test.key) } - res, err := handleSubStr(context.Background(), test.command, mockServer, nil) + res, err := handleSubStr(ctx, test.command, mockServer, nil) if test.expectedError != nil { if err.Error() != test.expectedError.Error() { t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) diff --git a/src/raft/fsm.go b/src/raft/fsm.go index f6d4f58..fb6cab8 100644 --- a/src/raft/fsm.go +++ b/src/raft/fsm.go @@ -98,19 +98,21 @@ func (fsm *FSM) Restore(snapshot io.ReadCloser) error { LatestSnapshotMilliseconds: 0, } - if err := json.Unmarshal(b, &data); err != nil { + if err = json.Unmarshal(b, &data); err != nil { log.Fatal(err) return err } // Set state + ctx := context.Background() for k, v := range data.State { - _, err := fsm.options.Server.CreateKeyAndLock(context.Background(), k) - if err != nil { + if _, err = fsm.options.Server.CreateKeyAndLock(ctx, k); err != nil { log.Fatal(err) } - fsm.options.Server.SetValue(context.Background(), k, v) - fsm.options.Server.KeyUnlock(k) + if err = fsm.options.Server.SetValue(ctx, k, v); err != nil { + log.Fatal(err) + } + fsm.options.Server.KeyUnlock(ctx, k) } // Set latest snapshot milliseconds fsm.options.Server.SetLatestSnapshot(data.LatestSnapshotMilliseconds) diff --git a/src/server/keyspace.go b/src/server/keyspace.go index 2860656..1ae495a 100644 --- a/src/server/keyspace.go +++ b/src/server/keyspace.go @@ -33,8 +33,8 @@ func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { } } -func (server *Server) KeyUnlock(key string) { - if server.KeyExists(key) { +func (server *Server) KeyUnlock(ctx context.Context, key string) { + if _, ok := server.keyLocks[key]; ok { server.keyLocks[key].Unlock() } } @@ -58,14 +58,27 @@ func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { } } -func (server *Server) KeyRUnlock(key string) { - if server.KeyExists(key) { +func (server *Server) KeyRUnlock(ctx context.Context, key string) { + if _, ok := server.keyLocks[key]; ok { server.keyLocks[key].RUnlock() } } -func (server *Server) KeyExists(key string) bool { - return server.keyLocks[key] != nil +func (server *Server) KeyExists(ctx context.Context, key string) bool { + entry, ok := server.store[key] + if !ok { + return false + } + + if entry.ExpireAt != (time.Time{}) && entry.ExpireAt.Before(time.Now()) { + err := server.DeleteKey(ctx, key) + if err != nil { + log.Printf("keyExists: %+v\n", err) + } + return false + } + + return true } // CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist. @@ -78,15 +91,15 @@ func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, e server.keyCreationLock.Lock() defer server.keyCreationLock.Unlock() - if !server.KeyExists(key) { + if !server.KeyExists(ctx, key) { // Create Lock keyLock := &sync.RWMutex{} keyLock.Lock() server.keyLocks[key] = keyLock // Create key entry - server.store[key] = KeyData{ - value: nil, - expireAt: time.Time{}, + server.store[key] = utils.KeyData{ + Value: nil, + ExpireAt: time.Time{}, } return true, nil } @@ -100,7 +113,7 @@ func (server *Server) GetValue(ctx context.Context, key string) interface{} { if err := server.updateKeyInCache(ctx, key); err != nil { log.Printf("GetValue error: %+v\n", err) } - return server.store[key].value + return server.store[key].Value } // SetValue updates the value in the store at the specified key with the given value. @@ -113,9 +126,9 @@ func (server *Server) SetValue(ctx context.Context, key string, value interface{ return errors.New("max memory reached, key value not set") } - server.store[key] = KeyData{ - value: value, - expireAt: server.store[key].expireAt, + server.store[key] = utils.KeyData{ + Value: value, + ExpireAt: server.store[key].ExpireAt, } err := server.updateKeyInCache(ctx, key) @@ -136,7 +149,7 @@ func (server *Server) GetExpiry(ctx context.Context, key string) time.Time { if err := server.updateKeyInCache(ctx, key); err != nil { log.Printf("GetKeyExpiry error: %+v\n", err) } - return server.store[key].expireAt + return server.store[key].ExpireAt } // The SetExpiry receiver function sets the expiry time of a key. @@ -146,9 +159,9 @@ func (server *Server) GetExpiry(ctx context.Context, key string) time.Time { // or the access time on lru eviction policy. // The key must be locked prior to calling this function. func (server *Server) SetExpiry(ctx context.Context, key string, expireAt time.Time, touch bool) { - server.store[key] = KeyData{ - value: server.store[key].value, - expireAt: expireAt, + server.store[key] = utils.KeyData{ + Value: server.store[key].Value, + ExpireAt: expireAt, } if touch { err := server.updateKeyInCache(ctx, key) @@ -161,9 +174,9 @@ func (server *Server) SetExpiry(ctx context.Context, key string, expireAt time.T // RemoveExpiry is called by commands that remove key expiry (e.g. PERSIST). // The key must be locked prior ro calling this function. func (server *Server) RemoveExpiry(key string) { - server.store[key] = KeyData{ - value: server.store[key].value, - expireAt: time.Time{}, + server.store[key] = utils.KeyData{ + Value: server.store[key].Value, + ExpireAt: time.Time{}, } switch { case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, server.Config.EvictionPolicy): @@ -198,9 +211,11 @@ func (server *Server) DeleteKey(ctx context.Context, key string) error { if _, err := server.KeyLock(ctx, key); err != nil { return fmt.Errorf("deleteKey: %+v", err) } - // Delete the keys - delete(server.store, key) + // Remove key expiry + server.RemoveExpiry(key) + // Delete the key from keyLocks and store delete(server.keyLocks, key) + delete(server.store, key) return nil } @@ -227,13 +242,13 @@ func (server *Server) updateKeyInCache(ctx context.Context, key string) error { case utils.VolatileLFU: server.lfuCache.mutex.Lock() defer server.lfuCache.mutex.Unlock() - if server.store[key].expireAt != (time.Time{}) { + if server.store[key].ExpireAt != (time.Time{}) { server.lfuCache.cache.Update(key) } case utils.VolatileLRU: server.lruCache.mutex.Lock() defer server.lruCache.mutex.Unlock() - if server.store[key].expireAt != (time.Time{}) { + if server.store[key].ExpireAt != (time.Time{}) { server.lruCache.cache.Update(key) } } @@ -347,7 +362,7 @@ func (server *Server) adjustMemoryUsage(ctx context.Context) error { for key, _ := range server.keyLocks { if idx == 0 { // If the key is not volatile, break the loop - if server.store[key].expireAt == (time.Time{}) { + if server.store[key].ExpireAt == (time.Time{}) { break } // Delete the key diff --git a/src/server/server.go b/src/server/server.go index 9143508..4b8b936 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -21,17 +21,12 @@ import ( "time" ) -type KeyData struct { - value interface{} - expireAt time.Time -} - type Server struct { Config utils.Config ConnID atomic.Uint64 - store map[string]KeyData + store map[string]utils.KeyData keyLocks map[string]*sync.RWMutex keyCreationLock *sync.Mutex lfuCache struct { @@ -77,7 +72,7 @@ func NewServer(opts Opts) *Server { PubSub: opts.PubSub, CancelCh: opts.CancelCh, Commands: opts.Commands, - store: make(map[string]KeyData), + store: make(map[string]utils.KeyData), keyLocks: make(map[string]*sync.RWMutex), keyCreationLock: &sync.Mutex{}, } @@ -105,13 +100,14 @@ func NewServer(opts Opts) *Server { SetLatestSnapshotMilliseconds: server.SetLatestSnapshot, GetLatestSnapshotMilliseconds: server.GetLatestSnapshot, SetValue: func(key string, value interface{}) error { - if _, err := server.CreateKeyAndLock(context.Background(), key); err != nil { + ctx := context.Background() + if _, err := server.CreateKeyAndLock(ctx, key); err != nil { return err } - if err := server.SetValue(context.Background(), key, value); err != nil { + if err := server.SetValue(ctx, key, value); err != nil { return err } - server.KeyUnlock(key) + server.KeyUnlock(ctx, key) return nil }, }) @@ -123,13 +119,14 @@ func NewServer(opts Opts) *Server { aof.WithFinishRewriteFunc(server.FinishRewriteAOF), aof.WithGetStateFunc(server.GetState), aof.WithSetValueFunc(func(key string, value interface{}) error { - if _, err := server.CreateKeyAndLock(context.Background(), key); err != nil { + ctx := context.Background() + if _, err := server.CreateKeyAndLock(ctx, key); err != nil { return err } - if err := server.SetValue(context.Background(), key, value); err != nil { + if err := server.SetValue(ctx, key, value); err != nil { return err } - server.KeyUnlock(key) + server.KeyUnlock(ctx, key) return nil }), aof.WithHandleCommandFunc(func(command []byte) { diff --git a/src/utils/config.go b/src/utils/config.go index 0bdeec2..eae8dd9 100644 --- a/src/utils/config.go +++ b/src/utils/config.go @@ -98,12 +98,11 @@ There is no limit by default.`, func(memory string) error { 4) volatile-lfu - Evict the least frequently used keys with an expiration. 5) volatile-lru - Evict the least recently used keys with an expiration. 6) allkeys-random - Evict random keys until we get under the max-memory limit. -7) volatile-random - Evict random keys with an expiration. -8) volatile-ttl - Evict the keys with the shortest remaining ttl.`, func(policy string) error { +7) volatile-random - Evict random keys with an expiration.`, func(policy string) error { policies := []string{ NoEviction, AllKeysLFU, AllKeysLRU, AllKeysRandom, - VolatileLFU, VolatileLRU, VolatileRandom, VolatileTTL, + VolatileLFU, VolatileLRU, VolatileRandom, } policyIdx := slices.Index(policies, strings.ToLower(policy)) if policyIdx == -1 { diff --git a/src/utils/const.go b/src/utils/const.go index d56dea1..cea7f69 100644 --- a/src/utils/const.go +++ b/src/utils/const.go @@ -37,5 +37,4 @@ const ( VolatileLFU = "volatile-lfu" AllKeysRandom = "allkeys-random" VolatileRandom = "volatile-random" - VolatileTTL = "volatile-ttl" ) diff --git a/src/utils/types.go b/src/utils/types.go index be5b090..3c00915 100644 --- a/src/utils/types.go +++ b/src/utils/types.go @@ -6,12 +6,18 @@ import ( "time" ) +// KeyData holds the structure of the in-memory data stored at a string key. +type KeyData struct { + Value interface{} + ExpireAt time.Time +} + type Server interface { KeyLock(ctx context.Context, key string) (bool, error) - KeyUnlock(key string) + KeyUnlock(ctx context.Context, key string) KeyRLock(ctx context.Context, key string) (bool, error) - KeyRUnlock(key string) - KeyExists(key string) bool + KeyRUnlock(ctx context.Context, key string) + KeyExists(ctx context.Context, key string) bool CreateKeyAndLock(ctx context.Context, key string) (bool, error) GetValue(ctx context.Context, key string) interface{} SetValue(ctx context.Context, key string, value interface{}) error