Removed etc and get modules and replaced them with generic module. Implemented functions to set and remove the expiry of a key. Implemented LRU and LFU caches using heap.

This commit is contained in:
Kelvin Clement Mwinuka
2024-03-03 16:21:12 +08:00
parent e569bf6837
commit 28f97656c4
22 changed files with 797 additions and 544 deletions

View File

@@ -21,6 +21,8 @@ CMD "./server" \
"--data-dir" "${DATA_DIR}" \ "--data-dir" "${DATA_DIR}" \
"--snapshot-threshold" "${SNAPSHOT_THRESHOLD}" \ "--snapshot-threshold" "${SNAPSHOT_THRESHOLD}" \
"--snapshot-interval" "${SNAPSHOT_INTERVAL}" \ "--snapshot-interval" "${SNAPSHOT_INTERVAL}" \
"--max-memory" "${MAX_MEMORY}" \
"--eviction-policy" "${EVICTION_POLICY}" \
"--tls=${TLS}" \ "--tls=${TLS}" \
"--mtls=${MTLS}" \ "--mtls=${MTLS}" \
"--in-memory=${IN_MEMORY}" \ "--in-memory=${IN_MEMORY}" \
@@ -32,7 +34,6 @@ CMD "./server" \
"--restore-snapshot=${RESTORE_SNAPSHOT}" \ "--restore-snapshot=${RESTORE_SNAPSHOT}" \
"--restore-aof=${RESTORE_AOF}" \ "--restore-aof=${RESTORE_AOF}" \
"--aof-sync-strategy=${AOF_SYNC_STRATEGY}" \ "--aof-sync-strategy=${AOF_SYNC_STRATEGY}" \
"--max-memory=${MAX_MEMORY}" \
# List of server cert/key pairs # List of server cert/key pairs
"--cert-key-pair=${CERT_KEY_PAIR_1}" \ "--cert-key-pair=${CERT_KEY_PAIR_1}" \
"--cert-key-pair=${CERT_KEY_PAIR_2}" \ "--cert-key-pair=${CERT_KEY_PAIR_2}" \

View File

@@ -18,7 +18,7 @@ services:
- PLUGIN_DIR=/usr/local/lib/echovault - PLUGIN_DIR=/usr/local/lib/echovault
- DATA_DIR=/var/lib/echovault - DATA_DIR=/var/lib/echovault
- IN_MEMORY=false - IN_MEMORY=false
- TLS=true - TLS=false
- MTLS=false - MTLS=false
- BOOTSTRAP_CLUSTER=false - BOOTSTRAP_CLUSTER=false
- ACL_CONFIG=/etc/config/echovault/acl.yml - ACL_CONFIG=/etc/config/echovault/acl.yml
@@ -31,6 +31,7 @@ services:
- RESTORE_AOF=true - RESTORE_AOF=true
- AOF_SYNC_STRATEGY=everysec - AOF_SYNC_STRATEGY=everysec
- MAX_MEMORY=100kb - MAX_MEMORY=100kb
- EVICTION_POLICY=allkeys-lfu
# List of server cert/key pairs # List of server cert/key pairs
- CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key
- CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key
@@ -55,15 +56,15 @@ services:
# - PORT=7480 # - PORT=7480
# - RAFT_PORT=8000 # - RAFT_PORT=8000
# - ML_PORT=7946 # - ML_PORT=7946
# - KEY=/etc/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
# - CERT=/etc/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
# - SERVER_ID=1 # - SERVER_ID=1
# - DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
# - IN_MEMORY=false # - IN_MEMORY=false
# - TLS=true # - TLS=true
# - MTLS=true # - MTLS=true
# - BOOTSTRAP_CLUSTER=true # - BOOTSTRAP_CLUSTER=true
# - ACL_CONFIG=/etc/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
# - REQUIRE_PASS=false # - REQUIRE_PASS=false
# - FORWARD_COMMAND=true # - FORWARD_COMMAND=true
# - SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
@@ -72,16 +73,16 @@ services:
# - RESTORE_AOF=false # - RESTORE_AOF=false
# - AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
# # List of server cert/key pairs # # List of server cert/key pairs
# - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
# - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# # List of client certificate authorities # # List of client certificate authorities
# - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
# ports: # ports:
# - "7481:7480" # - "7481:7480"
# - "7945:7946" # - "7945:7946"
# - "8000:8000" # - "8000:8000"
# volumes: # volumes:
# - ./config/acl.yml:/etc/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
# - ./volumes/cluster_node_1:/var/lib/echovault # - ./volumes/cluster_node_1:/var/lib/echovault
# networks: # networks:
# - testnet # - testnet
@@ -95,8 +96,8 @@ services:
# - PORT=7480 # - PORT=7480
# - RAFT_PORT=8000 # - RAFT_PORT=8000
# - ML_PORT=7946 # - ML_PORT=7946
# - KEY=/etc/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
# - CERT=/etc/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
# - SERVER_ID=2 # - SERVER_ID=2
# - JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
# - DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
@@ -104,7 +105,7 @@ services:
# - TLS=true # - TLS=true
# - MTLS=true # - MTLS=true
# - BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
# - ACL_CONFIG=/etc/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
# - REQUIRE_PASS=false # - REQUIRE_PASS=false
# - FORWARD_COMMAND=true # - FORWARD_COMMAND=true
# - SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
@@ -113,16 +114,16 @@ services:
# - RESTORE_AOF=false # - RESTORE_AOF=false
# - AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
# # List of server cert/key pairs # # List of server cert/key pairs
# - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
# - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# # List of client certificate authorities # # List of client certificate authorities
# - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
# ports: # ports:
# - "7482:7480" # - "7482:7480"
# - "7947:7946" # - "7947:7946"
# - "8001:8000" # - "8001:8000"
# volumes: # volumes:
# - ./config/acl.yml:/etc/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
# - ./volumes/cluster_node_2:/var/lib/echovault # - ./volumes/cluster_node_2:/var/lib/echovault
# networks: # networks:
# - testnet # - testnet
@@ -136,8 +137,8 @@ services:
# - PORT=7480 # - PORT=7480
# - RAFT_PORT=8000 # - RAFT_PORT=8000
# - ML_PORT=7946 # - ML_PORT=7946
# - KEY=/etc/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
# - CERT=/etc/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
# - SERVER_ID=3 # - SERVER_ID=3
# - JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
# - DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
@@ -145,7 +146,7 @@ services:
# - TLS=true # - TLS=true
# - MTLS=true # - MTLS=true
# - BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
# - ACL_CONFIG=/etc/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
# - REQUIRE_PASS=false # - REQUIRE_PASS=false
# - FORWARD_COMMAND=true # - FORWARD_COMMAND=true
# - SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
@@ -154,16 +155,16 @@ services:
# - RESTORE_AOF=false # - RESTORE_AOF=false
# - AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
# # List of server cert/key pairs # # List of server cert/key pairs
# - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
# - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# # List of client certificate authorities # # List of client certificate authorities
# - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
# ports: # ports:
# - "7483:7480" # - "7483:7480"
# - "7948:7946" # - "7948:7946"
# - "8002:8000" # - "8002:8000"
# volumes: # volumes:
# - ./config/acl.yml:/etc/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
# - ./volumes/cluster_node_3:/var/lib/echovault # - ./volumes/cluster_node_3:/var/lib/echovault
# networks: # networks:
# - testnet # - testnet
@@ -177,8 +178,8 @@ services:
# - PORT=7480 # - PORT=7480
# - RAFT_PORT=8000 # - RAFT_PORT=8000
# - ML_PORT=7946 # - ML_PORT=7946
# - KEY=/etc/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
# - CERT=/etc/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
# - SERVER_ID=4 # - SERVER_ID=4
# - JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
# - DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
@@ -186,7 +187,7 @@ services:
# - TLS=true # - TLS=true
# - MTLS=true # - MTLS=true
# - BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
# - ACL_CONFIG=/etc/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
# - REQUIRE_PASS=false # - REQUIRE_PASS=false
# - FORWARD_COMMAND=true # - FORWARD_COMMAND=true
# - SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
@@ -195,16 +196,16 @@ services:
# - RESTORE_AOF=false # - RESTORE_AOF=false
# - AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
# # List of server cert/key pairs # # List of server cert/key pairs
# - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
# - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# # List of client certificate authorities # # List of client certificate authorities
# - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
# ports: # ports:
# - "7484:7480" # - "7484:7480"
# - "7949:7946" # - "7949:7946"
# - "8003:8000" # - "8003:8000"
# volumes: # volumes:
# - ./config/acl.yml:/etc/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
# - ./volumes/cluster_node_4:/var/lib/echovault # - ./volumes/cluster_node_4:/var/lib/echovault
# networks: # networks:
# - testnet # - testnet
@@ -218,8 +219,8 @@ services:
# - PORT=7480 # - PORT=7480
# - RAFT_PORT=8000 # - RAFT_PORT=8000
# - ML_PORT=7946 # - ML_PORT=7946
# - KEY=/etc/ssl/certs/echovault/server1.key # - KEY=/generic/ssl/certs/echovault/server1.key
# - CERT=/etc/ssl/certs/echovault/server1.crt # - CERT=/generic/ssl/certs/echovault/server1.crt
# - SERVER_ID=5 # - SERVER_ID=5
# - JOIN_ADDR=cluster_node_1:7946 # - JOIN_ADDR=cluster_node_1:7946
# - DATA_DIR=/var/lib/echovault # - DATA_DIR=/var/lib/echovault
@@ -227,7 +228,7 @@ services:
# - TLS=true # - TLS=true
# - MTLS=true # - MTLS=true
# - BOOTSTRAP_CLUSTER=false # - BOOTSTRAP_CLUSTER=false
# - ACL_CONFIG=/etc/config/echovault/acl.yml # - ACL_CONFIG=/generic/config/echovault/acl.yml
# - REQUIRE_PASS=false # - REQUIRE_PASS=false
# - FORWARD_COMMAND=true # - FORWARD_COMMAND=true
# - SNAPSHOT_THRESHOLD=1000 # - SNAPSHOT_THRESHOLD=1000
@@ -236,16 +237,16 @@ services:
# - RESTORE_AOF=false # - RESTORE_AOF=false
# - AOF_SYNC_STRATEGY=everysec # - AOF_SYNC_STRATEGY=everysec
# # List of server cert/key pairs # # List of server cert/key pairs
# - CERT_KEY_PAIR_1=/etc/ssl/certs/echovault/server/server1.crt,/etc/ssl/certs/echovault/server/server1.key # - CERT_KEY_PAIR_1=/generic/ssl/certs/echovault/server/server1.crt,/generic/ssl/certs/echovault/server/server1.key
# - CERT_KEY_PAIR_2=/etc/ssl/certs/echovault/server/server2.crt,/etc/ssl/certs/echovault/server/server2.key # - CERT_KEY_PAIR_2=/generic/ssl/certs/echovault/server/server2.crt,/generic/ssl/certs/echovault/server/server2.key
# # List of client certificate authorities # # List of client certificate authorities
# - CLIENT_CA_1=/etc/ssl/certs/echovault/client/rootCA.crt # - CLIENT_CA_1=/generic/ssl/certs/echovault/client/rootCA.crt
# ports: # ports:
# - "7485:7480" # - "7485:7480"
# - "7950:7946" # - "7950:7946"
# - "8004:8000" # - "8004:8000"
# volumes: # volumes:
# - ./config/acl.yml:/etc/config/echovault/acl.yml # - ./config/acl.yml:/generic/config/echovault/acl.yml
# - ./volumes/cluster_node_5:/var/lib/echovault # - ./volumes/cluster_node_5:/var/lib/echovault
# networks: # networks:
# - testnet # - testnet

View File

@@ -11,47 +11,80 @@ type EntryLFU struct {
index int // The index of the entry in the heap index int // The index of the entry in the heap
} }
type Cache []*EntryLFU type CacheLFU struct {
keys map[string]bool
func (cache *Cache) Len() int { entries []*EntryLFU
return len(*cache)
} }
func (cache *Cache) Less(i, j int) bool { func NewCacheLFU() *CacheLFU {
return (*cache)[i].count > (*cache)[j].count cache := &CacheLFU{
keys: make(map[string]bool),
entries: make([]*EntryLFU, 0),
}
heap.Init(cache)
return cache
} }
func (cache *Cache) Swap(i, j int) { func (cache *CacheLFU) Len() int {
(*cache)[i], (*cache)[j] = (*cache)[j], (*cache)[i] return len(cache.entries)
(*cache)[i].index = i
(*cache)[j].index = j
} }
func (cache *Cache) Push(key any) { func (cache *CacheLFU) Less(i, j int) bool {
n := len(*cache) return cache.entries[i].count < cache.entries[j].count
*cache = append(*cache, &EntryLFU{ }
func (cache *CacheLFU) Swap(i, j int) {
cache.entries[i], cache.entries[j] = cache.entries[j], cache.entries[i]
cache.entries[i].index = i
cache.entries[j].index = j
}
func (cache *CacheLFU) Push(key any) {
n := len(cache.entries)
cache.entries = append(cache.entries, &EntryLFU{
key: key.(string), key: key.(string),
count: 1, count: 1,
index: n, index: n,
}) })
cache.keys[key.(string)] = true
} }
func (cache *Cache) Pop() any { func (cache *CacheLFU) Pop() any {
old := *cache old := cache.entries
n := len(old) n := len(old)
entry := old[n-1] entry := old[n-1]
old[n-1] = nil old[n-1] = nil
entry.index = -1 entry.index = -1
*cache = old[0 : n-1] cache.entries = old[0 : n-1]
delete(cache.keys, entry.key)
return entry.key return entry.key
} }
func (cache *Cache) Update(key string) { func (cache *CacheLFU) Update(key string) {
// If the key is not contained in the cache, push it.
if !cache.contains(key) {
heap.Push(cache, key)
return
}
// Get the item with key // Get the item with key
entryIdx := slices.IndexFunc(*cache, func(e *EntryLFU) bool { entryIdx := slices.IndexFunc(cache.entries, func(e *EntryLFU) bool {
return e.key == key return e.key == key
}) })
entry := (*cache)[entryIdx] entry := cache.entries[entryIdx]
entry.count += 1 entry.count += 1
heap.Fix(cache, entryIdx) heap.Fix(cache, entryIdx)
} }
func (cache *CacheLFU) Delete(key string) {
entryIdx := slices.IndexFunc(cache.entries, func(entry *EntryLFU) bool {
return entry.key == key
})
if entryIdx > -1 {
heap.Remove(cache, cache.entries[entryIdx].index)
}
}
func (cache *CacheLFU) contains(key string) bool {
_, ok := cache.keys[key]
return ok
}

View File

@@ -1 +1,89 @@
package eviction package eviction
import (
"container/heap"
"slices"
"time"
)
type EntryLRU struct {
key string // The key, matching the key in the store
unixTime int64 // Unix time in milliseconds when this key was accessed
index int // The index of the entry in the heap
}
type CacheLRU struct {
keys map[string]bool
entries []*EntryLRU
}
func NewCacheLRU() *CacheLRU {
cache := &CacheLRU{
keys: make(map[string]bool),
entries: make([]*EntryLRU, 0),
}
heap.Init(cache)
return cache
}
func (cache *CacheLRU) Len() int {
return len(cache.entries)
}
func (cache *CacheLRU) Less(i, j int) bool {
return cache.entries[i].unixTime > cache.entries[j].unixTime
}
func (cache *CacheLRU) Swap(i, j int) {
cache.entries[i], cache.entries[j] = cache.entries[j], cache.entries[i]
cache.entries[i].index = i
cache.entries[j].index = j
}
func (cache *CacheLRU) Push(key any) {
n := len(cache.entries)
cache.entries = append(cache.entries, &EntryLRU{
key: key.(string),
unixTime: time.Now().Unix(),
index: n,
})
}
func (cache *CacheLRU) Pop() any {
old := cache.entries
n := len(old)
entry := old[n-1]
old[n-1] = nil
entry.index = -1
cache.entries = old[0 : n-1]
delete(cache.keys, entry.key)
return entry.key
}
func (cache *CacheLRU) Update(key string) {
// If the key does not already exist in the cache, then push it
if !cache.contains(key) {
heap.Push(cache, key)
}
// Get the item with key
entryIdx := slices.IndexFunc(cache.entries, func(e *EntryLRU) bool {
return e.key == key
})
entry := cache.entries[entryIdx]
entry.unixTime = time.Now().Unix()
heap.Fix(cache, entryIdx)
}
func (cache *CacheLRU) Delete(key string) {
entryIdx := slices.IndexFunc(cache.entries, func(entry *EntryLRU) bool {
return entry.key == key
})
if entryIdx > -1 {
heap.Remove(cache, cache.entries[entryIdx].index)
}
}
func (cache *CacheLRU) contains(key string) bool {
_, ok := cache.keys[key]
return ok
}

View File

@@ -4,11 +4,10 @@ import (
"context" "context"
"github.com/echovault/echovault/src/modules/acl" "github.com/echovault/echovault/src/modules/acl"
"github.com/echovault/echovault/src/modules/admin" "github.com/echovault/echovault/src/modules/admin"
"github.com/echovault/echovault/src/modules/etc" "github.com/echovault/echovault/src/modules/connection"
"github.com/echovault/echovault/src/modules/get" "github.com/echovault/echovault/src/modules/generic"
"github.com/echovault/echovault/src/modules/hash" "github.com/echovault/echovault/src/modules/hash"
"github.com/echovault/echovault/src/modules/list" "github.com/echovault/echovault/src/modules/list"
"github.com/echovault/echovault/src/modules/ping"
"github.com/echovault/echovault/src/modules/pubsub" "github.com/echovault/echovault/src/modules/pubsub"
"github.com/echovault/echovault/src/modules/set" "github.com/echovault/echovault/src/modules/set"
"github.com/echovault/echovault/src/modules/sorted_set" "github.com/echovault/echovault/src/modules/sorted_set"
@@ -25,11 +24,10 @@ func GetCommands() []utils.Command {
var commands []utils.Command var commands []utils.Command
commands = append(commands, acl.Commands()...) commands = append(commands, acl.Commands()...)
commands = append(commands, admin.Commands()...) commands = append(commands, admin.Commands()...)
commands = append(commands, etc.Commands()...) commands = append(commands, generic.Commands()...)
commands = append(commands, get.Commands()...)
commands = append(commands, hash.Commands()...) commands = append(commands, hash.Commands()...)
commands = append(commands, list.Commands()...) commands = append(commands, list.Commands()...)
commands = append(commands, ping.Commands()...) commands = append(commands, connection.Commands()...)
commands = append(commands, pubsub.Commands()...) commands = append(commands, pubsub.Commands()...)
commands = append(commands, set.Commands()...) commands = append(commands, set.Commands()...)
commands = append(commands, sorted_set.Commands()...) commands = append(commands, sorted_set.Commands()...)

View File

@@ -271,8 +271,8 @@ func (acl *ACL) AuthorizeConnection(conn *net.Conn, cmd []string, command utils.
return nil return nil
} }
// Skip ping // Skip connection
if strings.EqualFold(comm, "ping") { if strings.EqualFold(comm, "connection") {
return nil return nil
} }

View File

@@ -156,18 +156,6 @@ func handleCommandDocs(ctx context.Context, cmd []string, server utils.Server, _
return []byte("*0\r\n"), nil return []byte("*0\r\n"), nil
} }
// func handleConfigGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
// return nil, errors.New("command not yet implemented")
// }
//
// func handleConfigRewrite(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
// return nil, errors.New("command not yet implemented")
// }
//
// func handleConfigSet(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
// return nil, errors.New("command not yet implemented")
// }
func Commands() []utils.Command { func Commands() []utils.Command {
return []utils.Command{ return []utils.Command{
{ {

View File

@@ -1,4 +1,4 @@
package ping package connection
import ( import (
"context" "context"
@@ -22,7 +22,7 @@ func handlePing(ctx context.Context, cmd []string, server utils.Server, conn *ne
func Commands() []utils.Command { func Commands() []utils.Command {
return []utils.Command{ return []utils.Command{
{ {
Command: "ping", Command: "connection",
Categories: []string{utils.FastCategory, utils.ConnectionCategory}, Categories: []string{utils.FastCategory, utils.ConnectionCategory},
Description: "(PING [value]) Ping the server. If a value is provided, the value will be echoed.", Description: "(PING [value]) Ping the server. If a value is provided, the value will be echoed.",
Sync: false, Sync: false,

View File

@@ -1,4 +1,4 @@
package ping package connection
import ( import (
"bytes" "bytes"

View File

@@ -1,148 +0,0 @@
package etc
import (
"context"
"errors"
"fmt"
"github.com/echovault/echovault/src/utils"
"net"
)
type KeyObject struct {
value interface{}
locked bool
}
func handleSet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
keys, err := setKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
if !server.KeyExists(key) {
_, err := server.CreateKeyAndLock(ctx, key)
if err != nil {
return nil, err
}
server.SetValue(ctx, key, utils.AdaptType(cmd[2]))
server.KeyUnlock(key)
return []byte(utils.OkResponse), nil
}
if _, err := server.KeyLock(ctx, key); err != nil {
return nil, err
}
server.SetValue(ctx, key, utils.AdaptType(cmd[2]))
server.KeyUnlock(key)
return []byte(utils.OkResponse), nil
}
func handleSetNX(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
keys, err := setNXKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
if server.KeyExists(key) {
return nil, fmt.Errorf("key %s already exists", key)
}
if _, err = server.CreateKeyAndLock(ctx, key); err != nil {
return nil, err
}
server.SetValue(ctx, key, utils.AdaptType(cmd[2]))
server.KeyUnlock(key)
return []byte(utils.OkResponse), nil
}
func handleMSet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
if _, err := msetKeyFunc(cmd); err != nil {
return nil, err
}
entries := make(map[string]KeyObject)
// Release all acquired key locks
defer func() {
for k, v := range entries {
if v.locked {
server.KeyUnlock(k)
entries[k] = KeyObject{
value: v.value,
locked: false,
}
}
}
}()
// Extract all the key/value pairs
for i, key := range cmd[1:] {
if i%2 == 0 {
entries[key] = KeyObject{
value: utils.AdaptType(cmd[1:][i+1]),
locked: false,
}
}
}
// Acquire all the locks for each key first
// If any key cannot be acquired, abandon transaction and release all currently held keys
for k, v := range entries {
if server.KeyExists(k) {
if _, err := server.KeyLock(ctx, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
continue
}
if _, err := server.CreateKeyAndLock(ctx, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
}
// Set all the values
for k, v := range entries {
server.SetValue(ctx, k, v.value)
}
return []byte(utils.OkResponse), nil
}
func handleCopy(ctx context.Context, cmd []string, server *utils.Server, _ *net.Conn) ([]byte, error) {
return nil, errors.New("command not yet implemented")
}
func Commands() []utils.Command {
return []utils.Command{
{
Command: "set",
Categories: []string{utils.WriteCategory, utils.SlowCategory},
Description: "(SET key value) Set the value of a key, considering the value's type.",
Sync: true,
KeyExtractionFunc: setKeyFunc,
HandlerFunc: handleSet,
},
{
Command: "setnx",
Categories: []string{utils.WriteCategory, utils.SlowCategory},
Description: "(SETNX key value) Set the key/value only if the key doesn't exist.",
Sync: true,
KeyExtractionFunc: setNXKeyFunc,
HandlerFunc: handleSetNX,
},
{
Command: "mset",
Categories: []string{utils.WriteCategory, utils.SlowCategory},
Description: "(MSET key value [key value ...]) Automatically etc or modify multiple key/value pairs.",
Sync: true,
KeyExtractionFunc: msetKeyFunc,
HandlerFunc: handleMSet,
},
}
}

View File

@@ -0,0 +1,249 @@
package generic
import (
"context"
"fmt"
"github.com/echovault/echovault/src/utils"
"net"
"strings"
"time"
)
type KeyObject struct {
value interface{}
locked bool
}
func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
keys, err := setKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
value := cmd[2]
res := []byte(utils.OkResponse)
params, err := getSetCommandParams(cmd[3:], SetParams{})
if err != nil {
return nil, err
}
// If GET is provided, the response should be the current stored value.
// If there's no current value, then the response should be nil.
if params.get {
if !server.KeyExists(key) {
res = []byte("$-1\r\n")
} else {
res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(key)))
}
}
if "xx" == strings.ToLower(params.exists) {
// If XX is specified, make sure the key exists.
if !server.KeyExists(key) {
return nil, fmt.Errorf("key %s does not exist", key)
}
_, err = server.KeyLock(ctx, key)
} else if "nx" == strings.ToLower(params.exists) {
// If NX is specified, make sure that the key does not currently exist.
if server.KeyExists(key) {
return nil, fmt.Errorf("key %s already exists", key)
}
_, err = server.CreateKeyAndLock(ctx, key)
} else {
// Neither XX not NX are specified, lock or create the lock
if !server.KeyExists(key) {
// Key does not exist, create it
_, err = server.CreateKeyAndLock(ctx, key)
} else {
// Key exists, acquire the lock
_, err = server.KeyLock(ctx, key)
}
}
if err != nil {
return nil, err
}
defer server.KeyUnlock(key)
server.SetValue(ctx, key, utils.AdaptType(value))
// If expiresAt is set, set the key's expiry time as well
if params.expireAt != nil {
server.SetKeyExpiry(key, params.expireAt.(time.Time), false)
}
return res, nil
}
func handleMSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
if _, err := msetKeyFunc(cmd); err != nil {
return nil, err
}
entries := make(map[string]KeyObject)
// Release all acquired key locks
defer func() {
for k, v := range entries {
if v.locked {
server.KeyUnlock(k)
entries[k] = KeyObject{
value: v.value,
locked: false,
}
}
}
}()
// Extract all the key/value pairs
for i, key := range cmd[1:] {
if i%2 == 0 {
entries[key] = KeyObject{
value: utils.AdaptType(cmd[1:][i+1]),
locked: false,
}
}
}
// Acquire all the locks for each key first
// If any key cannot be acquired, abandon transaction and release all currently held keys
for k, v := range entries {
if server.KeyExists(k) {
if _, err := server.KeyLock(ctx, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
continue
}
if _, err := server.CreateKeyAndLock(ctx, k); err != nil {
return nil, err
}
entries[k] = KeyObject{value: v.value, locked: true}
}
// Set all the values
for k, v := range entries {
server.SetValue(ctx, k, v.value)
}
return []byte(utils.OkResponse), nil
}
func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
keys, err := getKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
if !server.KeyExists(key) {
return []byte("$-1\r\n"), nil
}
_, err = server.KeyRLock(ctx, key)
if err != nil {
return nil, err
}
defer server.KeyRUnlock(key)
value := server.GetValue(key)
return []byte(fmt.Sprintf("+%v\r\n", value)), nil
}
func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Conn) ([]byte, error) {
keys, err := mgetKeyFunc(cmd)
if err != nil {
return nil, err
}
values := make(map[string]string)
locks := make(map[string]bool)
for _, key := range keys {
if _, ok := values[key]; ok {
// Skip if we have already locked this key
continue
}
if server.KeyExists(key) {
_, err = server.KeyRLock(ctx, key)
if err != nil {
return nil, fmt.Errorf("could not obtain lock for %s key", key)
}
locks[key] = true
continue
}
values[key] = ""
}
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(key)
locks[key] = false
}
}
}()
for key, _ := range locks {
values[key] = fmt.Sprintf("%v", server.GetValue(key))
}
bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:])))
for _, key := range cmd[1:] {
if values[key] == "" {
bytes = append(bytes, []byte("$-1\r\n")...)
continue
}
bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(values[key]), values[key]))...)
}
return bytes, nil
}
func Commands() []utils.Command {
return []utils.Command{
{
Command: "set",
Categories: []string{utils.WriteCategory, utils.SlowCategory},
Description: `
(SET key value [NX | XX] [GET] [EX seconds | PX milliseconds | EXAT unix-time-seconds | PXAT unix-time-milliseconds])
Set the value of a key, considering the value's type.
NX - Only set if the key does not exist.
XX - Only set if the key exists.
GET - Return the old value stored at key, or nil if the value does not exist.
EX - Expire the key after the specified number of seconds (positive integer).
PX - Expire the key after the specified number of milliseconds (positive integer).
EXAT - Expire at the exact time in unix seconds (positive integer).
PXAT - Expire at the exat time in unix milliseconds (positive integer).`,
Sync: true,
KeyExtractionFunc: setKeyFunc,
HandlerFunc: handleSet,
},
{
Command: "mset",
Categories: []string{utils.WriteCategory, utils.SlowCategory},
Description: "(MSET key value [key value ...]) Automatically generic or modify multiple key/value pairs.",
Sync: true,
KeyExtractionFunc: msetKeyFunc,
HandlerFunc: handleMSet,
},
{
Command: "get",
Categories: []string{utils.ReadCategory, utils.FastCategory},
Description: "(GET key) Get the value at the specified key.",
Sync: false,
KeyExtractionFunc: getKeyFunc,
HandlerFunc: handleGet,
},
{
Command: "mget",
Categories: []string{utils.ReadCategory, utils.FastCategory},
Description: "(MGET key1 [key2]) Get multiple values from the specified keys.",
Sync: false,
KeyExtractionFunc: mgetKeyFunc,
HandlerFunc: handleMGet,
},
}
}

View File

@@ -1,9 +1,10 @@
package etc package generic
import ( import (
"bytes" "bytes"
"context" "context"
"errors" "errors"
"fmt"
"github.com/echovault/echovault/src/server" "github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils" "github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp" "github.com/tidwall/resp"
@@ -44,7 +45,7 @@ func Test_HandleSET(t *testing.T) {
expectedErr: errors.New(utils.WrongArgsResponse), expectedErr: errors.New(utils.WrongArgsResponse),
}, },
{ {
command: []string{"SET", "test", "one", "two", "three"}, command: []string{"SET", "test", "one", "two", "three", "four", "five", "eight"},
expectedResponse: "", expectedResponse: "",
expectedValue: nil, expectedValue: nil,
expectedErr: errors.New(utils.WrongArgsResponse), expectedErr: errors.New(utils.WrongArgsResponse),
@@ -102,23 +103,6 @@ func Test_HandleSET(t *testing.T) {
} }
} }
func Test_HandleSETNX(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
res, err := handleSetNX(context.Background(), []string{"SET", "test", "Test_HandleSETNX"}, mockServer, nil)
if err != nil {
t.Error(err)
}
// Try to set existing key again
res, err = handleSetNX(context.Background(), []string{"SET", "test", "Test_HandleSETNX_2"}, mockServer, nil)
if res != nil {
t.Errorf("exptected nil response, got: %+v", res)
}
if err.Error() != "key test already exists" {
t.Errorf("expected key test already exists, got %s", err.Error())
}
}
func Test_HandleMSET(t *testing.T) { func Test_HandleMSET(t *testing.T) {
mockServer := server.NewServer(server.Opts{}) mockServer := server.NewServer(server.Opts{})
@@ -197,3 +181,155 @@ func Test_HandleMSET(t *testing.T) {
} }
} }
} }
func Test_HandleGET(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
tests := []struct {
key string
value string
}{
{
key: "test1",
value: "value1",
},
{
key: "test2",
value: "10",
},
{
key: "test3",
value: "3.142",
},
}
// Test successful GET command
for _, test := range tests {
func(key, value string) {
ctx := context.Background()
_, err := mockServer.CreateKeyAndLock(ctx, key)
if err != nil {
t.Error(err)
}
mockServer.SetValue(ctx, key, value)
mockServer.KeyUnlock(key)
res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil)
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n", value))) {
t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n", value), string(res))
}
}(test.key, test.value)
}
// Test get non-existent key
res, err := handleGet(context.Background(), []string{"GET", "test4"}, mockServer, nil)
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte("$-1\r\n")) {
t.Errorf("expected %+v, got: %+v", "+nil\r\n", res)
}
errorTests := []struct {
command []string
expected string
}{
{
command: []string{"GET"},
expected: utils.WrongArgsResponse,
},
{
command: []string{"GET", "key", "test"},
expected: utils.WrongArgsResponse,
},
}
for _, test := range errorTests {
res, err = handleGet(context.Background(), test.command, mockServer, nil)
if res != nil {
t.Errorf("expected nil response, got: %+v", res)
}
if err.Error() != test.expected {
t.Errorf("expected error '%s', got: %s", test.expected, err.Error())
}
}
}
func Test_HandleMGET(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
tests := []struct {
presetKeys []string
presetValues []string
command []string
expected []interface{}
expectedError error
}{
{
presetKeys: []string{"test1", "test2", "test3", "test4"},
presetValues: []string{"value1", "value2", "value3", "value4"},
command: []string{"MGET", "test1", "test4", "test2", "test3", "test1"},
expected: []interface{}{"value1", "value4", "value2", "value3", "value1"},
expectedError: nil,
},
{
presetKeys: []string{"test5", "test6", "test7"},
presetValues: []string{"value5", "value6", "value7"},
command: []string{"MGET", "test5", "test6", "non-existent", "non-existent", "test7", "non-existent"},
expected: []interface{}{"value5", "value6", nil, nil, "value7", nil},
expectedError: nil,
},
{
presetKeys: []string{"test5"},
presetValues: []string{"value5"},
command: []string{"MGET"},
expected: nil,
expectedError: errors.New(utils.WrongArgsResponse),
},
}
for _, test := range tests {
// Set up the values
for i, key := range test.presetKeys {
_, err := mockServer.CreateKeyAndLock(context.Background(), key)
if err != nil {
t.Error(err)
}
mockServer.SetValue(context.Background(), key, test.presetValues[i])
mockServer.KeyUnlock(key)
}
// Test the command and its results
res, err := handleMGet(context.Background(), test.command, mockServer, nil)
if test.expectedError != nil {
// If we expect and error, branch out and check error
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error %+v, got: %+v", test.expectedError, err)
}
continue
}
if err != nil {
t.Error(err)
}
rr := resp.NewReader(bytes.NewBuffer(res))
rv, _, err := rr.ReadValue()
if err != nil {
t.Error(err)
}
if rv.Type().String() != "Array" {
t.Errorf("expected type Array, got: %s", rv.Type().String())
}
for i, value := range rv.Array() {
if test.expected[i] == nil {
if !value.IsNull() {
t.Errorf("expected nil value, got %+v", value)
}
continue
}
if value.String() != test.expected[i] {
t.Errorf("expected value %s, got: %s", test.expected[i], value.String())
}
}
}
}

View File

@@ -1,4 +1,4 @@
package etc package generic
import ( import (
"errors" "errors"
@@ -6,14 +6,7 @@ import (
) )
func setKeyFunc(cmd []string) ([]string, error) { func setKeyFunc(cmd []string) ([]string, error) {
if len(cmd) != 3 { if len(cmd) < 3 || len(cmd) > 7 {
return nil, errors.New(utils.WrongArgsResponse)
}
return []string{cmd[1]}, nil
}
func setNXKeyFunc(cmd []string) ([]string, error) {
if len(cmd) != 3 {
return nil, errors.New(utils.WrongArgsResponse) return nil, errors.New(utils.WrongArgsResponse)
} }
return []string{cmd[1]}, nil return []string{cmd[1]}, nil
@@ -31,3 +24,17 @@ func msetKeyFunc(cmd []string) ([]string, error) {
} }
return keys, nil return keys, nil
} }
func getKeyFunc(cmd []string) ([]string, error) {
if len(cmd) != 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
return []string{cmd[1]}, nil
}
func mgetKeyFunc(cmd []string) ([]string, error) {
if len(cmd) < 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
return cmd[1:], nil
}

View File

@@ -0,0 +1,103 @@
package generic
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
)
type SetParams struct {
exists string
get bool
expireAt interface{} // Exact expireAt time un unix milliseconds
}
func getSetCommandParams(cmd []string, params SetParams) (SetParams, error) {
if len(cmd) == 0 {
return params, nil
}
switch strings.ToLower(cmd[0]) {
case "get":
params.get = true
return getSetCommandParams(cmd[1:], params)
case "nx":
if params.exists != "" {
return SetParams{}, fmt.Errorf("cannot specify NX when %s is already specified", strings.ToUpper(params.exists))
}
params.exists = "NX"
return getSetCommandParams(cmd[1:], params)
case "xx":
if params.exists != "" {
return SetParams{}, fmt.Errorf("cannot specify XX when %s is already specified", strings.ToUpper(params.exists))
}
params.exists = "XX"
return getSetCommandParams(cmd[1:], params)
case "ex":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after EX")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify EX when expiry time is already set")
}
secondsStr := cmd[1]
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
if err != nil {
return SetParams{}, err
}
params.expireAt = time.Now().Add(time.Duration(seconds) * time.Second)
return getSetCommandParams(cmd[2:], params)
case "px":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after PX")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify PX when expiry time is already set")
}
millisecondsStr := cmd[1]
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
if err != nil {
return SetParams{}, err
}
params.expireAt = time.Now().Add(time.Duration(milliseconds) * time.Millisecond)
return getSetCommandParams(cmd[2:], params)
case "exat":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after EXAT")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify EXAT when expiry time is already set")
}
secondsStr := cmd[1]
seconds, err := strconv.ParseInt(secondsStr, 10, 64)
if err != nil {
return SetParams{}, err
}
params.expireAt = time.Unix(seconds, 0)
return getSetCommandParams(cmd[2:], params)
case "pxat":
if len(cmd) < 2 {
return SetParams{}, errors.New("seconds value required after PXAT")
}
if params.expireAt != nil {
return SetParams{}, errors.New("cannot specify PXAT when expiry time is already set")
}
millisecondsStr := cmd[1]
milliseconds, err := strconv.ParseInt(millisecondsStr, 10, 64)
if err != nil {
return SetParams{}, err
}
params.expireAt = time.UnixMilli(milliseconds)
return getSetCommandParams(cmd[2:], params)
default:
return SetParams{}, fmt.Errorf("unknown option %s for set command", strings.ToUpper(cmd[0]))
}
}

View File

@@ -1,101 +0,0 @@
package get
import (
"context"
"fmt"
"github.com/echovault/echovault/src/utils"
"net"
)
func handleGet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
keys, err := getKeyFunc(cmd)
if err != nil {
return nil, err
}
key := keys[0]
if !server.KeyExists(key) {
return []byte("$-1\r\n"), nil
}
_, err = server.KeyRLock(ctx, key)
if err != nil {
return nil, err
}
defer server.KeyRUnlock(key)
value := server.GetValue(key)
return []byte(fmt.Sprintf("+%v\r\n", value)), nil
}
func handleMGet(ctx context.Context, cmd []string, server utils.Server, conn *net.Conn) ([]byte, error) {
keys, err := mgetKeyFunc(cmd)
if err != nil {
return nil, err
}
values := make(map[string]string)
locks := make(map[string]bool)
for _, key := range keys {
if _, ok := values[key]; ok {
// Skip if we have already locked this key
continue
}
if server.KeyExists(key) {
_, err = server.KeyRLock(ctx, key)
if err != nil {
return nil, fmt.Errorf("could not obtain lock for %s key", key)
}
locks[key] = true
continue
}
values[key] = ""
}
defer func() {
for key, locked := range locks {
if locked {
server.KeyRUnlock(key)
locks[key] = false
}
}
}()
for key, _ := range locks {
values[key] = fmt.Sprintf("%v", server.GetValue(key))
}
bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:])))
for _, key := range cmd[1:] {
if values[key] == "" {
bytes = append(bytes, []byte("$-1\r\n")...)
continue
}
bytes = append(bytes, []byte(fmt.Sprintf("$%d\r\n%s\r\n", len(values[key]), values[key]))...)
}
return bytes, nil
}
func Commands() []utils.Command {
return []utils.Command{
{
Command: "get",
Categories: []string{utils.ReadCategory, utils.FastCategory},
Description: "(GET key) Get the value at the specified key.",
Sync: false,
KeyExtractionFunc: getKeyFunc,
HandlerFunc: handleGet,
},
{
Command: "mget",
Categories: []string{utils.ReadCategory, utils.FastCategory},
Description: "(MGET key1 [key2]) Get multiple values from the specified keys.",
Sync: false,
KeyExtractionFunc: mgetKeyFunc,
HandlerFunc: handleMGet,
},
}
}

View File

@@ -1,164 +0,0 @@
package get
import (
"bytes"
"context"
"errors"
"fmt"
"github.com/echovault/echovault/src/server"
"github.com/echovault/echovault/src/utils"
"github.com/tidwall/resp"
"testing"
)
func Test_HandleGET(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
tests := []struct {
key string
value string
}{
{
key: "test1",
value: "value1",
},
{
key: "test2",
value: "10",
},
{
key: "test3",
value: "3.142",
},
}
// Test successful GET command
for _, test := range tests {
func(key, value string) {
ctx := context.Background()
_, err := mockServer.CreateKeyAndLock(ctx, key)
if err != nil {
t.Error(err)
}
mockServer.SetValue(ctx, key, value)
mockServer.KeyUnlock(key)
res, err := handleGet(ctx, []string{"GET", key}, mockServer, nil)
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte(fmt.Sprintf("+%v\r\n", value))) {
t.Errorf("expected %s, got: %s", fmt.Sprintf("+%v\r\n", value), string(res))
}
}(test.key, test.value)
}
// Test get non-existent key
res, err := handleGet(context.Background(), []string{"GET", "test4"}, mockServer, nil)
if err != nil {
t.Error(err)
}
if !bytes.Equal(res, []byte("$-1\r\n")) {
t.Errorf("expected %+v, got: %+v", "+nil\r\n", res)
}
errorTests := []struct {
command []string
expected string
}{
{
command: []string{"GET"},
expected: utils.WrongArgsResponse,
},
{
command: []string{"GET", "key", "test"},
expected: utils.WrongArgsResponse,
},
}
for _, test := range errorTests {
res, err = handleGet(context.Background(), test.command, mockServer, nil)
if res != nil {
t.Errorf("expected nil response, got: %+v", res)
}
if err.Error() != test.expected {
t.Errorf("expected error '%s', got: %s", test.expected, err.Error())
}
}
}
func Test_HandleMGET(t *testing.T) {
mockServer := server.NewServer(server.Opts{})
tests := []struct {
presetKeys []string
presetValues []string
command []string
expected []interface{}
expectedError error
}{
{
presetKeys: []string{"test1", "test2", "test3", "test4"},
presetValues: []string{"value1", "value2", "value3", "value4"},
command: []string{"MGET", "test1", "test4", "test2", "test3", "test1"},
expected: []interface{}{"value1", "value4", "value2", "value3", "value1"},
expectedError: nil,
},
{
presetKeys: []string{"test5", "test6", "test7"},
presetValues: []string{"value5", "value6", "value7"},
command: []string{"MGET", "test5", "test6", "non-existent", "non-existent", "test7", "non-existent"},
expected: []interface{}{"value5", "value6", nil, nil, "value7", nil},
expectedError: nil,
},
{
presetKeys: []string{"test5"},
presetValues: []string{"value5"},
command: []string{"MGET"},
expected: nil,
expectedError: errors.New(utils.WrongArgsResponse),
},
}
for _, test := range tests {
// Set up the values
for i, key := range test.presetKeys {
_, err := mockServer.CreateKeyAndLock(context.Background(), key)
if err != nil {
t.Error(err)
}
mockServer.SetValue(context.Background(), key, test.presetValues[i])
mockServer.KeyUnlock(key)
}
// Test the command and its results
res, err := handleMGet(context.Background(), test.command, mockServer, nil)
if test.expectedError != nil {
// If we expect and error, branch out and check error
if err.Error() != test.expectedError.Error() {
t.Errorf("expected error %+v, got: %+v", test.expectedError, err)
}
continue
}
if err != nil {
t.Error(err)
}
rr := resp.NewReader(bytes.NewBuffer(res))
rv, _, err := rr.ReadValue()
if err != nil {
t.Error(err)
}
if rv.Type().String() != "Array" {
t.Errorf("expected type Array, got: %s", rv.Type().String())
}
for i, value := range rv.Array() {
if test.expected[i] == nil {
if !value.IsNull() {
t.Errorf("expected nil value, got %+v", value)
}
continue
}
if value.String() != test.expected[i] {
t.Errorf("expected value %s, got: %s", test.expected[i], value.String())
}
}
}
}

View File

@@ -1,20 +0,0 @@
package get
import (
"errors"
"github.com/echovault/echovault/src/utils"
)
func getKeyFunc(cmd []string) ([]string, error) {
if len(cmd) != 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
return []string{cmd[1]}, nil
}
func mgetKeyFunc(cmd []string) ([]string, error) {
if len(cmd) < 2 {
return nil, errors.New(utils.WrongArgsResponse)
}
return cmd[1:], nil
}

View File

@@ -2,10 +2,15 @@ package server
import ( import (
"context" "context"
"github.com/echovault/echovault/src/utils"
"slices"
"strings"
"sync" "sync"
"time" "time"
) )
// KeyLock tries to acquire the write lock for the specified key every 5 milliseconds.
// If the context passed to the function finishes before the lock is acquired, an error is returned.
func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) {
ticker := time.NewTicker(5 * time.Millisecond) ticker := time.NewTicker(5 * time.Millisecond)
for { for {
@@ -26,6 +31,8 @@ func (server *Server) KeyUnlock(key string) {
server.keyLocks[key].Unlock() server.keyLocks[key].Unlock()
} }
// KeyRLock tries to acquire the read lock for the specified key every few milliseconds.
// If the context passed to the function finishes before the lock is acquired, an error is returned.
func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) {
ticker := time.NewTicker(5 * time.Millisecond) ticker := time.NewTicker(5 * time.Millisecond)
for { for {
@@ -50,6 +57,8 @@ func (server *Server) KeyExists(key string) bool {
return server.keyLocks[key] != nil return server.keyLocks[key] != nil
} }
// CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist.
// If the key exists, the existing key is locked.
func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) { func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) {
server.keyCreationLock.Lock() server.keyCreationLock.Lock()
defer server.keyCreationLock.Unlock() defer server.keyCreationLock.Unlock()
@@ -64,17 +73,58 @@ func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, e
return server.KeyLock(ctx, key) return server.KeyLock(ctx, key)
} }
// GetValue retrieves the current value at the specified key.
// The key must be read-locked before calling this function.
func (server *Server) GetValue(key string) interface{} { func (server *Server) GetValue(key string) interface{} {
server.updateKeyInCache(key)
return server.store[key] return server.store[key]
} }
func (server *Server) SetValue(ctx context.Context, key string, value interface{}) { // SetValue updates the value in the store at the specified key with the given value.
// If we're in not in cluster (i.e. in standalone mode), then the change count is incremented
// in the snapshot engine.
// This count triggers a snapshot when the threshold is reached.
// The key must be locked prior to calling this function.
func (server *Server) SetValue(_ context.Context, key string, value interface{}) {
server.store[key] = value server.store[key] = value
server.updateKeyInCache(key)
if !server.IsInCluster() { if !server.IsInCluster() {
server.SnapshotEngine.IncrementChangeCount() server.SnapshotEngine.IncrementChangeCount()
} }
} }
// The SetKeyExpiry receiver function sets the expiry time of a key.
// The key parameter represents the key whose expiry time is to be set/updated.
// The expire parameter is the new expiry time.
// The touch parameter determines whether to update the keys access count on lfu eviction policy,
// or the access time on lru eviction policy.
// The key must be locked prior to calling this function.
func (server *Server) SetKeyExpiry(key string, expire time.Time, touch bool) {
server.keyExpiry[key] = expire
if touch {
server.updateKeyInCache(key)
}
}
// RemoveKeyExpiry is called by commands that remove key expiry (e.g. PERSIST).
// The key must be locked prior ro calling this function.
func (server *Server) RemoveKeyExpiry(key string) {
server.keyExpiry[key] = time.Time{}
switch {
case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, server.Config.EvictionPolicy):
server.lfuCache.Delete(key)
case slices.Contains([]string{utils.AllKeysLRU, utils.VolatileLRU}, server.Config.EvictionPolicy):
server.lruCache.Delete(key)
}
}
// GetState creates a deep copy of the store map.
// It is used to retrieve the current state for persistence but can also be used for other
// functions that require a deep copy of the state.
// The copy only starts when there's no current copy in progress (represented by StateCopyInProgress atomic boolean)
// and when there's no current state mutation in progress (represented by StateMutationInProgress atomic boolean)
func (server *Server) GetState() map[string]interface{} { func (server *Server) GetState() map[string]interface{} {
for { for {
if !server.StateCopyInProgress.Load() && !server.StateMutationInProgress.Load() { if !server.StateCopyInProgress.Load() && !server.StateMutationInProgress.Load() {
@@ -89,3 +139,22 @@ func (server *Server) GetState() map[string]interface{} {
server.StateCopyInProgress.Store(false) server.StateCopyInProgress.Store(false)
return data return data
} }
// updateKeyInCache updates either the key access count or the most recent access time in the cache
// depending on whether an LFU or LRU strategy was used.
func (server *Server) updateKeyInCache(key string) {
switch strings.ToLower(server.Config.EvictionPolicy) {
case utils.AllKeysLFU:
server.lfuCache.Update(key)
case utils.AllKeysLRU:
server.lruCache.Update(key)
case utils.VolatileLFU:
if _, ok := server.keyExpiry[key]; ok {
server.lfuCache.Update(key)
}
case utils.VolatileLRU:
if _, ok := server.keyExpiry[key]; ok {
server.lruCache.Update(key)
}
}
}

View File

@@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/echovault/echovault/src/aof" "github.com/echovault/echovault/src/aof"
"github.com/echovault/echovault/src/eviction"
"github.com/echovault/echovault/src/memberlist" "github.com/echovault/echovault/src/memberlist"
"github.com/echovault/echovault/src/raft" "github.com/echovault/echovault/src/raft"
"github.com/echovault/echovault/src/snapshot" "github.com/echovault/echovault/src/snapshot"
@@ -28,6 +29,9 @@ type Server struct {
store map[string]interface{} store map[string]interface{}
keyLocks map[string]*sync.RWMutex keyLocks map[string]*sync.RWMutex
keyCreationLock *sync.Mutex keyCreationLock *sync.Mutex
keyExpiry map[string]time.Time
lfuCache *eviction.CacheLFU
lruCache *eviction.CacheLRU
Commands []utils.Command Commands []utils.Command
@@ -66,6 +70,7 @@ func NewServer(opts Opts) *Server {
store: make(map[string]interface{}), store: make(map[string]interface{}),
keyLocks: make(map[string]*sync.RWMutex), keyLocks: make(map[string]*sync.RWMutex),
keyCreationLock: &sync.Mutex{}, keyCreationLock: &sync.Mutex{},
keyExpiry: make(map[string]time.Time),
} }
if server.IsInCluster() { if server.IsInCluster() {
server.raft = raft.NewRaft(raft.Opts{ server.raft = raft.NewRaft(raft.Opts{
@@ -117,6 +122,11 @@ func NewServer(opts Opts) *Server {
}), }),
) )
} }
// Set up lfu and lru caches
server.lfuCache = eviction.NewCacheLFU()
server.lruCache = eviction.NewCacheLRU()
return server return server
} }

View File

@@ -121,7 +121,7 @@ Supported units (kb, mb, gb, tb, pb). There is no limit by default.`, func(memor
raftBindPort := flag.Uint("raft-port", 7481, "Port to use for intra-cluster communication. Leave on the client.") raftBindPort := flag.Uint("raft-port", 7481, "Port to use for intra-cluster communication. Leave on the client.")
mlBindPort := flag.Uint("memberlist-port", 7946, "Port to use for memberlist communication.") mlBindPort := flag.Uint("memberlist-port", 7946, "Port to use for memberlist communication.")
inMemory := flag.Bool("in-memory", false, "Whether to use memory or persistent storage for raft logs and snapshots.") inMemory := flag.Bool("in-memory", false, "Whether to use memory or persistent storage for raft logs and snapshots.")
dataDir := flag.String("data-dir", "/var/lib/echovault", "Directory to store raft snapshots and logs.") dataDir := flag.String("data-dir", "/var/lib/echovault", "Directory to store snapshots and logs.")
bootstrapCluster := flag.Bool("bootstrap-cluster", false, "Whether this instance should bootstrap a new cluster.") bootstrapCluster := flag.Bool("bootstrap-cluster", false, "Whether this instance should bootstrap a new cluster.")
aclConfig := flag.String("acl-config", "", "ACL config file path.") aclConfig := flag.String("acl-config", "", "ACL config file path.")
snapshotThreshold := flag.Uint64("snapshot-threshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.") snapshotThreshold := flag.Uint64("snapshot-threshold", 1000, "The number of entries that trigger a snapshot. Default is 1000.")
@@ -213,7 +213,7 @@ It is a plain text value by default but you can provide a SHA256 hash by adding
var err error = nil var err error = nil
if conf.RequirePass && conf.Password == "" { if conf.RequirePass && conf.Password == "" {
err = errors.New("password cannot be empty if requirePass is etc to true") err = errors.New("password cannot be empty if requirePass is generic to true")
} }
return conf, err return conf, err

View File

@@ -32,7 +32,7 @@ const (
const ( const (
NoEviction = "noeviction" NoEviction = "noeviction"
AllKeysLRU = "allkeys-lru" AllKeysLRU = "allkeys-lru"
AllKeysLFU = "allkeys=lfu" AllKeysLFU = "allkeys-lfu"
VolatileLRU = "volatile-lru" VolatileLRU = "volatile-lru"
VolatileLFU = "volatile-lfu" VolatileLFU = "volatile-lfu"
AllKeysRandom = "allkeys-random" AllKeysRandom = "allkeys-random"

View File

@@ -3,6 +3,7 @@ package utils
import ( import (
"context" "context"
"net" "net"
"time"
) )
type Server interface { type Server interface {
@@ -14,6 +15,8 @@ type Server interface {
CreateKeyAndLock(ctx context.Context, key string) (bool, error) CreateKeyAndLock(ctx context.Context, key string) (bool, error)
GetValue(key string) interface{} GetValue(key string) interface{}
SetValue(ctx context.Context, key string, value interface{}) SetValue(ctx context.Context, key string, value interface{})
SetKeyExpiry(key string, expire time.Time, touch bool)
RemoveKeyExpiry(key string)
GetState() map[string]interface{} GetState() map[string]interface{}
GetAllCommands(ctx context.Context) []Command GetAllCommands(ctx context.Context) []Command
GetACL() interface{} GetACL() interface{}