KeyExists keyspace function now removes the key if the key expiry is in the past.

KeyData type moved to utils package to allow sharing between multiple packages.
Updated all commands and command tests to pass context object to KeyExists, KeyLock, keyUnlock, KeyRLock, and KeyRUnlock.
Create context object for each test in all test suites instead of just passing context.Background() to all functions that accept a context.
This commit is contained in:
Kelvin Mwinuka
2024-03-10 23:19:05 +08:00
parent 10f1aeab9e
commit c414da16b4
19 changed files with 1148 additions and 871 deletions

View File

@@ -47,216 +47,216 @@ services:
networks: networks:
- testnet - testnet
cluster_node_1: # cluster_node_1:
container_name: cluster_node_1 # container_name: cluster_node_1
build: # build:
context: . # context: .
dockerfile: Dockerfile.dev # dockerfile: Dockerfile.dev
environment: # environment:
- PORT=7480 # - PORT=7480
- RAFT_PORT=8000 # - RAFT_PORT=8000
- ML_PORT=7946 # - ML_PORT=7946
- KEY=/generic/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
- CERT=/generic/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
- SERVER_ID=1 # - SERVER_ID=1
- DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false # - IN_MEMORY=false
- TLS=false # - TLS=false
- MTLS=false # - MTLS=false
- BOOTSTRAP_CLUSTER=true # - BOOTSTRAP_CLUSTER=true
- ACL_CONFIG=/generic/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
- REQUIRE_PASS=false # - REQUIRE_PASS=false
- FORWARD_COMMAND=true # - FORWARD_COMMAND=true
- SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
- SNAPSHOT_INTERVAL=5m30s # - SNAPSHOT_INTERVAL=5m30s
- RESTORE_SNAPSHOT=false # - RESTORE_SNAPSHOT=false
- RESTORE_AOF=false # - RESTORE_AOF=false
- AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=2000kb # - MAX_MEMORY=2000kb
- EVICTION_POLICY=allkeys-lfu # - EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # # List of server cert/key pairs
- CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# List of client certificate authorities # # List of client certificate authorities
- CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
ports: # ports:
- "7481:7480" # - "7481:7480"
- "7945:7946" # - "7945:7946"
- "8000:8000" # - "8000:8000"
volumes: # volumes:
- ./config/acl.yml:/generic/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
- ./volumes/cluster_node_1:/var/lib/echovault # - ./volumes/cluster_node_1:/var/lib/echovault
networks: # networks:
- testnet # - testnet
#
cluster_node_2: # cluster_node_2:
container_name: cluster_node_2 # container_name: cluster_node_2
build: # build:
context: . # context: .
dockerfile: Dockerfile.dev # dockerfile: Dockerfile.dev
environment: # environment:
- PORT=7480 # - PORT=7480
- RAFT_PORT=8000 # - RAFT_PORT=8000
- ML_PORT=7946 # - ML_PORT=7946
- KEY=/generic/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
- CERT=/generic/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
- SERVER_ID=2 # - SERVER_ID=2
- JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
- DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false # - IN_MEMORY=false
- TLS=false # - TLS=false
- MTLS=false # - MTLS=false
- BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/generic/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
- REQUIRE_PASS=false # - REQUIRE_PASS=false
- FORWARD_COMMAND=true # - FORWARD_COMMAND=true
- SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
- SNAPSHOT_INTERVAL=5m30s # - SNAPSHOT_INTERVAL=5m30s
- RESTORE_SNAPSHOT=false # - RESTORE_SNAPSHOT=false
- RESTORE_AOF=false # - RESTORE_AOF=false
- AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=2000kb # - MAX_MEMORY=2000kb
- EVICTION_POLICY=allkeys-lfu # - EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # # List of server cert/key pairs
- CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# List of client certificate authorities # # List of client certificate authorities
- CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
ports: # ports:
- "7482:7480" # - "7482:7480"
- "7947:7946" # - "7947:7946"
- "8001:8000" # - "8001:8000"
volumes: # volumes:
- ./config/acl.yml:/generic/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
- ./volumes/cluster_node_2:/var/lib/echovault # - ./volumes/cluster_node_2:/var/lib/echovault
networks: # networks:
- testnet # - testnet
#
cluster_node_3: # cluster_node_3:
container_name: cluster_node_3 # container_name: cluster_node_3
build: # build:
context: . # context: .
dockerfile: Dockerfile.dev # dockerfile: Dockerfile.dev
environment: # environment:
- PORT=7480 # - PORT=7480
- RAFT_PORT=8000 # - RAFT_PORT=8000
- ML_PORT=7946 # - ML_PORT=7946
- KEY=/generic/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
- CERT=/generic/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
- SERVER_ID=3 # - SERVER_ID=3
- JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
- DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false # - IN_MEMORY=false
- TLS=false # - TLS=false
- MTLS=false # - MTLS=false
- BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/generic/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
- REQUIRE_PASS=false # - REQUIRE_PASS=false
- FORWARD_COMMAND=true # - FORWARD_COMMAND=true
- SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
- SNAPSHOT_INTERVAL=5m30s # - SNAPSHOT_INTERVAL=5m30s
- RESTORE_SNAPSHOT=false # - RESTORE_SNAPSHOT=false
- RESTORE_AOF=false # - RESTORE_AOF=false
- AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=2000kb # - MAX_MEMORY=2000kb
- EVICTION_POLICY=allkeys-lfu # - EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # # List of server cert/key pairs
- CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# List of client certificate authorities # # List of client certificate authorities
- CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
ports: # ports:
- "7483:7480" # - "7483:7480"
- "7948:7946" # - "7948:7946"
- "8002:8000" # - "8002:8000"
volumes: # volumes:
- ./config/acl.yml:/generic/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
- ./volumes/cluster_node_3:/var/lib/echovault # - ./volumes/cluster_node_3:/var/lib/echovault
networks: # networks:
- testnet # - testnet
#
cluster_node_4: # cluster_node_4:
container_name: cluster_node_4 # container_name: cluster_node_4
build: # build:
context: . # context: .
dockerfile: Dockerfile.dev # dockerfile: Dockerfile.dev
environment: # environment:
- PORT=7480 # - PORT=7480
- RAFT_PORT=8000 # - RAFT_PORT=8000
- ML_PORT=7946 # - ML_PORT=7946
- KEY=/generic/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
- CERT=/generic/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
- SERVER_ID=4 # - SERVER_ID=4
- JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
- DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false # - IN_MEMORY=false
- TLS=false # - TLS=false
- MTLS=false # - MTLS=false
- BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/generic/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
- REQUIRE_PASS=false # - REQUIRE_PASS=false
- FORWARD_COMMAND=true # - FORWARD_COMMAND=true
- SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
- SNAPSHOT_INTERVAL=5m30s # - SNAPSHOT_INTERVAL=5m30s
- RESTORE_SNAPSHOT=false # - RESTORE_SNAPSHOT=false
- RESTORE_AOF=false # - RESTORE_AOF=false
- AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=2000kb # - MAX_MEMORY=2000kb
- EVICTION_POLICY=allkeys-lfu # - EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # # List of server cert/key pairs
- CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# List of client certificate authorities # # List of client certificate authorities
- CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
ports: # ports:
- "7484:7480" # - "7484:7480"
- "7949:7946" # - "7949:7946"
- "8003:8000" # - "8003:8000"
volumes: # volumes:
- ./config/acl.yml:/generic/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
- ./volumes/cluster_node_4:/var/lib/echovault # - ./volumes/cluster_node_4:/var/lib/echovault
networks: # networks:
- testnet # - testnet
#
cluster_node_5: # cluster_node_5:
container_name: cluster_node_5 # container_name: cluster_node_5
build: # build:
context: . # context: .
dockerfile: Dockerfile.dev # dockerfile: Dockerfile.dev
environment: # environment:
- PORT=7480 # - PORT=7480
- RAFT_PORT=8000 # - RAFT_PORT=8000
- ML_PORT=7946 # - ML_PORT=7946
- KEY=/generic/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
- CERT=/generic/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
- SERVER_ID=5 # - SERVER_ID=5
- JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
- DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false # - IN_MEMORY=false
- TLS=false # - TLS=false
- MTLS=false # - MTLS=false
- BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/generic/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
- REQUIRE_PASS=false # - REQUIRE_PASS=false
- FORWARD_COMMAND=true # - FORWARD_COMMAND=true
- SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
- SNAPSHOT_INTERVAL=5m30s # - SNAPSHOT_INTERVAL=5m30s
- RESTORE_SNAPSHOT=false # - RESTORE_SNAPSHOT=false
- RESTORE_AOF=false # - RESTORE_AOF=false
- AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=2000kb # - MAX_MEMORY=2000kb
- EVICTION_POLICY=allkeys-lfu # - EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # # List of server cert/key pairs
- CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# List of client certificate authorities # # List of client certificate authorities
- CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
ports: # ports:
- "7485:7480" # - "7485:7480"
- "7950:7946" # - "7950:7946"
- "8004:8000" # - "8004:8000"
volumes: # volumes:
- ./config/acl.yml:/generic/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
- ./volumes/cluster_node_5:/var/lib/echovault # - ./volumes/cluster_node_5:/var/lib/echovault
networks: # networks:
- testnet # - testnet

View File

@@ -35,7 +35,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
// If GET is provided, the response should be the current stored value. // If GET is provided, the response should be the current stored value.
// If there's no current value, then the response should be nil. // If there's no current value, then the response should be nil.
if params.get { if params.get {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
res = []byte("$-1\r\n") res = []byte("$-1\r\n")
} else { } else {
res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key))) res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key)))
@@ -44,19 +44,19 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
if "xx" == strings.ToLower(params.exists) { if "xx" == strings.ToLower(params.exists) {
// If XX is specified, make sure the key exists. // If XX is specified, make sure the key exists.
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("key %s does not exist", key) return nil, fmt.Errorf("key %s does not exist", key)
} }
_, err = server.KeyLock(ctx, key) _, err = server.KeyLock(ctx, key)
} else if "nx" == strings.ToLower(params.exists) { } else if "nx" == strings.ToLower(params.exists) {
// If NX is specified, make sure that the key does not currently exist. // If NX is specified, make sure that the key does not currently exist.
if server.KeyExists(key) { if server.KeyExists(ctx, key) {
return nil, fmt.Errorf("key %s already exists", key) return nil, fmt.Errorf("key %s already exists", key)
} }
_, err = server.CreateKeyAndLock(ctx, key) _, err = server.CreateKeyAndLock(ctx, key)
} else { } else {
// Neither XX not NX are specified, lock or create the lock // Neither XX not NX are specified, lock or create the lock
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// Key does not exist, create it // Key does not exist, create it
_, err = server.CreateKeyAndLock(ctx, key) _, err = server.CreateKeyAndLock(ctx, key)
} else { } else {
@@ -67,7 +67,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
if err = server.SetValue(ctx, key, utils.AdaptType(value)); err != nil { if err = server.SetValue(ctx, key, utils.AdaptType(value)); err != nil {
return nil, err return nil, err
@@ -92,7 +92,7 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, _ *net.C
defer func() { defer func() {
for k, v := range entries { for k, v := range entries {
if v.locked { if v.locked {
server.KeyUnlock(k) server.KeyUnlock(ctx, k)
entries[k] = KeyObject{ entries[k] = KeyObject{
value: v.value, value: v.value,
locked: false, locked: false,
@@ -114,7 +114,7 @@ func handleMSet(ctx context.Context, cmd []string, server utils.Server, _ *net.C
// Acquire all the locks for each key first // Acquire all the locks for each key first
// If any key cannot be acquired, abandon transaction and release all currently held keys // If any key cannot be acquired, abandon transaction and release all currently held keys
for k, v := range entries { for k, v := range entries {
if server.KeyExists(k) { if server.KeyExists(ctx, k) {
if _, err := server.KeyLock(ctx, k); err != nil { if _, err := server.KeyLock(ctx, k); err != nil {
return nil, err return nil, err
} }
@@ -144,7 +144,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
} }
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
@@ -152,7 +152,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
value := server.GetValue(ctx, key) value := server.GetValue(ctx, key)
@@ -173,7 +173,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C
// Skip if we have already locked this key // Skip if we have already locked this key
continue continue
} }
if server.KeyExists(key) { if server.KeyExists(ctx, key) {
_, err = server.KeyRLock(ctx, key) _, err = server.KeyRLock(ctx, key)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not obtain lock for %s key", key) return nil, fmt.Errorf("could not obtain lock for %s key", key)
@@ -186,7 +186,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
locks[key] = false locks[key] = false
} }
} }
@@ -234,14 +234,14 @@ func handlePersist(ctx context.Context, cmd []string, server utils.Server, _ *ne
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
expireAt := server.GetExpiry(ctx, key) expireAt := server.GetExpiry(ctx, key)
if expireAt == (time.Time{}) { if expireAt == (time.Time{}) {
@@ -261,14 +261,14 @@ func handleExpireTime(ctx context.Context, cmd []string, server utils.Server, _
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":-2\r\n"), nil return []byte(":-2\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
expireAt := server.GetExpiry(ctx, key) expireAt := server.GetExpiry(ctx, key)
@@ -292,14 +292,14 @@ func handleTTL(ctx context.Context, cmd []string, server utils.Server, _ *net.Co
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":-2\r\n"), nil return []byte(":-2\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
expireAt := server.GetExpiry(ctx, key) expireAt := server.GetExpiry(ctx, key)
@@ -337,14 +337,14 @@ func handleExpire(ctx context.Context, cmd []string, server utils.Server, _ *net
expireAt = time.Now().Add(time.Duration(n) * time.Millisecond) expireAt = time.Now().Add(time.Duration(n) * time.Millisecond)
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
if len(cmd) == 3 { if len(cmd) == 3 {
server.SetExpiry(ctx, key, expireAt, true) server.SetExpiry(ctx, key, expireAt, true)
@@ -405,14 +405,14 @@ func handleExpireAt(ctx context.Context, cmd []string, server utils.Server, _ *n
expireAt = time.UnixMilli(n) expireAt = time.UnixMilli(n)
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
if len(cmd) == 3 { if len(cmd) == 3 {
server.SetExpiry(ctx, key, expireAt, true) server.SetExpiry(ctx, key, expireAt, true)

View File

@@ -177,7 +177,7 @@ func Test_HandleMSET(t *testing.T) {
t.Errorf("expected value %s for key %s, got %s", ev, key, value) t.Errorf("expected value %s for key %s, got %s", ev, key, value)
} }
} }
mockServer.KeyRUnlock(key) mockServer.KeyRUnlock(context.Background(), key)
} }
} }
} }
@@ -211,8 +211,10 @@ func Test_HandleGET(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(ctx, key, value) if err = mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil) res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil)
if err != nil { if err != nil {
@@ -297,8 +299,10 @@ func Test_HandleMGET(t *testing.T) {
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, test.presetValues[i]) if err = mockServer.SetValue(context.Background(), key, test.presetValues[i]); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(context.Background(), key)
} }
// Test the command and its results // Test the command and its results
res, err := handleMGet(context.Background(), test.command, mockServer, nil) res, err := handleMGet(context.Background(), test.command, mockServer, nil)

View File

@@ -29,12 +29,12 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne
entries[cmd[i]] = utils.AdaptType(cmd[i+1]) entries[cmd[i]] = utils.AdaptType(cmd[i+1])
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
_, err = server.CreateKeyAndLock(ctx, key) _, err = server.CreateKeyAndLock(ctx, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
if err = server.SetValue(ctx, key, entries); err != nil { if err = server.SetValue(ctx, key, entries); err != nil {
return nil, err return nil, err
} }
@@ -44,7 +44,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -79,14 +79,14 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0] key := keys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -130,14 +130,14 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0] key := keys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -180,14 +180,14 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -242,14 +242,14 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co
} }
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -337,14 +337,14 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -362,14 +362,14 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -410,11 +410,11 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
intIncrement = i intIncrement = i
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
if _, err := server.CreateKeyAndLock(ctx, key); err != nil { if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
hash := make(map[string]interface{}) hash := make(map[string]interface{})
if strings.EqualFold(cmd[0], "hincrbyfloat") { if strings.EqualFold(cmd[0], "hincrbyfloat") {
hash[field] = floatIncrement hash[field] = floatIncrement
@@ -434,7 +434,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
if _, err := server.KeyLock(ctx, key); err != nil { if _, err := server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -484,14 +484,14 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -525,14 +525,14 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0] key := keys[0]
field := cmd[2] field := cmd[2]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {
@@ -555,14 +555,14 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0] key := keys[0]
fields := cmd[2:] fields := cmd[2:]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
hash, ok := server.GetValue(ctx, key).(map[string]interface{}) hash, ok := server.GetValue(ctx, key).(map[string]interface{})
if !ok { if !ok {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -98,15 +99,18 @@ func Test_HandleHSET(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSET/HSETNX, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHSET(context.Background(), test.command, mockServer, nil) res, err := handleHSET(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -122,10 +126,10 @@ func Test_HandleHSET(t *testing.T) {
t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, rv.Integer()) t.Errorf("expected response \"%d\", got \"%d\"", test.expectedResponse, rv.Integer())
} }
// Check that all the values are what is expected // Check that all the values are what is expected
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
if !ok { if !ok {
t.Errorf("value at key \"%s\" is not a hash map", test.key) t.Errorf("value at key \"%s\" is not a hash map", test.key)
} }
@@ -242,15 +246,19 @@ func Test_HandleHINCRBY(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHINCRBY(context.Background(), test.command, mockServer, nil) res, err := handleHINCRBY(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -275,10 +283,10 @@ func Test_HandleHINCRBY(t *testing.T) {
} }
} }
// Check that all the values are what is expected // Check that all the values are what is expected
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{})
if !ok { if !ok {
t.Errorf("value at key \"%s\" is not a hash map", test.key) t.Errorf("value at key \"%s\" is not a hash map", test.key)
} }
@@ -340,15 +348,19 @@ func Test_HandleHGET(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HINCRBY, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHGET(context.Background(), test.command, mockServer, nil) res, err := handleHGET(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -440,15 +452,19 @@ func Test_HandleHSTRLEN(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HSTRLEN, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHSTRLEN(context.Background(), test.command, mockServer, nil) res, err := handleHSTRLEN(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -535,15 +551,19 @@ func Test_HandleHVALS(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HVALS, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHVALS(context.Background(), test.command, mockServer, nil) res, err := handleHVALS(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -727,15 +747,19 @@ func Test_HandleHRANDFIELD(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HRANDFIELD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHRANDFIELD(context.Background(), test.command, mockServer, nil) res, err := handleHRANDFIELD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -841,15 +865,19 @@ func Test_HandleHLEN(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HLEN, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHLEN(context.Background(), test.command, mockServer, nil) res, err := handleHLEN(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -931,15 +959,19 @@ func Test_HandleHKeys(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HKEYS, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHKEYS(context.Background(), test.command, mockServer, nil) res, err := handleHKEYS(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1028,15 +1060,19 @@ func Test_HandleHGETALL(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HGETALL, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHGETALL(context.Background(), test.command, mockServer, nil) res, err := handleHGETALL(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1136,15 +1172,19 @@ func Test_HandleHEXISTS(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HEXISTS, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHEXISTS(context.Background(), test.command, mockServer, nil) res, err := handleHEXISTS(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1226,15 +1266,19 @@ func Test_HandleHDEL(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("HDEL, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleHDEL(context.Background(), test.command, mockServer, nil) res, err := handleHDEL(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1252,10 +1296,10 @@ func Test_HandleHDEL(t *testing.T) {
} }
continue continue
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
if hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}); ok { if hash, ok := mockServer.GetValue(ctx, test.key).(map[string]interface{}); ok {
for field, value := range hash { for field, value := range hash {
if value != test.expectedValue[field] { if value != test.expectedValue[field] {
t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value) t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value)

View File

@@ -19,7 +19,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// If key does not exist, return 0 // If key does not exist, return 0
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
@@ -27,7 +27,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
if list, ok := server.GetValue(ctx, key).([]interface{}); ok { if list, ok := server.GetValue(ctx, key).([]interface{}); ok {
return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil
@@ -49,7 +49,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, errors.New("index must be an integer") return nil, errors.New("index must be an integer")
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, errors.New("LINDEX command on non-list item") return nil, errors.New("LINDEX command on non-list item")
} }
@@ -57,7 +57,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, err return nil, err
} }
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
if !ok { if !ok {
return nil, errors.New("LINDEX command on non-list item") return nil, errors.New("LINDEX command on non-list item")
@@ -84,14 +84,14 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, errors.New("start and end indices must be integers") return nil, errors.New("start and end indices must be integers")
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, errors.New("LRANGE command on non-list item") return nil, errors.New("LRANGE command on non-list item")
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
if !ok { if !ok {
@@ -162,14 +162,14 @@ func handleLSet(ctx context.Context, cmd []string, server utils.Server, conn *ne
return nil, errors.New("index must be an integer") return nil, errors.New("index must be an integer")
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, errors.New("LSET command on non-list item") return nil, errors.New("LSET command on non-list item")
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
if !ok { if !ok {
@@ -206,14 +206,14 @@ func handleLTrim(ctx context.Context, cmd []string, server utils.Server, conn *n
return nil, errors.New("end index must be greater than start index or -1") return nil, errors.New("end index must be greater than start index or -1")
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, errors.New("LTRIM command on non-list item") return nil, errors.New("LTRIM command on non-list item")
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
if !ok { if !ok {
@@ -253,14 +253,14 @@ func handleLRem(ctx context.Context, cmd []string, server utils.Server, conn *ne
absoluteCount := utils.AbsInt(count) absoluteCount := utils.AbsInt(count)
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, errors.New("LREM command on non-list item") return nil, errors.New("LREM command on non-list item")
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
if !ok { if !ok {
@@ -320,20 +320,20 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n
return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT") return nil, errors.New("wherefrom and whereto arguments must be either LEFT or RIGHT")
} }
if !server.KeyExists(source) || !server.KeyExists(destination) { if !server.KeyExists(ctx, source) || !server.KeyExists(ctx, destination) {
return nil, errors.New("both source and destination must be lists") return nil, errors.New("both source and destination must be lists")
} }
if _, err = server.KeyLock(ctx, source); err != nil { if _, err = server.KeyLock(ctx, source); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(source) defer server.KeyUnlock(ctx, source)
_, err = server.KeyLock(ctx, destination) _, err = server.KeyLock(ctx, destination)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
sourceList, sourceOk := server.GetValue(ctx, source).([]interface{}) sourceList, sourceOk := server.GetValue(ctx, source).([]interface{})
destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{}) destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{})
@@ -380,7 +380,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) { switch strings.ToLower(cmd[0]) {
case "lpushx": case "lpushx":
return nil, errors.New("LPUSHX command on non-list item") return nil, errors.New("LPUSHX command on non-list item")
@@ -397,7 +397,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
currentList := server.GetValue(ctx, key) currentList := server.GetValue(ctx, key)
@@ -426,7 +426,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n
newElems = append(newElems, utils.AdaptType(elem)) newElems = append(newElems, utils.AdaptType(elem))
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
switch strings.ToLower(cmd[0]) { switch strings.ToLower(cmd[0]) {
case "rpushx": case "rpushx":
return nil, errors.New("RPUSHX command on non-list item") return nil, errors.New("RPUSHX command on non-list item")
@@ -434,7 +434,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n
if _, err = server.CreateKeyAndLock(ctx, key); err != nil { if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
if err = server.SetValue(ctx, key, []interface{}{}); err != nil { if err = server.SetValue(ctx, key, []interface{}{}); err != nil {
return nil, err return nil, err
} }
@@ -443,7 +443,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
} }
currentList := server.GetValue(ctx, key) currentList := server.GetValue(ctx, key)
@@ -468,14 +468,14 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0])) return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0]))
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
list, ok := server.GetValue(ctx, key).([]interface{}) list, ok := server.GetValue(ctx, key).([]interface{})
if !ok { if !ok {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -69,15 +70,19 @@ func Test_HandleLLEN(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LLEN, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLLen(context.Background(), test.command, mockServer, nil) res, err := handleLLen(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -199,15 +204,19 @@ func Test_HandleLINDEX(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LINDEX, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLIndex(context.Background(), test.command, mockServer, nil) res, err := handleLIndex(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -350,15 +359,19 @@ func Test_HandleLRANGE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LRANGE, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLRange(context.Background(), test.command, mockServer, nil) res, err := handleLRange(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -486,15 +499,19 @@ func Test_HandleLSET(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LSET, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLSet(context.Background(), test.command, mockServer, nil) res, err := handleLSet(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -509,10 +526,10 @@ func Test_HandleLSET(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -524,7 +541,7 @@ func Test_HandleLSET(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -644,15 +661,19 @@ func Test_HandleLTRIM(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LTRIM, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLTrim(context.Background(), test.command, mockServer, nil) res, err := handleLTrim(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -667,10 +688,10 @@ func Test_HandleLTRIM(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -682,7 +703,7 @@ func Test_HandleLTRIM(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -763,15 +784,19 @@ func Test_HandleLREM(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LREM, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLRem(context.Background(), test.command, mockServer, nil) res, err := handleLRem(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -786,10 +811,10 @@ func Test_HandleLREM(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -801,7 +826,7 @@ func Test_HandleLREM(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -970,17 +995,21 @@ func Test_HandleLMOVE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LMOVE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValue { for key, value := range test.presetValue {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleLMove(context.Background(), test.command, mockServer, nil) res, err := handleLMove(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -996,10 +1025,10 @@ func Test_HandleLMOVE(t *testing.T) {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
for key, value := range test.expectedValue { for key, value := range test.expectedValue {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), key).([]interface{}) list, ok := mockServer.GetValue(ctx, key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -1015,7 +1044,7 @@ func Test_HandleLMOVE(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, expectedList[i], list[i])
} }
} }
mockServer.KeyRUnlock(key) mockServer.KeyRUnlock(ctx, key)
} }
} }
} }
@@ -1079,15 +1108,19 @@ func Test_HandleLPUSH(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LPUSH/LPUSHX, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleLPush(context.Background(), test.command, mockServer, nil) res, err := handleLPush(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1102,10 +1135,10 @@ func Test_HandleLPUSH(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -1117,7 +1150,7 @@ func Test_HandleLPUSH(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -1180,15 +1213,19 @@ func Test_HandleRPUSH(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("RPUSH/RPUSHX, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleRPush(context.Background(), test.command, mockServer, nil) res, err := handleRPush(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1203,10 +1240,10 @@ func Test_HandleRPUSH(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -1218,7 +1255,7 @@ func Test_HandleRPUSH(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -1290,15 +1327,19 @@ func Test_HandlePop(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("LPOP/RPOP, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handlePop(context.Background(), test.command, mockServer, nil) res, err := handlePop(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1313,10 +1354,10 @@ func Test_HandlePop(t *testing.T) {
if rv.String() != test.expectedResponse { if rv.String() != test.expectedResponse {
t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String()) t.Errorf("expected \"%s\" response, got \"%s\"", test.expectedResponse, rv.String())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) list, ok := mockServer.GetValue(ctx, test.key).([]interface{})
if !ok { if !ok {
t.Error("expected value to be list, got another type") t.Error("expected value to be list, got another type")
} }
@@ -1328,6 +1369,6 @@ func Test_HandlePop(t *testing.T) {
t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i]) t.Errorf("expected element at index %d to be %+v, got %+v", i, test.expectedValue[i], list[i])
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }

View File

@@ -20,7 +20,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
var set *Set var set *Set
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
set = NewSet(cmd[2:]) set = NewSet(cmd[2:])
if ok, err := server.CreateKeyAndLock(ctx, key); !ok && err != nil { if ok, err := server.CreateKeyAndLock(ctx, key); !ok && err != nil {
return nil, err return nil, err
@@ -28,14 +28,14 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
if err = server.SetValue(ctx, key, set); err != nil { if err = server.SetValue(ctx, key, set); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(key) server.KeyUnlock(ctx, key)
return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil return []byte(fmt.Sprintf(":%d\r\n", len(cmd[2:]))), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -55,14 +55,14 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(fmt.Sprintf(":0\r\n")), nil return []byte(fmt.Sprintf(":0\r\n")), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -81,13 +81,13 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
} }
// Extract base set first // Extract base set first
if !server.KeyExists(keys[0]) { if !server.KeyExists(ctx, keys[0]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[0]) return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[0])
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(keys[0]) defer server.KeyRUnlock(ctx, keys[0])
baseSet, ok := server.GetValue(ctx, keys[0]).(*Set) baseSet, ok := server.GetValue(ctx, keys[0]).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[0]) return nil, fmt.Errorf("value at key %s is not a set", keys[0])
@@ -97,13 +97,13 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys[1:] {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
continue continue
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -144,13 +144,13 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
destination := keys[0] destination := keys[0]
// Extract base set first // Extract base set first
if !server.KeyExists(keys[1]) { if !server.KeyExists(ctx, keys[1]) {
return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[1]) return nil, fmt.Errorf("key for base set \"%s\" does not exist", keys[1])
} }
if _, err := server.KeyRLock(ctx, keys[1]); err != nil { if _, err := server.KeyRLock(ctx, keys[1]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(keys[1]) defer server.KeyRUnlock(ctx, keys[1])
baseSet, ok := server.GetValue(ctx, keys[1]).(*Set) baseSet, ok := server.GetValue(ctx, keys[1]).(*Set)
if !ok { if !ok {
return nil, fmt.Errorf("value at key %s is not a set", keys[1]) return nil, fmt.Errorf("value at key %s is not a set", keys[1])
@@ -160,13 +160,13 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys[2:] { for _, key := range keys[2:] {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
continue continue
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -189,14 +189,14 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
res := fmt.Sprintf(":%d\r\n", len(elems)) res := fmt.Sprintf(":%d\r\n", len(elems))
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
if err = server.SetValue(ctx, destination, diff); err != nil { if err = server.SetValue(ctx, destination, diff); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(destination) server.KeyUnlock(ctx, destination)
return []byte(res), nil return []byte(res), nil
} }
@@ -206,7 +206,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
if err = server.SetValue(ctx, destination, diff); err != nil { if err = server.SetValue(ctx, destination, diff); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(destination) server.KeyUnlock(ctx, destination)
return []byte(res), nil return []byte(res), nil
} }
@@ -221,13 +221,13 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn *
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys[0:] { for _, key := range keys[0:] {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
@@ -297,13 +297,13 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys { for _, key := range keys {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
@@ -343,13 +343,13 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys[1:] {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// If key does not exist, then there is no intersection // If key does not exist, then there is no intersection
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
@@ -373,7 +373,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
intersect, _ := Intersection(0, sets...) intersect, _ := Intersection(0, sets...)
destination := keys[0] destination := keys[0]
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -386,7 +386,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
if err = server.SetValue(ctx, destination, intersect); err != nil { if err = server.SetValue(ctx, destination, intersect); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(destination) server.KeyUnlock(ctx, destination)
return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil return []byte(fmt.Sprintf(":%d\r\n", intersect.Cardinality())), nil
} }
@@ -399,14 +399,14 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -428,14 +428,14 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -464,7 +464,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co
key := keys[0] key := keys[0]
members := cmd[2:] members := cmd[2:]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
res := fmt.Sprintf("*%d", len(members)) res := fmt.Sprintf("*%d", len(members))
for i, _ := range members { for i, _ := range members {
res = fmt.Sprintf("%s\r\n:0", res) res = fmt.Sprintf("%s\r\n:0", res)
@@ -478,7 +478,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -508,14 +508,14 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n
destination := keys[1] destination := keys[1]
member := cmd[3] member := cmd[3]
if !server.KeyExists(source) { if !server.KeyExists(ctx, source) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, source); err != nil { if _, err = server.KeyLock(ctx, source); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(source) defer server.KeyUnlock(ctx, source)
sourceSet, ok := server.GetValue(ctx, source).(*Set) sourceSet, ok := server.GetValue(ctx, source).(*Set)
if !ok { if !ok {
@@ -524,12 +524,12 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n
var destinationSet *Set var destinationSet *Set
if !server.KeyExists(destination) { if !server.KeyExists(ctx, destination) {
// Destination key does not exist // Destination key does not exist
if _, err = server.CreateKeyAndLock(ctx, destination); err != nil { if _, err = server.CreateKeyAndLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
destinationSet = NewSet([]string{}) destinationSet = NewSet([]string{})
if err = server.SetValue(ctx, destination, destinationSet); err != nil { if err = server.SetValue(ctx, destination, destinationSet); err != nil {
return nil, err return nil, err
@@ -539,7 +539,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n
if _, err := server.KeyLock(ctx, destination); err != nil { if _, err := server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
ds, ok := server.GetValue(ctx, destination).(*Set) ds, ok := server.GetValue(ctx, destination).(*Set)
if !ok { if !ok {
return nil, errors.New("destination is not a set") return nil, errors.New("destination is not a set")
@@ -569,14 +569,14 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
count = c count = c
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*-1\r\n"), nil return []byte("*-1\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -613,14 +613,14 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
count = c count = c
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*-1\r\n"), nil return []byte("*-1\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -649,14 +649,14 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0] key := keys[0]
members := cmd[2:] members := cmd[2:]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*Set) set, ok := server.GetValue(ctx, key).(*Set)
if !ok { if !ok {
@@ -678,13 +678,13 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn *
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys { for _, key := range keys {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
continue continue
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -729,13 +729,13 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
for _, key := range keys[1:] { for _, key := range keys[1:] {
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
continue continue
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
@@ -761,7 +761,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
destination := cmd[1] destination := cmd[1]
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -770,7 +770,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
if err = server.SetValue(ctx, destination, union); err != nil { if err = server.SetValue(ctx, destination, union); err != nil {
return nil, err return nil, err

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -59,15 +60,19 @@ func Test_HandleSADD(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SADD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSADD(context.Background(), test.command, mockServer, nil) res, err := handleSADD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -85,10 +90,10 @@ func Test_HandleSADD(t *testing.T) {
if rv.Integer() != test.expectedResponse { if rv.Integer() != test.expectedResponse {
t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer())
} }
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) set, ok := mockServer.GetValue(ctx, test.key).(*Set)
if !ok { if !ok {
t.Errorf("expected set value at key \"%s\"", test.key) t.Errorf("expected set value at key \"%s\"", test.key)
} }
@@ -100,7 +105,7 @@ func Test_HandleSADD(t *testing.T) {
t.Errorf("could not find member \"%s\" in expected set", member) t.Errorf("could not find member \"%s\" in expected set", member)
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -160,15 +165,19 @@ func Test_HandleSCARD(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SCARD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSCARD(context.Background(), test.command, mockServer, nil) res, err := handleSCARD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -260,17 +269,21 @@ func Test_HandleSDIFF(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SDIFF, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSDIFF(context.Background(), test.command, mockServer, nil) res, err := handleSDIFF(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -376,17 +389,21 @@ func Test_HandleSDIFFSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SDIFFSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSDIFFSTORE(context.Background(), test.command, mockServer, nil) res, err := handleSDIFFSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -405,10 +422,10 @@ func Test_HandleSDIFFSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) set, ok := mockServer.GetValue(ctx, test.destination).(*Set)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -417,7 +434,7 @@ func Test_HandleSDIFFSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem) t.Errorf("could not find element %s in the expected values", elem)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }
@@ -493,17 +510,21 @@ func Test_HandleSINTER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTER, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSINTER(context.Background(), test.command, mockServer, nil) res, err := handleSINTER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -596,29 +617,33 @@ func Test_HandleSINTERCARD(t *testing.T) {
"key15": NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}), "key15": NewSet([]string{"one", "two", "thirty-six", "twelve", "eleven"}),
"key16": NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}), "key16": NewSet([]string{"seven", "eight", "nine", "ten", "twelve"}),
}, },
command: []string{"SINTERSTORE", "key14", "key15", "key16"}, command: []string{"SINTERCARD", "key14", "key15", "key16"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("value at key key14 is not a set"), expectedError: errors.New("value at key key14 is not a set"),
}, },
{ // 7. Command too short { // 7. Command too short
preset: false, preset: false,
command: []string{"SINTERSTORE"}, command: []string{"SINTERCARD"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New(utils.WrongArgsResponse), expectedError: errors.New(utils.WrongArgsResponse),
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTERCARD, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSINTERCARD(context.Background(), test.command, mockServer, nil) res, err := handleSINTERCARD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -722,17 +747,21 @@ func Test_HandleSINTERSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SINTERSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSINTERSTORE(context.Background(), test.command, mockServer, nil) res, err := handleSINTERSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -751,10 +780,10 @@ func Test_HandleSINTERSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) set, ok := mockServer.GetValue(ctx, test.destination).(*Set)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -763,7 +792,7 @@ func Test_HandleSINTERSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem) t.Errorf("could not find element %s in the expected values", elem)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }
@@ -819,15 +848,19 @@ func Test_HandleSISMEMBER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SISMEMBER, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSISMEMBER(context.Background(), test.command, mockServer, nil) res, err := handleSISMEMBER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -897,15 +930,19 @@ func Test_HandleSMEMBERS(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMEMBERS, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSMEMBERS(context.Background(), test.command, mockServer, nil) res, err := handleSMEMBERS(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -979,15 +1016,19 @@ func Test_HandleSMISMEMBER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMISMEMBER, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSMISMEMBER(context.Background(), test.command, mockServer, nil) res, err := handleSMISMEMBER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1090,17 +1131,21 @@ func Test_HandleSMOVE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SMOVE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSMOVE(context.Background(), test.command, mockServer, nil) res, err := handleSMOVE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1123,10 +1168,10 @@ func Test_HandleSMOVE(t *testing.T) {
if !ok { if !ok {
t.Errorf("expected value at \"%s\" should be a set", key) t.Errorf("expected value at \"%s\" should be a set", key)
} }
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(key) t.Error(key)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*Set) set, ok := mockServer.GetValue(ctx, key).(*Set)
if !ok { if !ok {
t.Errorf("expected set \"%s\" to be a set, got another type", key) t.Errorf("expected set \"%s\" to be a set, got another type", key)
} }
@@ -1138,7 +1183,7 @@ func Test_HandleSMOVE(t *testing.T) {
t.Errorf("could not find element \"%s\" in the expected set", element) t.Errorf("could not find element \"%s\" in the expected set", element)
} }
} }
mockServer.KeyRUnlock(key) mockServer.KeyRUnlock(ctx, key)
} }
} }
} }
@@ -1190,15 +1235,19 @@ func Test_HandleSPOP(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SPOP, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSPOP(context.Background(), test.command, mockServer, nil) res, err := handleSPOP(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1220,10 +1269,10 @@ func Test_HandleSPOP(t *testing.T) {
} }
} }
// 2. Fetch the set and check if its cardinality is what we expect. // 2. Fetch the set and check if its cardinality is what we expect.
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) set, ok := mockServer.GetValue(ctx, test.key).(*Set)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key)
} }
@@ -1301,15 +1350,19 @@ func Test_HandleSRANDMEMBER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SRANDMEMBER, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSRANDMEMBER(context.Background(), test.command, mockServer, nil) res, err := handleSRANDMEMBER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1331,10 +1384,10 @@ func Test_HandleSRANDMEMBER(t *testing.T) {
} }
} }
// 2. Fetch the set and check if its cardinality is what we expect. // 2. Fetch the set and check if its cardinality is what we expect.
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) set, ok := mockServer.GetValue(ctx, test.key).(*Set)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key)
} }
@@ -1407,15 +1460,19 @@ func Test_HandleSREM(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SREM, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSREM(context.Background(), test.command, mockServer, nil) res, err := handleSREM(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1434,10 +1491,10 @@ func Test_HandleSREM(t *testing.T) {
t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected integer response %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) set, ok := mockServer.GetValue(ctx, test.key).(*Set)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key)
} }
@@ -1446,7 +1503,7 @@ func Test_HandleSREM(t *testing.T) {
t.Errorf("element \"%s\" not found in expected set values but found in set", element) t.Errorf("element \"%s\" not found in expected set values but found in set", element)
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
} }
@@ -1515,17 +1572,21 @@ func Test_HandleSUNION(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUNION, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSUNION(context.Background(), test.command, mockServer, nil) res, err := handleSUNION(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1609,17 +1670,21 @@ func Test_HandleSUNIONSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUNIONSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleSUNIONSTORE(context.Background(), test.command, mockServer, nil) res, err := handleSUNIONSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1638,10 +1703,10 @@ func Test_HandleSUNIONSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) set, ok := mockServer.GetValue(ctx, test.destination).(*Set)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -1650,7 +1715,7 @@ func Test_HandleSUNIONSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem) t.Errorf("could not find element %s in the expected values", elem)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }

View File

@@ -126,13 +126,13 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
} }
} }
if server.KeyExists(key) { if server.KeyExists(ctx, key) {
// Key exists // Key exists
_, err = server.KeyLock(ctx, key) _, err = server.KeyLock(ctx, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
@@ -154,7 +154,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne
if _, err = server.CreateKeyAndLock(ctx, key); err != nil { if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set := NewSortedSet(members) set := NewSortedSet(members)
if err = server.SetValue(ctx, key, set); err != nil { if err = server.SetValue(ctx, key, set); err != nil {
@@ -171,14 +171,14 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n
} }
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -232,14 +232,14 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn *
maximum = Score(s) maximum = Score(s)
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -266,14 +266,14 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con
minimum := cmd[2] minimum := cmd[2]
maximum := cmd[3] maximum := cmd[3]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -318,20 +318,20 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
// Extract base set // Extract base set
if !server.KeyExists(keys[0]) { if !server.KeyExists(ctx, keys[0]) {
// If base set does not exist, return an empty array // If base set does not exist, return an empty array
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(keys[0]) defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
@@ -341,7 +341,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n
var sets []*SortedSet var sets []*SortedSet
for i := 1; i < len(keys); i++ { for i := 1; i < len(keys); i++ {
if !server.KeyExists(keys[i]) { if !server.KeyExists(ctx, keys[i]) {
continue continue
} }
locked, err := server.KeyRLock(ctx, keys[i]) locked, err := server.KeyRLock(ctx, keys[i])
@@ -386,20 +386,20 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
// Extract base set // Extract base set
if !server.KeyExists(keys[0]) { if !server.KeyExists(ctx, keys[0]) {
// If base set does not exist, return 0 // If base set does not exist, return 0
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, keys[0]); err != nil { if _, err = server.KeyRLock(ctx, keys[0]); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(keys[0]) defer server.KeyRUnlock(ctx, keys[0])
baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) return nil, fmt.Errorf("value at %s is not a sorted set", keys[0])
@@ -408,7 +408,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
var sets []*SortedSet var sets []*SortedSet
for i := 1; i < len(keys); i++ { for i := 1; i < len(keys); i++ {
if server.KeyExists(keys[i]) { if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil { if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err return nil, err
} }
@@ -422,7 +422,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
diff := baseSortedSet.Subtract(sets) diff := baseSortedSet.Subtract(sets)
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -431,7 +431,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
if err = server.SetValue(ctx, destination, diff); err != nil { if err = server.SetValue(ctx, destination, diff); err != nil {
return nil, err return nil, err
@@ -469,7 +469,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
increment = Score(s) increment = Score(s)
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// If the key does not exist, create a new sorted set at the key with // If the key does not exist, create a new sorted set at the key with
// the member and increment as the first value // the member and increment as the first value
if _, err = server.CreateKeyAndLock(ctx, key); err != nil { if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
@@ -478,14 +478,14 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn
if err = server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}})); err != nil { if err = server.SetValue(ctx, key, NewSortedSet([]MemberParam{{value: member, score: increment}})); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(key) server.KeyUnlock(ctx, key)
return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil return []byte(fmt.Sprintf("+%s\r\n", strconv.FormatFloat(float64(increment), 'f', -1, 64))), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
@@ -518,7 +518,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
@@ -526,7 +526,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn *
var setParams []SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if !server.KeyExists(keys[i]) { if !server.KeyExists(ctx, keys[i]) {
// If any of the keys is non-existent, return an empty array as there's no intersect // If any of the keys is non-existent, return an empty array as there's no intersect
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
@@ -585,7 +585,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
@@ -593,7 +593,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
var setParams []SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if !server.KeyExists(keys[i]) { if !server.KeyExists(ctx, keys[i]) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, keys[i]); err != nil { if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
@@ -612,7 +612,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
intersect := Intersect(aggregate, setParams...) intersect := Intersect(aggregate, setParams...)
if server.KeyExists(destination) && intersect.Cardinality() > 0 { if server.KeyExists(ctx, destination) && intersect.Cardinality() > 0 {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -621,7 +621,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
if err = server.SetValue(ctx, destination, intersect); err != nil { if err = server.SetValue(ctx, destination, intersect); err != nil {
return nil, err return nil, err
@@ -677,21 +677,21 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n
} }
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if server.KeyExists(keys[i]) { if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyLock(ctx, keys[i]); err != nil { if _, err = server.KeyLock(ctx, keys[i]); err != nil {
continue continue
} }
v, ok := server.GetValue(ctx, keys[i]).(*SortedSet) v, ok := server.GetValue(ctx, keys[i]).(*SortedSet)
if !ok || v.Cardinality() == 0 { if !ok || v.Cardinality() == 0 {
server.KeyUnlock(keys[i]) server.KeyUnlock(ctx, keys[i])
continue continue
} }
popped, err := v.Pop(count, policy) popped, err := v.Pop(count, policy)
if err != nil { if err != nil {
server.KeyUnlock(keys[i]) server.KeyUnlock(ctx, keys[i])
return nil, err return nil, err
} }
server.KeyUnlock(keys[i]) server.KeyUnlock(ctx, keys[i])
res := fmt.Sprintf("*%d", popped.Cardinality()) res := fmt.Sprintf("*%d", popped.Cardinality())
@@ -730,14 +730,14 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne
count = c count = c
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -767,14 +767,14 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -826,14 +826,14 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c
} }
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -870,14 +870,14 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n
withscores = true withscores = true
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -914,14 +914,14 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -946,13 +946,13 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn *
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("$-1\r\n"), nil return []byte("$-1\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
return nil, fmt.Errorf("value at %s is not a sorted set", key) return nil, fmt.Errorf("value at %s is not a sorted set", key)
@@ -987,14 +987,14 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv
return nil, err return nil, err
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -1029,14 +1029,14 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve
return nil, err return nil, err
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -1086,14 +1086,14 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server
minimum := cmd[2] minimum := cmd[2]
maximum := cmd[3] maximum := cmd[3]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err = server.KeyLock(ctx, key); err != nil { if _, err = server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -1184,14 +1184,14 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn *
} }
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
set, ok := server.GetValue(ctx, key).(*SortedSet) set, ok := server.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
@@ -1321,14 +1321,14 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
} }
} }
if !server.KeyExists(source) { if !server.KeyExists(ctx, source) {
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
if _, err = server.KeyRLock(ctx, source); err != nil { if _, err = server.KeyRLock(ctx, source); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(source) defer server.KeyRUnlock(ctx, source)
set, ok := server.GetValue(ctx, source).(*SortedSet) set, ok := server.GetValue(ctx, source).(*SortedSet)
if !ok { if !ok {
@@ -1387,7 +1387,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
newSortedSet := NewSortedSet(resultMembers) newSortedSet := NewSortedSet(resultMembers)
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -1396,7 +1396,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
if err = server.SetValue(ctx, destination, newSortedSet); err != nil { if err = server.SetValue(ctx, destination, newSortedSet); err != nil {
return nil, err return nil, err
@@ -1419,7 +1419,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn *
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
@@ -1427,7 +1427,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn *
var setParams []SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if server.KeyExists(keys[i]) { if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil { if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err return nil, err
} }
@@ -1481,7 +1481,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
defer func() { defer func() {
for key, locked := range locks { for key, locked := range locks {
if locked { if locked {
server.KeyRUnlock(key) server.KeyRUnlock(ctx, key)
} }
} }
}() }()
@@ -1489,7 +1489,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
var setParams []SortedSetParam var setParams []SortedSetParam
for i := 0; i < len(keys); i++ { for i := 0; i < len(keys); i++ {
if server.KeyExists(keys[i]) { if server.KeyExists(ctx, keys[i]) {
if _, err = server.KeyRLock(ctx, keys[i]); err != nil { if _, err = server.KeyRLock(ctx, keys[i]); err != nil {
return nil, err return nil, err
} }
@@ -1507,7 +1507,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
union := Union(aggregate, setParams...) union := Union(aggregate, setParams...)
if server.KeyExists(destination) { if server.KeyExists(ctx, destination) {
if _, err = server.KeyLock(ctx, destination); err != nil { if _, err = server.KeyLock(ctx, destination); err != nil {
return nil, err return nil, err
} }
@@ -1516,7 +1516,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c
return nil, err return nil, err
} }
} }
defer server.KeyUnlock(destination) defer server.KeyUnlock(ctx, destination)
if err = server.SetValue(ctx, destination, union); err != nil { if err = server.SetValue(ctx, destination, union); err != nil {
return nil, err return nil, err

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -218,15 +219,19 @@ func Test_HandleZADD(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZADD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZADD(context.Background(), test.command, mockServer, nil) res, err := handleZADD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -245,10 +250,10 @@ func Test_HandleZADD(t *testing.T) {
t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer()) t.Errorf("expected response %d at key \"%s\", got %d", test.expectedResponse, test.key, rv.Integer())
} }
// Fetch the sorted set from the server and check it against the expected result // Fetch the sorted set from the server and check it against the expected result
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
sortedSet, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) sortedSet, ok := mockServer.GetValue(ctx, test.key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key) t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key)
} }
@@ -258,7 +263,7 @@ func Test_HandleZADD(t *testing.T) {
if !sortedSet.Equals(test.expectedValue) { if !sortedSet.Equals(test.expectedValue) {
t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, sortedSet) t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, sortedSet)
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -325,15 +330,19 @@ func Test_HandleZCARD(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZCARD(context.Background(), test.command, mockServer, nil) res, err := handleZCARD(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -464,15 +473,19 @@ func Test_HandleZCOUNT(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZCARD, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZCOUNT(context.Background(), test.command, mockServer, nil) res, err := handleZCOUNT(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -577,15 +590,19 @@ func Test_HandleZLEXCOUNT(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZLEXCOUNT, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZLEXCOUNT(context.Background(), test.command, mockServer, nil) res, err := handleZLEXCOUNT(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -728,17 +745,21 @@ func Test_HandleZDIFF(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFF, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZDIFF(context.Background(), test.command, mockServer, nil) res, err := handleZDIFF(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -903,17 +924,21 @@ func Test_HandleZDIFFSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZDIFFSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZDIFFSTORE(context.Background(), test.command, mockServer, nil) res, err := handleZDIFFSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -932,10 +957,10 @@ func Test_HandleZDIFFSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -944,7 +969,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem.value) t.Errorf("could not find element %s in the expected values", elem.value)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }
@@ -1119,15 +1144,19 @@ func Test_HandleZINCRBY(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINCRBY, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZINCRBY(context.Background(), test.command, mockServer, nil) res, err := handleZINCRBY(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1146,10 +1175,10 @@ func Test_HandleZINCRBY(t *testing.T) {
t.Errorf("expected response integer %s, got %s", test.expectedResponse, rv.String()) t.Errorf("expected response integer %s, got %s", test.expectedResponse, rv.String())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.key) t.Errorf("expected vaule at key %s to be set, got another type", test.key)
} }
@@ -1165,7 +1194,7 @@ func Test_HandleZINCRBY(t *testing.T) {
) )
} }
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
} }
@@ -1339,17 +1368,21 @@ func Test_HandleZMPOP(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMPOP, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZMPOP(context.Background(), test.command, mockServer, nil) res, err := handleZMPOP(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1381,10 +1414,10 @@ func Test_HandleZMPOP(t *testing.T) {
} }
} }
for key, expectedSortedSet := range test.expectedValues { for key, expectedSortedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected key \"%s\" to be a sorted set, got another type", key)
} }
@@ -1510,17 +1543,21 @@ func Test_HandleZPOP(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZPOPMIN/ZPOPMAX, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZPOP(context.Background(), test.command, mockServer, nil) res, err := handleZPOP(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1552,10 +1589,10 @@ func Test_HandleZPOP(t *testing.T) {
} }
} }
for key, expectedSortedSet := range test.expectedValues { for key, expectedSortedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected key \"%s\" to be a sorted set, got another type", key)
} }
@@ -1610,17 +1647,21 @@ func Test_HandleZMSCORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZMSCORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZMSCORE(context.Background(), test.command, mockServer, nil) res, err := handleZMSCORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1710,17 +1751,21 @@ func Test_HandleZSCORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZSCORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZSCORE(context.Background(), test.command, mockServer, nil) res, err := handleZSCORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1825,15 +1870,19 @@ func Test_HandleZRANDMEMBER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANDMEMBER, %d", i))
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleZRANDMEMBER(context.Background(), test.command, mockServer, nil) res, err := handleZRANDMEMBER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -1866,10 +1915,10 @@ func Test_HandleZRANDMEMBER(t *testing.T) {
} }
} }
// 2. Fetch the set and check if its cardinality is what we expect. // 2. Fetch the set and check if its cardinality is what we expect.
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key)
} }
@@ -1980,17 +2029,21 @@ func Test_HandleZRANK(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANK, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZRANK(context.Background(), test.command, mockServer, nil) res, err := handleZRANK(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2078,17 +2131,21 @@ func Test_HandleZREM(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREM, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZREM(context.Background(), test.command, mockServer, nil) res, err := handleZREM(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2109,10 +2166,10 @@ func Test_HandleZREM(t *testing.T) {
// Check if the expected sorted set is the same at the current one // Check if the expected sorted set is the same at the current one
if test.expectedValues != nil { if test.expectedValues != nil {
for key, expectedSet := range test.expectedValues { for key, expectedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key)
} }
@@ -2184,17 +2241,21 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYSCORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZREMRANGEBYSCORE(context.Background(), test.command, mockServer, nil) res, err := handleZREMRANGEBYSCORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2215,10 +2276,10 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) {
// Check if the expected values are the same // Check if the expected values are the same
if test.expectedValues != nil { if test.expectedValues != nil {
for key, expectedSet := range test.expectedValues { for key, expectedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key)
} }
@@ -2344,17 +2405,21 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYRANK, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZREMRANGEBYRANK(context.Background(), test.command, mockServer, nil) res, err := handleZREMRANGEBYRANK(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2375,10 +2440,10 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) {
// Check if the expected values are the same // Check if the expected values are the same
if test.expectedValues != nil { if test.expectedValues != nil {
for key, expectedSet := range test.expectedValues { for key, expectedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key)
} }
@@ -2475,17 +2540,21 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZREMRANGEBYLEX, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZREMRANGEBYLEX(context.Background(), test.command, mockServer, nil) res, err := handleZREMRANGEBYLEX(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2506,10 +2575,10 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) {
// Check if the expected values are the same // Check if the expected values are the same
if test.expectedValues != nil { if test.expectedValues != nil {
for key, expectedSet := range test.expectedValues { for key, expectedSet := range test.expectedValues {
if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { if _, err = mockServer.KeyRLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) set, ok := mockServer.GetValue(ctx, key).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key)
} }
@@ -2720,17 +2789,21 @@ func Test_HandleZRANGE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZRANGE(context.Background(), test.command, mockServer, nil) res, err := handleZRANGE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -2790,7 +2863,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination1", destination: "destination1",
command: []string{"ZRANGE", "destination1", "key1", "3", "7", "BYSCORE"}, command: []string{"ZRANGESTORE", "destination1", "key1", "3", "7", "BYSCORE"},
expectedResponse: 5, expectedResponse: 5,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5},
@@ -2809,7 +2882,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination2", destination: "destination2",
command: []string{"ZRANGE", "destination2", "key2", "3", "7", "BYSCORE", "WITHSCORES"}, command: []string{"ZRANGESTORE", "destination2", "key2", "3", "7", "BYSCORE", "WITHSCORES"},
expectedResponse: 5, expectedResponse: 5,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5},
@@ -2829,7 +2902,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination3", destination: "destination3",
command: []string{"ZRANGE", "destination3", "key3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"}, command: []string{"ZRANGESTORE", "destination3", "key3", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4"},
expectedResponse: 3, expectedResponse: 3,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5}, {value: "three", score: 3}, {value: "four", score: 4}, {value: "five", score: 5},
@@ -2849,7 +2922,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination4", destination: "destination4",
command: []string{"ZRANGE", "destination4", "key4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"}, command: []string{"ZRANGESTORE", "destination4", "key4", "3", "7", "BYSCORE", "WITHSCORES", "LIMIT", "2", "4", "REV"},
expectedResponse: 3, expectedResponse: 3,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "six", score: 6}, {value: "five", score: 5}, {value: "four", score: 4}, {value: "six", score: 6}, {value: "five", score: 5}, {value: "four", score: 4},
@@ -2867,7 +2940,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination5", destination: "destination5",
command: []string{"ZRANGE", "destination5", "key5", "c", "g", "BYLEX"}, command: []string{"ZRANGESTORE", "destination5", "key5", "c", "g", "BYLEX"},
expectedResponse: 5, expectedResponse: 5,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1}, {value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1},
@@ -2886,7 +2959,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination6", destination: "destination6",
command: []string{"ZRANGE", "destination6", "key6", "a", "f", "BYLEX", "WITHSCORES"}, command: []string{"ZRANGESTORE", "destination6", "key6", "a", "f", "BYLEX", "WITHSCORES"},
expectedResponse: 6, expectedResponse: 6,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "a", score: 1}, {value: "b", score: 1}, {value: "c", score: 1}, {value: "a", score: 1}, {value: "b", score: 1}, {value: "c", score: 1},
@@ -2906,7 +2979,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination7", destination: "destination7",
command: []string{"ZRANGE", "destination7", "key7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, command: []string{"ZRANGESTORE", "destination7", "key7", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"},
expectedResponse: 3, expectedResponse: 3,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1}, {value: "c", score: 1}, {value: "d", score: 1}, {value: "e", score: 1},
@@ -2926,7 +2999,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination8", destination: "destination8",
command: []string{"ZRANGE", "destination8", "key8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"}, command: []string{"ZRANGESTORE", "destination8", "key8", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4", "REV"},
expectedResponse: 3, expectedResponse: 3,
expectedValue: NewSortedSet([]MemberParam{ expectedValue: NewSortedSet([]MemberParam{
{value: "f", score: 1}, {value: "e", score: 1}, {value: "d", score: 1}, {value: "f", score: 1}, {value: "e", score: 1}, {value: "d", score: 1},
@@ -2944,7 +3017,7 @@ func Test_HandleZRANGESTORE(t *testing.T) {
}), }),
}, },
destination: "destination9", destination: "destination9",
command: []string{"ZRANGE", "destination9", "key9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, command: []string{"ZRANGESTORE", "destination9", "key9", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"},
expectedResponse: 0, expectedResponse: 0,
expectedValue: nil, expectedValue: nil,
expectedError: nil, expectedError: nil,
@@ -2952,28 +3025,28 @@ func Test_HandleZRANGESTORE(t *testing.T) {
{ // 10. Throw error when limit does not provide both offset and limit { // 10. Throw error when limit does not provide both offset and limit
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "destination10", "key10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"}, command: []string{"ZRANGESTORE", "destination10", "key10", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("limit should contain offset and count as integers"), expectedError: errors.New("limit should contain offset and count as integers"),
}, },
{ // 11. Throw error when offset is not a valid integer { // 11. Throw error when offset is not a valid integer
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "destination11", "key11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"}, command: []string{"ZRANGESTORE", "destination11", "key11", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "offset", "4"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("limit offset must be integer"), expectedError: errors.New("limit offset must be integer"),
}, },
{ // 12. Throw error when limit is not a valid integer { // 12. Throw error when limit is not a valid integer
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "destination12", "key12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"}, command: []string{"ZRANGESTORE", "destination12", "key12", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "4", "limit"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("limit count must be integer"), expectedError: errors.New("limit count must be integer"),
}, },
{ // 13. Throw error when offset is negative { // 13. Throw error when offset is negative
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "destination13", "key13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"}, command: []string{"ZRANGESTORE", "destination13", "key13", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("limit offset must be >= 0"), expectedError: errors.New("limit offset must be >= 0"),
}, },
@@ -2982,37 +3055,41 @@ func Test_HandleZRANGESTORE(t *testing.T) {
presetValues: map[string]interface{}{ presetValues: map[string]interface{}{
"key14": "Default value", "key14": "Default value",
}, },
command: []string{"ZRANGE", "destination14", "key14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"}, command: []string{"ZRANGESTORE", "destination14", "key14", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "2", "4"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New("value at key14 is not a sorted set"), expectedError: errors.New("value at key14 is not a sorted set"),
}, },
{ // 15. Command too short { // 15. Command too short
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "key15", "1"}, command: []string{"ZRANGESTORE", "key15", "1"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New(utils.WrongArgsResponse), expectedError: errors.New(utils.WrongArgsResponse),
}, },
{ // 16 Command too long { // 16 Command too long
preset: false, preset: false,
presetValues: nil, presetValues: nil,
command: []string{"ZRANGE", "destination16", "key16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"}, command: []string{"ZRANGESTORE", "destination16", "key16", "a", "h", "BYLEX", "WITHSCORES", "LIMIT", "-4", "9", "REV", "WITHSCORES"},
expectedResponse: 0, expectedResponse: 0,
expectedError: errors.New(utils.WrongArgsResponse), expectedError: errors.New(utils.WrongArgsResponse),
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZRANGESTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZRANGESTORE(context.Background(), test.command, mockServer, nil) res, err := handleZRANGESTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3031,17 +3108,17 @@ func Test_HandleZRANGESTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
if !set.Equals(test.expectedValue) { if !set.Equals(test.expectedValue) {
t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, set) t.Errorf("expected sorted set %+v, got %+v", test.expectedValue, set)
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }
@@ -3316,17 +3393,21 @@ func Test_HandleZINTER(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTER, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZINTER(context.Background(), test.command, mockServer, nil) res, err := handleZINTER(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3661,17 +3742,21 @@ func Test_HandleZINTERSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZINTERSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZINTERSTORE(context.Background(), test.command, mockServer, nil) res, err := handleZINTERSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -3690,10 +3775,10 @@ func Test_HandleZINTERSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -3702,7 +3787,7 @@ func Test_HandleZINTERSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem.value) t.Errorf("could not find element %s in the expected values", elem.value)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }
@@ -4002,17 +4087,21 @@ func Test_HandleZUNION(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNION, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZUNION(context.Background(), test.command, mockServer, nil) res, err := handleZUNION(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -4386,17 +4475,21 @@ func Test_HandleZUNIONSTORE(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("ZUNIONSTORE, %d", i))
if test.preset { if test.preset {
for key, value := range test.presetValues { for key, value := range test.presetValues {
if _, err := mockServer.CreateKeyAndLock(context.Background(), key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), key, value) if err := mockServer.SetValue(ctx, key, value); err != nil {
mockServer.KeyUnlock(key) t.Error(err)
}
mockServer.KeyUnlock(ctx, key)
} }
} }
res, err := handleZUNIONSTORE(context.Background(), test.command, mockServer, nil) res, err := handleZUNIONSTORE(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -4415,10 +4508,10 @@ func Test_HandleZUNIONSTORE(t *testing.T) {
t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer()) t.Errorf("expected response integer %d, got %d", test.expectedResponse, rv.Integer())
} }
if test.expectedValue != nil { if test.expectedValue != nil {
if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { if _, err = mockServer.KeyRLock(ctx, test.destination); err != nil {
t.Error(err) t.Error(err)
} }
set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) set, ok := mockServer.GetValue(ctx, test.destination).(*SortedSet)
if !ok { if !ok {
t.Errorf("expected vaule at key %s to be set, got another type", test.destination) t.Errorf("expected vaule at key %s to be set, got another type", test.destination)
} }
@@ -4427,7 +4520,7 @@ func Test_HandleZUNIONSTORE(t *testing.T) {
t.Errorf("could not find element %s in the expected values", elem.value) t.Errorf("could not find element %s in the expected values", elem.value)
} }
} }
mockServer.KeyRUnlock(test.destination) mockServer.KeyRUnlock(ctx, test.destination)
} }
} }
} }

View File

@@ -23,21 +23,21 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn
newStr := cmd[3] newStr := cmd[3]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
if _, err = server.CreateKeyAndLock(ctx, key); err != nil { if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err return nil, err
} }
if err = server.SetValue(ctx, key, newStr); err != nil { if err = server.SetValue(ctx, key, newStr); err != nil {
return nil, err return nil, err
} }
server.KeyUnlock(key) server.KeyUnlock(ctx, key)
return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil return []byte(fmt.Sprintf(":%d\r\n", len(newStr))), nil
} }
if _, err := server.KeyLock(ctx, key); err != nil { if _, err := server.KeyLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyUnlock(key) defer server.KeyUnlock(ctx, key)
str, ok := server.GetValue(ctx, key).(string) str, ok := server.GetValue(ctx, key).(string)
if !ok { if !ok {
@@ -91,14 +91,14 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn *
key := keys[0] key := keys[0]
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return []byte(":0\r\n"), nil return []byte(":0\r\n"), nil
} }
if _, err := server.KeyRLock(ctx, key); err != nil { if _, err := server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
value, ok := server.GetValue(ctx, key).(string) value, ok := server.GetValue(ctx, key).(string)
@@ -125,14 +125,14 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn *
return nil, errors.New("start and end indices must be integers") return nil, errors.New("start and end indices must be integers")
} }
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
return nil, fmt.Errorf("key %s does not exist", key) return nil, fmt.Errorf("key %s does not exist", key)
} }
if _, err := server.KeyRLock(ctx, key); err != nil { if _, err = server.KeyRLock(ctx, key); err != nil {
return nil, err return nil, err
} }
defer server.KeyRUnlock(key) defer server.KeyRUnlock(ctx, key)
value, ok := server.GetValue(ctx, key).(string) value, ok := server.GetValue(ctx, key).(string)
if !ok { if !ok {

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -105,17 +106,21 @@ func Test_HandleSetRange(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SETRANGE, %d", i))
// If there's a preset step, carry it out here // If there's a preset step, carry it out here
if test.preset { if test.preset {
if _, err := mockServer.CreateKeyAndLock(context.Background(), test.key); err != nil { if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, utils.AdaptType(test.presetValue)) if err := mockServer.SetValue(ctx, test.key, utils.AdaptType(test.presetValue)); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSetRange(context.Background(), test.command, mockServer, nil) res, err := handleSetRange(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -135,17 +140,17 @@ func Test_HandleSetRange(t *testing.T) {
} }
// Get the value from the server and check against the expected value // Get the value from the server and check against the expected value
if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { if _, err = mockServer.KeyRLock(ctx, test.key); err != nil {
t.Error(err) t.Error(err)
} }
value, ok := mockServer.GetValue(context.Background(), test.key).(string) value, ok := mockServer.GetValue(ctx, test.key).(string)
if !ok { if !ok {
t.Error("expected string data type, got another type") t.Error("expected string data type, got another type")
} }
if value != test.expectedValue { if value != test.expectedValue {
t.Errorf("expected value \"%s\", got \"%s\"", test.expectedValue, value) t.Errorf("expected value \"%s\", got \"%s\"", test.expectedValue, value)
} }
mockServer.KeyRUnlock(test.key) mockServer.KeyRUnlock(ctx, test.key)
} }
} }
@@ -194,16 +199,20 @@ func Test_HandleStrLen(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("STRLEN, %d", i))
if test.preset { if test.preset {
_, err := mockServer.CreateKeyAndLock(context.Background(), test.key) _, err := mockServer.CreateKeyAndLock(ctx, test.key)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleStrLen(context.Background(), test.command, mockServer, nil) res, err := handleStrLen(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())
@@ -307,16 +316,19 @@ func Test_HandleSubStr(t *testing.T) {
}, },
} }
for _, test := range tests { for i, test := range tests {
ctx := context.WithValue(context.Background(), "test_name", fmt.Sprintf("SUBSTR, %d", i))
if test.preset { if test.preset {
_, err := mockServer.CreateKeyAndLock(context.Background(), test.key) if _, err := mockServer.CreateKeyAndLock(ctx, test.key); err != nil {
if err != nil {
t.Error(err) t.Error(err)
} }
mockServer.SetValue(context.Background(), test.key, test.presetValue) if err := mockServer.SetValue(ctx, test.key, test.presetValue); err != nil {
mockServer.KeyUnlock(test.key) t.Error(err)
}
mockServer.KeyUnlock(ctx, test.key)
} }
res, err := handleSubStr(context.Background(), test.command, mockServer, nil) res, err := handleSubStr(ctx, test.command, mockServer, nil)
if test.expectedError != nil { if test.expectedError != nil {
if err.Error() != test.expectedError.Error() { if err.Error() != test.expectedError.Error() {
t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error()) t.Errorf("expected error \"%s\", got \"%s\"", test.expectedError.Error(), err.Error())

View File

@@ -98,19 +98,21 @@ func (fsm *FSM) Restore(snapshot io.ReadCloser) error {
LatestSnapshotMilliseconds: 0, LatestSnapshotMilliseconds: 0,
} }
if err := json.Unmarshal(b, &data); err != nil { if err = json.Unmarshal(b, &data); err != nil {
log.Fatal(err) log.Fatal(err)
return err return err
} }
// Set state // Set state
ctx := context.Background()
for k, v := range data.State { for k, v := range data.State {
_, err := fsm.options.Server.CreateKeyAndLock(context.Background(), k) if _, err = fsm.options.Server.CreateKeyAndLock(ctx, k); err != nil {
if err != nil {
log.Fatal(err) log.Fatal(err)
} }
fsm.options.Server.SetValue(context.Background(), k, v) if err = fsm.options.Server.SetValue(ctx, k, v); err != nil {
fsm.options.Server.KeyUnlock(k) log.Fatal(err)
}
fsm.options.Server.KeyUnlock(ctx, k)
} }
// Set latest snapshot milliseconds // Set latest snapshot milliseconds
fsm.options.Server.SetLatestSnapshot(data.LatestSnapshotMilliseconds) fsm.options.Server.SetLatestSnapshot(data.LatestSnapshotMilliseconds)

View File

@@ -33,8 +33,8 @@ func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) {
} }
} }
func (server *Server) KeyUnlock(key string) { func (server *Server) KeyUnlock(ctx context.Context, key string) {
if server.KeyExists(key) { if _, ok := server.keyLocks[key]; ok {
server.keyLocks[key].Unlock() server.keyLocks[key].Unlock()
} }
} }
@@ -58,14 +58,27 @@ func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) {
} }
} }
func (server *Server) KeyRUnlock(key string) { func (server *Server) KeyRUnlock(ctx context.Context, key string) {
if server.KeyExists(key) { if _, ok := server.keyLocks[key]; ok {
server.keyLocks[key].RUnlock() server.keyLocks[key].RUnlock()
} }
} }
func (server *Server) KeyExists(key string) bool { func (server *Server) KeyExists(ctx context.Context, key string) bool {
return server.keyLocks[key] != nil entry, ok := server.store[key]
if !ok {
return false
}
if entry.ExpireAt != (time.Time{}) && entry.ExpireAt.Before(time.Now()) {
err := server.DeleteKey(ctx, key)
if err != nil {
log.Printf("keyExists: %+v\n", err)
}
return false
}
return true
} }
// CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist. // CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist.
@@ -78,15 +91,15 @@ func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, e
server.keyCreationLock.Lock() server.keyCreationLock.Lock()
defer server.keyCreationLock.Unlock() defer server.keyCreationLock.Unlock()
if !server.KeyExists(key) { if !server.KeyExists(ctx, key) {
// Create Lock // Create Lock
keyLock := &sync.RWMutex{} keyLock := &sync.RWMutex{}
keyLock.Lock() keyLock.Lock()
server.keyLocks[key] = keyLock server.keyLocks[key] = keyLock
// Create key entry // Create key entry
server.store[key] = KeyData{ server.store[key] = utils.KeyData{
value: nil, Value: nil,
expireAt: time.Time{}, ExpireAt: time.Time{},
} }
return true, nil return true, nil
} }
@@ -100,7 +113,7 @@ func (server *Server) GetValue(ctx context.Context, key string) interface{} {
if err := server.updateKeyInCache(ctx, key); err != nil { if err := server.updateKeyInCache(ctx, key); err != nil {
log.Printf("GetValue error: %+v\n", err) log.Printf("GetValue error: %+v\n", err)
} }
return server.store[key].value return server.store[key].Value
} }
// SetValue updates the value in the store at the specified key with the given value. // SetValue updates the value in the store at the specified key with the given value.
@@ -113,9 +126,9 @@ func (server *Server) SetValue(ctx context.Context, key string, value interface{
return errors.New("max memory reached, key value not set") return errors.New("max memory reached, key value not set")
} }
server.store[key] = KeyData{ server.store[key] = utils.KeyData{
value: value, Value: value,
expireAt: server.store[key].expireAt, ExpireAt: server.store[key].ExpireAt,
} }
err := server.updateKeyInCache(ctx, key) err := server.updateKeyInCache(ctx, key)
@@ -136,7 +149,7 @@ func (server *Server) GetExpiry(ctx context.Context, key string) time.Time {
if err := server.updateKeyInCache(ctx, key); err != nil { if err := server.updateKeyInCache(ctx, key); err != nil {
log.Printf("GetKeyExpiry error: %+v\n", err) log.Printf("GetKeyExpiry error: %+v\n", err)
} }
return server.store[key].expireAt return server.store[key].ExpireAt
} }
// The SetExpiry receiver function sets the expiry time of a key. // The SetExpiry receiver function sets the expiry time of a key.
@@ -146,9 +159,9 @@ func (server *Server) GetExpiry(ctx context.Context, key string) time.Time {
// or the access time on lru eviction policy. // or the access time on lru eviction policy.
// The key must be locked prior to calling this function. // The key must be locked prior to calling this function.
func (server *Server) SetExpiry(ctx context.Context, key string, expireAt time.Time, touch bool) { func (server *Server) SetExpiry(ctx context.Context, key string, expireAt time.Time, touch bool) {
server.store[key] = KeyData{ server.store[key] = utils.KeyData{
value: server.store[key].value, Value: server.store[key].Value,
expireAt: expireAt, ExpireAt: expireAt,
} }
if touch { if touch {
err := server.updateKeyInCache(ctx, key) err := server.updateKeyInCache(ctx, key)
@@ -161,9 +174,9 @@ func (server *Server) SetExpiry(ctx context.Context, key string, expireAt time.T
// RemoveExpiry is called by commands that remove key expiry (e.g. PERSIST). // RemoveExpiry is called by commands that remove key expiry (e.g. PERSIST).
// The key must be locked prior ro calling this function. // The key must be locked prior ro calling this function.
func (server *Server) RemoveExpiry(key string) { func (server *Server) RemoveExpiry(key string) {
server.store[key] = KeyData{ server.store[key] = utils.KeyData{
value: server.store[key].value, Value: server.store[key].Value,
expireAt: time.Time{}, ExpireAt: time.Time{},
} }
switch { switch {
case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, server.Config.EvictionPolicy): case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, server.Config.EvictionPolicy):
@@ -198,9 +211,11 @@ func (server *Server) DeleteKey(ctx context.Context, key string) error {
if _, err := server.KeyLock(ctx, key); err != nil { if _, err := server.KeyLock(ctx, key); err != nil {
return fmt.Errorf("deleteKey: %+v", err) return fmt.Errorf("deleteKey: %+v", err)
} }
// Delete the keys // Remove key expiry
delete(server.store, key) server.RemoveExpiry(key)
// Delete the key from keyLocks and store
delete(server.keyLocks, key) delete(server.keyLocks, key)
delete(server.store, key)
return nil return nil
} }
@@ -227,13 +242,13 @@ func (server *Server) updateKeyInCache(ctx context.Context, key string) error {
case utils.VolatileLFU: case utils.VolatileLFU:
server.lfuCache.mutex.Lock() server.lfuCache.mutex.Lock()
defer server.lfuCache.mutex.Unlock() defer server.lfuCache.mutex.Unlock()
if server.store[key].expireAt != (time.Time{}) { if server.store[key].ExpireAt != (time.Time{}) {
server.lfuCache.cache.Update(key) server.lfuCache.cache.Update(key)
} }
case utils.VolatileLRU: case utils.VolatileLRU:
server.lruCache.mutex.Lock() server.lruCache.mutex.Lock()
defer server.lruCache.mutex.Unlock() defer server.lruCache.mutex.Unlock()
if server.store[key].expireAt != (time.Time{}) { if server.store[key].ExpireAt != (time.Time{}) {
server.lruCache.cache.Update(key) server.lruCache.cache.Update(key)
} }
} }
@@ -347,7 +362,7 @@ func (server *Server) adjustMemoryUsage(ctx context.Context) error {
for key, _ := range server.keyLocks { for key, _ := range server.keyLocks {
if idx == 0 { if idx == 0 {
// If the key is not volatile, break the loop // If the key is not volatile, break the loop
if server.store[key].expireAt == (time.Time{}) { if server.store[key].ExpireAt == (time.Time{}) {
break break
} }
// Delete the key // Delete the key

View File

@@ -21,17 +21,12 @@ import (
"time" "time"
) )
type KeyData struct {
value interface{}
expireAt time.Time
}
type Server struct { type Server struct {
Config utils.Config Config utils.Config
ConnID atomic.Uint64 ConnID atomic.Uint64
store map[string]KeyData store map[string]utils.KeyData
keyLocks map[string]*sync.RWMutex keyLocks map[string]*sync.RWMutex
keyCreationLock *sync.Mutex keyCreationLock *sync.Mutex
lfuCache struct { lfuCache struct {
@@ -77,7 +72,7 @@ func NewServer(opts Opts) *Server {
PubSub: opts.PubSub, PubSub: opts.PubSub,
CancelCh: opts.CancelCh, CancelCh: opts.CancelCh,
Commands: opts.Commands, Commands: opts.Commands,
store: make(map[string]KeyData), store: make(map[string]utils.KeyData),
keyLocks: make(map[string]*sync.RWMutex), keyLocks: make(map[string]*sync.RWMutex),
keyCreationLock: &sync.Mutex{}, keyCreationLock: &sync.Mutex{},
} }
@@ -105,13 +100,14 @@ func NewServer(opts Opts) *Server {
SetLatestSnapshotMilliseconds: server.SetLatestSnapshot, SetLatestSnapshotMilliseconds: server.SetLatestSnapshot,
GetLatestSnapshotMilliseconds: server.GetLatestSnapshot, GetLatestSnapshotMilliseconds: server.GetLatestSnapshot,
SetValue: func(key string, value interface{}) error { SetValue: func(key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(context.Background(), key); err != nil { ctx := context.Background()
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err return err
} }
if err := server.SetValue(context.Background(), key, value); err != nil { if err := server.SetValue(ctx, key, value); err != nil {
return err return err
} }
server.KeyUnlock(key) server.KeyUnlock(ctx, key)
return nil return nil
}, },
}) })
@@ -123,13 +119,14 @@ func NewServer(opts Opts) *Server {
aof.WithFinishRewriteFunc(server.FinishRewriteAOF), aof.WithFinishRewriteFunc(server.FinishRewriteAOF),
aof.WithGetStateFunc(server.GetState), aof.WithGetStateFunc(server.GetState),
aof.WithSetValueFunc(func(key string, value interface{}) error { aof.WithSetValueFunc(func(key string, value interface{}) error {
if _, err := server.CreateKeyAndLock(context.Background(), key); err != nil { ctx := context.Background()
if _, err := server.CreateKeyAndLock(ctx, key); err != nil {
return err return err
} }
if err := server.SetValue(context.Background(), key, value); err != nil { if err := server.SetValue(ctx, key, value); err != nil {
return err return err
} }
server.KeyUnlock(key) server.KeyUnlock(ctx, key)
return nil return nil
}), }),
aof.WithHandleCommandFunc(func(command []byte) { aof.WithHandleCommandFunc(func(command []byte) {

View File

@@ -98,12 +98,11 @@ There is no limit by default.`, func(memory string) error {
4) volatile-lfu - Evict the least frequently used keys with an expiration. 4) volatile-lfu - Evict the least frequently used keys with an expiration.
5) volatile-lru - Evict the least recently used keys with an expiration. 5) volatile-lru - Evict the least recently used keys with an expiration.
6) allkeys-random - Evict random keys until we get under the max-memory limit. 6) allkeys-random - Evict random keys until we get under the max-memory limit.
7) volatile-random - Evict random keys with an expiration. 7) volatile-random - Evict random keys with an expiration.`, func(policy string) error {
8) volatile-ttl - Evict the keys with the shortest remaining ttl.`, func(policy string) error {
policies := []string{ policies := []string{
NoEviction, NoEviction,
AllKeysLFU, AllKeysLRU, AllKeysRandom, AllKeysLFU, AllKeysLRU, AllKeysRandom,
VolatileLFU, VolatileLRU, VolatileRandom, VolatileTTL, VolatileLFU, VolatileLRU, VolatileRandom,
} }
policyIdx := slices.Index(policies, strings.ToLower(policy)) policyIdx := slices.Index(policies, strings.ToLower(policy))
if policyIdx == -1 { if policyIdx == -1 {

View File

@@ -37,5 +37,4 @@ const (
VolatileLFU = "volatile-lfu" VolatileLFU = "volatile-lfu"
AllKeysRandom = "allkeys-random" AllKeysRandom = "allkeys-random"
VolatileRandom = "volatile-random" VolatileRandom = "volatile-random"
VolatileTTL = "volatile-ttl"
) )

View File

@@ -6,12 +6,18 @@ import (
"time" "time"
) )
// KeyData holds the structure of the in-memory data stored at a string key.
type KeyData struct {
Value interface{}
ExpireAt time.Time
}
type Server interface { type Server interface {
KeyLock(ctx context.Context, key string) (bool, error) KeyLock(ctx context.Context, key string) (bool, error)
KeyUnlock(key string) KeyUnlock(ctx context.Context, key string)
KeyRLock(ctx context.Context, key string) (bool, error) KeyRLock(ctx context.Context, key string) (bool, error)
KeyRUnlock(key string) KeyRUnlock(ctx context.Context, key string)
KeyExists(key string) bool KeyExists(ctx context.Context, key string) bool
CreateKeyAndLock(ctx context.Context, key string) (bool, error) CreateKeyAndLock(ctx context.Context, key string) (bool, error)
GetValue(ctx context.Context, key string) interface{} GetValue(ctx context.Context, key string) interface{}
SetValue(ctx context.Context, key string, value interface{}) error SetValue(ctx context.Context, key string, value interface{}) error