From 82be1f606827d1baf90f53bd83e9ffe116f5e49a Mon Sep 17 00:00:00 2001 From: Kelvin Mwinuka Date: Fri, 8 Mar 2024 00:26:49 +0800 Subject: [PATCH] Added addedTime to EntryLFU object to control which entry will be removed if access count is the same. If two entries have the same access count, the older entry should be removed first. SetKeyExpiry and GetValue keyspace receiver functions now require context object to be passed. Created adjustMemoryUsage function for key eviction for LFU, LRU, and Random eviction policies. Updated all modules to pass context to SetKeyExpirty and GetValue functions. --- src/eviction/lfu.go | 20 ++- src/modules/generic/commands.go | 8 +- src/modules/generic/commands_test.go | 8 +- src/modules/hash/commands.go | 22 +-- src/modules/hash/commands_test.go | 6 +- src/modules/list/commands.go | 22 +-- src/modules/list/commands_test.go | 14 +- src/modules/set/commands.go | 38 ++--- src/modules/set/commant_test.go | 16 +- src/modules/sorted_set/commands.go | 50 +++--- src/modules/sorted_set/commands_test.go | 26 +-- src/modules/string/commands.go | 6 +- src/modules/string/commands_test.go | 2 +- src/server/keyspace.go | 204 +++++++++++++++++++++--- src/server/server.go | 17 +- src/utils/config.go | 3 +- src/utils/types.go | 4 +- src/utils/utils.go | 11 +- 18 files changed, 330 insertions(+), 147 deletions(-) diff --git a/src/eviction/lfu.go b/src/eviction/lfu.go index d1c38f7..b28d779 100644 --- a/src/eviction/lfu.go +++ b/src/eviction/lfu.go @@ -3,12 +3,14 @@ package eviction import ( "container/heap" "slices" + "time" ) type EntryLFU struct { - key string // The key, matching the key in the store - count int // The number of times this key has been accessed - index int // The index of the entry in the heap + key string // The key, matching the key in the store + count int // The number of times this key has been accessed + addedTime int64 // The time this entry was added to the cache in unix milliseconds + index int // The index of the entry in the heap } type CacheLFU struct { @@ -30,6 +32,11 @@ func (cache *CacheLFU) Len() int { } func (cache *CacheLFU) Less(i, j int) bool { + // If 2 entries have the same count, return the older one + if cache.entries[i].count == cache.entries[j].count { + return cache.entries[i].addedTime > cache.entries[j].addedTime + } + // Otherwise, return the one with a lower count return cache.entries[i].count < cache.entries[j].count } @@ -42,9 +49,10 @@ func (cache *CacheLFU) Swap(i, j int) { func (cache *CacheLFU) Push(key any) { n := len(cache.entries) cache.entries = append(cache.entries, &EntryLFU{ - key: key.(string), - count: 1, - index: n, + key: key.(string), + count: 1, + addedTime: time.Now().UnixMilli(), + index: n, }) cache.keys[key.(string)] = true } diff --git a/src/modules/generic/commands.go b/src/modules/generic/commands.go index 5ecd0dc..9c40d46 100644 --- a/src/modules/generic/commands.go +++ b/src/modules/generic/commands.go @@ -35,7 +35,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co if !server.KeyExists(key) { res = []byte("$-1\r\n") } else { - res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(key))) + res = []byte(fmt.Sprintf("+%v\r\n", server.GetValue(ctx, key))) } } @@ -72,7 +72,7 @@ func handleSet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co // If expiresAt is set, set the key's expiry time as well if params.expireAt != nil { - server.SetKeyExpiry(key, params.expireAt.(time.Time), false) + server.SetKeyExpiry(ctx, key, params.expireAt.(time.Time), false) } return res, nil @@ -151,7 +151,7 @@ func handleGet(ctx context.Context, cmd []string, server utils.Server, _ *net.Co } defer server.KeyRUnlock(key) - value := server.GetValue(key) + value := server.GetValue(ctx, key) return []byte(fmt.Sprintf("+%v\r\n", value)), nil } @@ -190,7 +190,7 @@ func handleMGet(ctx context.Context, cmd []string, server utils.Server, _ *net.C }() for key, _ := range locks { - values[key] = fmt.Sprintf("%v", server.GetValue(key)) + values[key] = fmt.Sprintf("%v", server.GetValue(ctx, key)) } bytes := []byte(fmt.Sprintf("*%d\r\n", len(cmd[1:]))) diff --git a/src/modules/generic/commands_test.go b/src/modules/generic/commands_test.go index 41a2849..a054577 100644 --- a/src/modules/generic/commands_test.go +++ b/src/modules/generic/commands_test.go @@ -71,7 +71,7 @@ func Test_HandleSET(t *testing.T) { if rv.String() != test.expectedResponse { t.Errorf("expected response %s, got %s", test.expectedResponse, rv.String()) } - value := mockServer.GetValue(test.command[1]) + value := mockServer.GetValue(context.Background(), test.command[1]) switch value.(type) { default: t.Error("unexpected type for expectedValue") @@ -151,7 +151,7 @@ func Test_HandleMSET(t *testing.T) { t.Error("unexpected type for expectedValue") case int: ev, _ := expectedValue.(int) - value, ok := mockServer.GetValue(key).(int) + value, ok := mockServer.GetValue(context.Background(), key).(int) if !ok { t.Errorf("expected integer type for key %s, got another type", key) } @@ -160,7 +160,7 @@ func Test_HandleMSET(t *testing.T) { } case float64: ev, _ := expectedValue.(float64) - value, ok := mockServer.GetValue(key).(float64) + value, ok := mockServer.GetValue(context.Background(), key).(float64) if !ok { t.Errorf("expected float type for key %s, got another type", key) } @@ -169,7 +169,7 @@ func Test_HandleMSET(t *testing.T) { } case string: ev, _ := expectedValue.(string) - value, ok := mockServer.GetValue(key).(string) + value, ok := mockServer.GetValue(context.Background(), key).(string) if !ok { t.Errorf("expected string type for key %s, got another type", key) } diff --git a/src/modules/hash/commands.go b/src/modules/hash/commands.go index f591801..29ca636 100644 --- a/src/modules/hash/commands.go +++ b/src/modules/hash/commands.go @@ -46,7 +46,7 @@ func handleHSET(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -88,7 +88,7 @@ func handleHGET(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -139,7 +139,7 @@ func handleHSTRLEN(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -189,7 +189,7 @@ func handleHVALS(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -251,7 +251,7 @@ func handleHRANDFIELD(ctx context.Context, cmd []string, server utils.Server, co } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -346,7 +346,7 @@ func handleHLEN(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -371,7 +371,7 @@ func handleHKEYS(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -436,7 +436,7 @@ func handleHINCRBY(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -493,7 +493,7 @@ func handleHGETALL(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -534,7 +534,7 @@ func handleHEXISTS(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } @@ -564,7 +564,7 @@ func handleHDEL(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - hash, ok := server.GetValue(key).(map[string]interface{}) + hash, ok := server.GetValue(ctx, key).(map[string]interface{}) if !ok { return nil, fmt.Errorf("value at %s is not a hash", key) } diff --git a/src/modules/hash/commands_test.go b/src/modules/hash/commands_test.go index 989ed97..09a5fc6 100644 --- a/src/modules/hash/commands_test.go +++ b/src/modules/hash/commands_test.go @@ -125,7 +125,7 @@ func Test_HandleHSET(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -278,7 +278,7 @@ func Test_HandleHINCRBY(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - hash, ok := mockServer.GetValue(test.key).(map[string]interface{}) + hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}) if !ok { t.Errorf("value at key \"%s\" is not a hash map", test.key) } @@ -1255,7 +1255,7 @@ func Test_HandleHDEL(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - if hash, ok := mockServer.GetValue(test.key).(map[string]interface{}); ok { + if hash, ok := mockServer.GetValue(context.Background(), test.key).(map[string]interface{}); ok { for field, value := range hash { if value != test.expectedValue[field] { t.Errorf("expected value \"%+v\", got \"%+v\"", test.expectedValue[field], value) diff --git a/src/modules/list/commands.go b/src/modules/list/commands.go index bb124e5..3a0f6d4 100644 --- a/src/modules/list/commands.go +++ b/src/modules/list/commands.go @@ -29,7 +29,7 @@ func handleLLen(ctx context.Context, cmd []string, server utils.Server, _ *net.C } defer server.KeyRUnlock(key) - if list, ok := server.GetValue(key).([]interface{}); ok { + if list, ok := server.GetValue(ctx, key).([]interface{}); ok { return []byte(fmt.Sprintf(":%d\r\n", len(list))), nil } @@ -56,7 +56,7 @@ func handleLIndex(ctx context.Context, cmd []string, server utils.Server, conn * if _, err = server.KeyRLock(ctx, key); err != nil { return nil, err } - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) server.KeyRUnlock(key) if !ok { @@ -93,7 +93,7 @@ func handleLRange(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LRANGE command on non-list item") } @@ -171,7 +171,7 @@ func handleLSet(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LSET command on non-list item") } @@ -215,7 +215,7 @@ func handleLTrim(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LTRIM command on non-list item") } @@ -262,7 +262,7 @@ func handleLRem(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, errors.New("LREM command on non-list item") } @@ -335,8 +335,8 @@ func handleLMove(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(destination) - sourceList, sourceOk := server.GetValue(source).([]interface{}) - destinationList, destinationOk := server.GetValue(destination).([]interface{}) + sourceList, sourceOk := server.GetValue(ctx, source).([]interface{}) + destinationList, destinationOk := server.GetValue(ctx, destination).([]interface{}) if !sourceOk || !destinationOk { return nil, errors.New("both source and destination must be lists") @@ -399,7 +399,7 @@ func handleLPush(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(key) - currentList := server.GetValue(key) + currentList := server.GetValue(ctx, key) l, ok := currentList.([]interface{}) if !ok { @@ -446,7 +446,7 @@ func handleRPush(ctx context.Context, cmd []string, server utils.Server, conn *n defer server.KeyUnlock(key) } - currentList := server.GetValue(key) + currentList := server.GetValue(ctx, key) l, ok := currentList.([]interface{}) @@ -477,7 +477,7 @@ func handlePop(ctx context.Context, cmd []string, server utils.Server, conn *net } defer server.KeyUnlock(key) - list, ok := server.GetValue(key).([]interface{}) + list, ok := server.GetValue(ctx, key).([]interface{}) if !ok { return nil, fmt.Errorf("%s command on non-list item", strings.ToUpper(cmd[0])) } diff --git a/src/modules/list/commands_test.go b/src/modules/list/commands_test.go index ae2eabf..a04dc1c 100644 --- a/src/modules/list/commands_test.go +++ b/src/modules/list/commands_test.go @@ -512,7 +512,7 @@ func Test_HandleLSET(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -670,7 +670,7 @@ func Test_HandleLTRIM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -789,7 +789,7 @@ func Test_HandleLREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -999,7 +999,7 @@ func Test_HandleLMOVE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1105,7 +1105,7 @@ func Test_HandleLPUSH(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1206,7 +1206,7 @@ func Test_HandleRPUSH(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } @@ -1316,7 +1316,7 @@ func Test_HandlePop(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - list, ok := mockServer.GetValue(test.key).([]interface{}) + list, ok := mockServer.GetValue(context.Background(), test.key).([]interface{}) if !ok { t.Error("expected value to be list, got another type") } diff --git a/src/modules/set/commands.go b/src/modules/set/commands.go index c77531e..20923e1 100644 --- a/src/modules/set/commands.go +++ b/src/modules/set/commands.go @@ -37,7 +37,7 @@ func handleSADD(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -64,7 +64,7 @@ func handleSCARD(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -88,7 +88,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyRUnlock(keys[0]) - baseSet, ok := server.GetValue(keys[0]).(*Set) + baseSet, ok := server.GetValue(ctx, keys[0]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[0]) } @@ -114,7 +114,7 @@ func handleSDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n var sets []*Set for _, key := range cmd[2:] { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { continue } @@ -151,7 +151,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co return nil, err } defer server.KeyRUnlock(keys[1]) - baseSet, ok := server.GetValue(keys[1]).(*Set) + baseSet, ok := server.GetValue(ctx, keys[1]).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", keys[1]) } @@ -177,7 +177,7 @@ func handleSDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co var sets []*Set for _, key := range keys[2:] { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { continue } @@ -240,7 +240,7 @@ func handleSINTER(ctx context.Context, cmd []string, server utils.Server, conn * var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -316,7 +316,7 @@ func handleSINTERCARD(ctx context.Context, cmd []string, server utils.Server, co var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -362,7 +362,7 @@ func handleSINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c var sets []*Set for key, _ := range locks { - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { // If the value at the key is not a set, return error return nil, fmt.Errorf("value at key %s is not a set", key) @@ -408,7 +408,7 @@ func handleSISMEMBER(ctx context.Context, cmd []string, server utils.Server, con } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -437,7 +437,7 @@ func handleSMEMBERS(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -480,7 +480,7 @@ func handleSMISMEMBER(ctx context.Context, cmd []string, server utils.Server, co } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -517,7 +517,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyUnlock(source) - sourceSet, ok := server.GetValue(source).(*Set) + sourceSet, ok := server.GetValue(ctx, source).(*Set) if !ok { return nil, errors.New("source is not a set") } @@ -540,7 +540,7 @@ func handleSMOVE(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyUnlock(destination) - ds, ok := server.GetValue(destination).(*Set) + ds, ok := server.GetValue(ctx, destination).(*Set) if !ok { return nil, errors.New("destination is not a set") } @@ -578,7 +578,7 @@ func handleSPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at %s is not a set", key) } @@ -622,7 +622,7 @@ func handleSRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at %s is not a set", key) } @@ -658,7 +658,7 @@ func handleSREM(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -699,7 +699,7 @@ func handleSUNION(ctx context.Context, cmd []string, server utils.Server, conn * if !locked { continue } - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } @@ -750,7 +750,7 @@ func handleSUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c if !locked { continue } - set, ok := server.GetValue(key).(*Set) + set, ok := server.GetValue(ctx, key).(*Set) if !ok { return nil, fmt.Errorf("value at key %s is not a set", key) } diff --git a/src/modules/set/commant_test.go b/src/modules/set/commant_test.go index 90999cc..3081cd9 100644 --- a/src/modules/set/commant_test.go +++ b/src/modules/set/commant_test.go @@ -88,7 +88,7 @@ func Test_HandleSADD(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected set value at key \"%s\"", test.key) } @@ -408,7 +408,7 @@ func Test_HandleSDIFFSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -754,7 +754,7 @@ func Test_HandleSINTERSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -1126,7 +1126,7 @@ func Test_HandleSMOVE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(key) } - set, ok := mockServer.GetValue(key).(*Set) + set, ok := mockServer.GetValue(context.Background(), key).(*Set) if !ok { t.Errorf("expected set \"%s\" to be a set, got another type", key) } @@ -1223,7 +1223,7 @@ func Test_HandleSPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1334,7 +1334,7 @@ func Test_HandleSRANDMEMBER(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1437,7 +1437,7 @@ func Test_HandleSREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.key).(*Set) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -1641,7 +1641,7 @@ func Test_HandleSUNIONSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*Set) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*Set) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } diff --git a/src/modules/sorted_set/commands.go b/src/modules/sorted_set/commands.go index 8dd2142..903f6b7 100644 --- a/src/modules/sorted_set/commands.go +++ b/src/modules/sorted_set/commands.go @@ -133,7 +133,7 @@ func handleZADD(ctx context.Context, cmd []string, server utils.Server, conn *ne return nil, err } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -180,7 +180,7 @@ func handleZCARD(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -241,7 +241,7 @@ func handleZCOUNT(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -275,7 +275,7 @@ func handleZLEXCOUNT(ctx context.Context, cmd []string, server utils.Server, con } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -332,7 +332,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } defer server.KeyRUnlock(keys[0]) - baseSortedSet, ok := server.GetValue(keys[0]).(*SortedSet) + baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) } @@ -349,7 +349,7 @@ func handleZDIFF(ctx context.Context, cmd []string, server utils.Server, conn *n return nil, err } locks[keys[i]] = locked - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -400,7 +400,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co return nil, err } defer server.KeyRUnlock(keys[0]) - baseSortedSet, ok := server.GetValue(keys[0]).(*SortedSet) + baseSortedSet, ok := server.GetValue(ctx, keys[0]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[0]) } @@ -412,7 +412,7 @@ func handleZDIFFSTORE(ctx context.Context, cmd []string, server utils.Server, co if _, err = server.KeyRLock(ctx, keys[i]); err != nil { return nil, err } - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -486,7 +486,7 @@ func handleZINCRBY(ctx context.Context, cmd []string, server utils.Server, conn return nil, err } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -534,7 +534,7 @@ func handleZINTER(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -600,7 +600,7 @@ func handleZINTERSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -681,7 +681,7 @@ func handleZMPOP(ctx context.Context, cmd []string, server utils.Server, conn *n if _, err = server.KeyLock(ctx, keys[i]); err != nil { continue } - v, ok := server.GetValue(keys[i]).(*SortedSet) + v, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok || v.Cardinality() == 0 { server.KeyUnlock(keys[i]) continue @@ -739,7 +739,7 @@ func handleZPOP(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at key %s is not a sorted set", key) } @@ -776,7 +776,7 @@ func handleZMSCORE(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -835,7 +835,7 @@ func handleZRANDMEMBER(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -879,7 +879,7 @@ func handleZRANK(ctx context.Context, cmd []string, server utils.Server, conn *n } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -923,7 +923,7 @@ func handleZREM(ctx context.Context, cmd []string, server utils.Server, conn *ne } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -953,7 +953,7 @@ func handleZSCORE(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -996,7 +996,7 @@ func handleZREMRANGEBYSCORE(ctx context.Context, cmd []string, server utils.Serv } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1038,7 +1038,7 @@ func handleZREMRANGEBYRANK(ctx context.Context, cmd []string, server utils.Serve } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1095,7 +1095,7 @@ func handleZREMRANGEBYLEX(ctx context.Context, cmd []string, server utils.Server } defer server.KeyUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1193,7 +1193,7 @@ func handleZRANGE(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - set, ok := server.GetValue(key).(*SortedSet) + set, ok := server.GetValue(ctx, key).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", key) } @@ -1330,7 +1330,7 @@ func handleZRANGESTORE(ctx context.Context, cmd []string, server utils.Server, c } defer server.KeyRUnlock(source) - set, ok := server.GetValue(source).(*SortedSet) + set, ok := server.GetValue(ctx, source).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", source) } @@ -1432,7 +1432,7 @@ func handleZUNION(ctx context.Context, cmd []string, server utils.Server, conn * return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } @@ -1494,7 +1494,7 @@ func handleZUNIONSTORE(ctx context.Context, cmd []string, server utils.Server, c return nil, err } locks[keys[i]] = true - set, ok := server.GetValue(keys[i]).(*SortedSet) + set, ok := server.GetValue(ctx, keys[i]).(*SortedSet) if !ok { return nil, fmt.Errorf("value at %s is not a sorted set", keys[i]) } diff --git a/src/modules/sorted_set/commands_test.go b/src/modules/sorted_set/commands_test.go index 45ddd93..2cef684 100644 --- a/src/modules/sorted_set/commands_test.go +++ b/src/modules/sorted_set/commands_test.go @@ -248,7 +248,7 @@ func Test_HandleZADD(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - sortedSet, ok := mockServer.GetValue(test.key).(*SortedSet) + sortedSet, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected the value at key \"%s\" to be a sorted set, got another type", test.key) } @@ -935,7 +935,7 @@ func Test_HandleZDIFFSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -1149,7 +1149,7 @@ func Test_HandleZINCRBY(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.key) } @@ -1384,7 +1384,7 @@ func Test_HandleZMPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1555,7 +1555,7 @@ func Test_HandleZPOP(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected key \"%s\" to be a sorted set, got another type", key) } @@ -1869,7 +1869,7 @@ func Test_HandleZRANDMEMBER(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a set, got another type", test.key) } @@ -2112,7 +2112,7 @@ func Test_HandleZREM(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2218,7 +2218,7 @@ func Test_HandleZREMRANGEBYSCORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2378,7 +2378,7 @@ func Test_HandleZREMRANGEBYRANK(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -2509,7 +2509,7 @@ func Test_HandleZREMRANGEBYLEX(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), key); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(key).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), key).(*SortedSet) if !ok { t.Errorf("expected value at key \"%s\" to be a sorted set, got another type", key) } @@ -3034,7 +3034,7 @@ func Test_HandleZRANGESTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -3693,7 +3693,7 @@ func Test_HandleZINTERSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } @@ -4418,7 +4418,7 @@ func Test_HandleZUNIONSTORE(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.destination); err != nil { t.Error(err) } - set, ok := mockServer.GetValue(test.destination).(*SortedSet) + set, ok := mockServer.GetValue(context.Background(), test.destination).(*SortedSet) if !ok { t.Errorf("expected vaule at key %s to be set, got another type", test.destination) } diff --git a/src/modules/string/commands.go b/src/modules/string/commands.go index 35a3548..ee92e6d 100644 --- a/src/modules/string/commands.go +++ b/src/modules/string/commands.go @@ -39,7 +39,7 @@ func handleSetRange(ctx context.Context, cmd []string, server utils.Server, conn } defer server.KeyUnlock(key) - str, ok := server.GetValue(key).(string) + str, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } @@ -100,7 +100,7 @@ func handleStrLen(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - value, ok := server.GetValue(key).(string) + value, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) @@ -134,7 +134,7 @@ func handleSubStr(ctx context.Context, cmd []string, server utils.Server, conn * } defer server.KeyRUnlock(key) - value, ok := server.GetValue(key).(string) + value, ok := server.GetValue(ctx, key).(string) if !ok { return nil, fmt.Errorf("value at key %s is not a string", key) } diff --git a/src/modules/string/commands_test.go b/src/modules/string/commands_test.go index f06142e..c531d1b 100644 --- a/src/modules/string/commands_test.go +++ b/src/modules/string/commands_test.go @@ -138,7 +138,7 @@ func Test_HandleSetRange(t *testing.T) { if _, err = mockServer.KeyRLock(context.Background(), test.key); err != nil { t.Error(err) } - value, ok := mockServer.GetValue(test.key).(string) + value, ok := mockServer.GetValue(context.Background(), test.key).(string) if !ok { t.Error("expected string data type, got another type") } diff --git a/src/server/keyspace.go b/src/server/keyspace.go index 4357fa4..9294514 100644 --- a/src/server/keyspace.go +++ b/src/server/keyspace.go @@ -3,20 +3,26 @@ package server import ( "context" "errors" + "fmt" "github.com/echovault/echovault/src/utils" + "log" + "math/rand" + "runtime" "slices" "strings" "sync" "time" ) -// KeyLock tries to acquire the write lock for the specified key every 5 milliseconds. +// KeyLock tries to acquire the write lock for the specified key. // If the context passed to the function finishes before the lock is acquired, an error is returned. func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { - ticker := time.NewTicker(5 * time.Millisecond) for { select { default: + if server.keyLocks[key] == nil { + return false, fmt.Errorf("key %s not found", key) + } ok := server.keyLocks[key].TryLock() if ok { return true, nil @@ -24,21 +30,24 @@ func (server *Server) KeyLock(ctx context.Context, key string) (bool, error) { case <-ctx.Done(): return false, context.Cause(ctx) } - <-ticker.C } } func (server *Server) KeyUnlock(key string) { - server.keyLocks[key].Unlock() + if server.KeyExists(key) { + server.keyLocks[key].Unlock() + } } -// KeyRLock tries to acquire the read lock for the specified key every few milliseconds. +// KeyRLock tries to acquire the read lock for the specified key. // If the context passed to the function finishes before the lock is acquired, an error is returned. func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { - ticker := time.NewTicker(5 * time.Millisecond) for { select { default: + if server.keyLocks[key] == nil { + return false, fmt.Errorf("key %s not found", key) + } ok := server.keyLocks[key].TryRLock() if ok { return true, nil @@ -46,12 +55,13 @@ func (server *Server) KeyRLock(ctx context.Context, key string) (bool, error) { case <-ctx.Done(): return false, context.Cause(ctx) } - <-ticker.C } } func (server *Server) KeyRUnlock(key string) { - server.keyLocks[key].RUnlock() + if server.KeyExists(key) { + server.keyLocks[key].RUnlock() + } } func (server *Server) KeyExists(key string) bool { @@ -61,7 +71,7 @@ func (server *Server) KeyExists(key string) bool { // CreateKeyAndLock creates a new key lock and immediately locks it if the key does not exist. // If the key exists, the existing key is locked. func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, error) { - if utils.IsMaxMemoryExceeded(server.Config) && server.Config.EvictionPolicy == utils.NoEviction { + if utils.IsMaxMemoryExceeded(server.Config.MaxMemory) && server.Config.EvictionPolicy == utils.NoEviction { return false, errors.New("max memory reached, key not created") } @@ -80,8 +90,11 @@ func (server *Server) CreateKeyAndLock(ctx context.Context, key string) (bool, e // GetValue retrieves the current value at the specified key. // The key must be read-locked before calling this function. -func (server *Server) GetValue(key string) interface{} { - server.updateKeyInCache(key) +func (server *Server) GetValue(ctx context.Context, key string) interface{} { + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("GetValue error: %+v\n", err) + } return server.store[key] } @@ -90,14 +103,17 @@ func (server *Server) GetValue(key string) interface{} { // in the snapshot engine. // This count triggers a snapshot when the threshold is reached. // The key must be locked prior to calling this function. -func (server *Server) SetValue(_ context.Context, key string, value interface{}) error { - if utils.IsMaxMemoryExceeded(server.Config) && server.Config.EvictionPolicy == utils.NoEviction { +func (server *Server) SetValue(ctx context.Context, key string, value interface{}) error { + if utils.IsMaxMemoryExceeded(server.Config.MaxMemory) && server.Config.EvictionPolicy == utils.NoEviction { return errors.New("max memory reached, key value not set") } server.store[key] = value - server.updateKeyInCache(key) + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("SetValue error: %+v\n", err) + } if !server.IsInCluster() { server.SnapshotEngine.IncrementChangeCount() @@ -112,10 +128,13 @@ func (server *Server) SetValue(_ context.Context, key string, value interface{}) // The touch parameter determines whether to update the keys access count on lfu eviction policy, // or the access time on lru eviction policy. // The key must be locked prior to calling this function. -func (server *Server) SetKeyExpiry(key string, expire time.Time, touch bool) { +func (server *Server) SetKeyExpiry(ctx context.Context, key string, expire time.Time, touch bool) { server.keyExpiry[key] = expire if touch { - server.updateKeyInCache(key) + err := server.updateKeyInCache(ctx, key) + if err != nil { + log.Printf("SetKeyExpiry error: %+v\n", err) + } } } @@ -153,7 +172,11 @@ func (server *Server) GetState() map[string]interface{} { // updateKeyInCache updates either the key access count or the most recent access time in the cache // depending on whether an LFU or LRU strategy was used. -func (server *Server) updateKeyInCache(key string) { +func (server *Server) updateKeyInCache(ctx context.Context, key string) error { + // If max memory is 0, there's no max so no need to update caches + if server.Config.MaxMemory == 0 { + return nil + } switch strings.ToLower(server.Config.EvictionPolicy) { case utils.AllKeysLFU: server.lfuCache.Update(key) @@ -168,5 +191,150 @@ func (server *Server) updateKeyInCache(key string) { server.lruCache.Update(key) } } - // TODO: Check if memory usage is above max-memory. If it is, pop items from the cache until we get under the limit. + if err := server.adjustMemoryUsage(ctx); err != nil { + return fmt.Errorf("updateKeyInCache: %+v", err) + } + return nil +} + +func (server *Server) adjustMemoryUsage(ctx context.Context) error { + // If max memory is 0, there's no need to adjust memory usage. + if server.Config.MaxMemory == 0 { + return nil + } + // Check if memory usage is above max-memory. + // If it is, pop items from the cache until we get under the limit. + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + // If we're using less memory than the max-memory, there's no need to evict. + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + // Force a garbage collection first before we start evicting key. + runtime.GC() + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + // We've done a GC, but we're still at or above the max memory limit. + // Start a loop that evicts keys until either the heap is empty or + // we're below the max memory limit. + for { + switch { + case slices.Contains([]string{utils.AllKeysLFU, utils.VolatileLFU}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove keys from LFU cache until we're below the max memory limit or + // until the LFU cache is empty. + for { + // Return if cache is empty + if server.lfuCache.Len() == 0 { + return fmt.Errorf("adjsutMemoryUsage -> LFU cache empty") + } + key := server.lfuCache.Pop().(string) + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> LFU cache eviction: %+v", err) + } + // Delete the keys + delete(server.store, key) + delete(server.keyExpiry, key) + delete(server.keyLocks, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + case slices.Contains([]string{utils.AllKeysLRU, utils.VolatileLRU}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove keys from th LRU cache until we're below the max memory limit or + // until the LRU cache is empty. + for { + // Return if cache is empty + if server.lruCache.Len() == 0 { + return fmt.Errorf("adjsutMemoryUsage -> LRU cache empty") + } + key := server.lruCache.Pop().(string) + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> LRU cache eviction: %+v", err) + } + // Delete the keys + delete(server.store, key) + delete(server.keyExpiry, key) + delete(server.keyLocks, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + case slices.Contains([]string{utils.AllKeysRandom}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove random keys until we're below the max memory limit + // or there are no more keys remaining. + for { + // If there are no keys, return error + if len(server.keyLocks) == 0 { + err := errors.New("no keys to evict") + return fmt.Errorf("adjustMemoryUsage -> all keys random: %+v", err) + } + // Get random key + idx := rand.Intn(len(server.keyLocks)) + for key, _ := range server.keyLocks { + if idx == 0 { + // Lock the key + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> all keys random: %+v", err) + } + // Delete the key + delete(server.keyLocks, key) + delete(server.store, key) + delete(server.keyExpiry, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + idx-- + } + } + case slices.Contains([]string{utils.VolatileRandom}, strings.ToLower(server.Config.EvictionPolicy)): + // Remove random keys with expiry time until we're below the max memory limit + // or there are no more keys with expiry time. + for { + // If there are no volatile keys, return error + if len(server.keyExpiry) == 0 { + err := errors.New("no volatile keys to evict") + return fmt.Errorf("adjustMemoryUsage -> volatile keys random: %+v", err) + } + // Get random volatile key + idx := rand.Intn(len(server.keyExpiry)) + for key, _ := range server.keyExpiry { + if idx == 0 { + // Lock the key + if _, err := server.KeyLock(ctx, key); err != nil { + return fmt.Errorf("adjustMemoryUsage -> volatile keys random: %+v", err) + } + // Delete the key + delete(server.keyLocks, key) + delete(server.store, key) + delete(server.keyExpiry, key) + // Run garbage collection + runtime.GC() + // Return if we're below max memory + runtime.ReadMemStats(&memStats) + if memStats.HeapInuse < server.Config.MaxMemory { + return nil + } + } + idx-- + } + } + default: + return nil + } + } } diff --git a/src/server/server.go b/src/server/server.go index 9648b7c..92dbecb 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -43,13 +43,13 @@ type Server struct { ACL utils.ACL PubSub utils.PubSub - SnapshotInProgress atomic.Bool - RewriteAOFInProgress atomic.Bool - StateCopyInProgress atomic.Bool - StateMutationInProgress atomic.Bool - LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds - SnapshotEngine *snapshot.Engine - AOFEngine *aof.Engine + SnapshotInProgress atomic.Bool // Atomic boolean that's true when actively taking a snapshot. + RewriteAOFInProgress atomic.Bool // Atomic boolean that's true when actively rewriting AOF file is in progress. + StateCopyInProgress atomic.Bool // Atomic boolean that's true when actively copying state for snapshotting or preamble generation. + StateMutationInProgress atomic.Bool // Atomic boolean that is set to true when state mutation is in progress. + LatestSnapshotMilliseconds atomic.Int64 // Unix epoch in milliseconds + SnapshotEngine *snapshot.Engine // Snapshot engine for standalone mode + AOFEngine *aof.Engine // AOF engine for standalone mode } type Opts struct { @@ -136,7 +136,8 @@ func NewServer(opts Opts) *Server { server.lfuCache = eviction.NewCacheLFU() server.lruCache = eviction.NewCacheLRU() - // TODO: Start goroutine that continuously reads the mem stats before triggering purge once max-memory is reached + // TODO: If eviction policy is volatile-ttl, start goroutine that continuously reads the mem stats + // TODO: before triggering purge once max-memory is reached return server } diff --git a/src/utils/config.go b/src/utils/config.go index 98156de..0bdeec2 100644 --- a/src/utils/config.go +++ b/src/utils/config.go @@ -80,7 +80,8 @@ The options are 'always' for syncing on each command, 'everysec' to sync every s var maxMemory uint64 = 0 flag.Func("max-memory", `Upper memory limit before triggering eviction. -Supported units (kb, mb, gb, tb, pb). There is no limit by default.`, func(memory string) error { +Supported units (kb, mb, gb, tb, pb). When 0 is passed, there will be no memory limit. +There is no limit by default.`, func(memory string) error { b, err := ParseMemory(memory) if err != nil { return err diff --git a/src/utils/types.go b/src/utils/types.go index 2b42cbe..3679a4f 100644 --- a/src/utils/types.go +++ b/src/utils/types.go @@ -13,9 +13,9 @@ type Server interface { KeyRUnlock(key string) KeyExists(key string) bool CreateKeyAndLock(ctx context.Context, key string) (bool, error) - GetValue(key string) interface{} + GetValue(ctx context.Context, key string) interface{} SetValue(ctx context.Context, key string, value interface{}) error - SetKeyExpiry(key string, expire time.Time, touch bool) + SetKeyExpiry(ctx context.Context, key string, expire time.Time, touch bool) RemoveKeyExpiry(key string) GetState() map[string]interface{} GetAllCommands(ctx context.Context) []Command diff --git a/src/utils/utils.go b/src/utils/utils.go index 28ff13c..8fb237d 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -163,12 +163,16 @@ func ParseMemory(memory string) (uint64, error) { } // IsMaxMemoryExceeded checks whether we have exceeded the current maximum memory limit -func IsMaxMemoryExceeded(config Config) bool { +func IsMaxMemoryExceeded(maxMemory uint64) bool { + if maxMemory == 0 { + return false + } + var memStats runtime.MemStats runtime.ReadMemStats(&memStats) // If we're currently using less than the configured max memory, return false - if memStats.HeapInuse < config.MaxMemory { + if memStats.HeapInuse < maxMemory { return false } @@ -176,7 +180,8 @@ func IsMaxMemoryExceeded(config Config) bool { // This measure is to prevent deleting keys that may be important when some memory can be reclaimed // by just collecting garbage. runtime.GC() + runtime.ReadMemStats(&memStats) // Return true when whe are above or equal to max memory. - return memStats.HeapInuse >= config.MaxMemory + return memStats.HeapInuse >= maxMemory }